minor: ruff format
Some checks failed
ci / ruff (push) Successful in 4s
ci / ty (push) Successful in 4s
ci / tests (push) Failing after 7s
ci / build (push) Has been cancelled

Tabs -> spaces
This commit is contained in:
Chris Milne 2026-06-22 15:04:11 +01:00
parent b2921b73b8
commit fab228bf8f
56 changed files with 3629 additions and 3630 deletions

View file

@ -22,5 +22,5 @@ from fastapi import APIRouter
router = APIRouter(
tags=[""],
tags=[""],
)

View file

@ -8,6 +8,6 @@ Exports:
from fastapi import APIRouter
router = APIRouter(
tags=["admin"],
prefix="/admin",
tags=["admin"],
prefix="/admin",
)

View file

@ -26,15 +26,15 @@ api_router.include_router(iam_router)
class HealthCheckResponse(CustomBaseModel):
status: str
status: str
@api_router.get(
path="/healthcheck",
status_code=status.HTTP_200_OK,
response_model=HealthCheckResponse,
include_in_schema=False,
path="/healthcheck",
status_code=status.HTTP_200_OK,
response_model=HealthCheckResponse,
include_in_schema=False,
)
def healthcheck():
"""Simple health check endpoint."""
return {"status": "ok"}
"""Simple health check endpoint."""
return {"status": "ok"}

View file

@ -9,9 +9,9 @@ from src.config import CustomBaseSettings
class AuthConfig(CustomBaseSettings):
OIDC_CONFIG: str = ""
OIDC_ISSUER: str = ""
CLIENT_ID: str = ""
OIDC_CONFIG: str = ""
OIDC_ISSUER: str = ""
CLIENT_ID: str = ""
auth_settings = AuthConfig()

View file

@ -16,92 +16,92 @@ from src.exceptions import ForbiddenException
from src.user.dependencies import user_model_claims_dependency
from src.user.models import User
from src.organisation.dependencies import (
org_model_query_dependency,
org_model_body_dependency,
org_model_query_dependency,
org_model_body_dependency,
)
from src.organisation.models import Organisation as Org
async def org_query_user_claims(
org_model: org_model_query_dependency, user_model: user_model_claims_dependency
org_model: org_model_query_dependency, user_model: user_model_claims_dependency
):
if user_model in org_model.user_rel:
return True
if user_model in org_model.user_rel:
return True
raise ForbiddenException()
raise ForbiddenException()
org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)]
def get_super_admin_list():
return []
return []
def empty_su_list():
return []
return []
def testing_su_list():
return ["admin@test.com"]
return ["admin@test.com"]
su_list_dependency = Annotated[list[str | None], Depends(get_super_admin_list)]
async def user_model_super_admin(
user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency
user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency
):
if user_model.email in super_admin_emails:
return user_model
if user_model.email in super_admin_emails:
return user_model
raise ForbiddenException(message="Must be super admin")
raise ForbiddenException(message="Must be super admin")
super_admin_dependency = Annotated[User, Depends(user_model_super_admin)]
async def org_query_root_claims(
user_model: user_model_claims_dependency,
org_model: org_model_query_dependency,
su_emails: su_list_dependency,
request: Request,
user_model: user_model_claims_dependency,
org_model: org_model_query_dependency,
su_emails: su_list_dependency,
request: Request,
):
try:
if await user_model_super_admin(user_model, su_emails):
return org_model
except ForbiddenException:
pass
try:
if await user_model_super_admin(user_model, su_emails):
return org_model
except ForbiddenException:
pass
await org_status_check(org_model, request)
await org_status_check(org_model, request)
if org_model.root_user_id == user_model.id:
return org_model
if org_model.root_user_id == user_model.id:
return org_model
raise ForbiddenException(message="Must be the org's root user")
raise ForbiddenException(message="Must be the org's root user")
org_model_root_claim_query_dependency = Annotated[Org, Depends(org_query_root_claims)]
async def org_body_root_claims(
user_model: user_model_claims_dependency,
org_model: org_model_body_dependency,
su_emails: su_list_dependency,
request: Request,
user_model: user_model_claims_dependency,
org_model: org_model_body_dependency,
su_emails: su_list_dependency,
request: Request,
):
try:
if await user_model_super_admin(user_model, su_emails):
return org_model
except ForbiddenException:
pass
try:
if await user_model_super_admin(user_model, su_emails):
return org_model
except ForbiddenException:
pass
await org_status_check(org_model, request)
await org_status_check(org_model, request)
if org_model.root_user_id == user_model.id:
return org_model
if org_model.root_user_id == user_model.id:
return org_model
raise ForbiddenException(message="Must be the org's root user")
raise ForbiddenException(message="Must be the org's root user")
org_model_root_claim_body_dependency = Annotated[Org, Depends(org_body_root_claims)]

View file

@ -8,5 +8,5 @@ Exports:
from fastapi import APIRouter
router = APIRouter(
tags=["auth"],
tags=["auth"],
)

View file

@ -31,56 +31,56 @@ oidc_dependency = Annotated[str, Depends(oidc)]
async def get_dev_user():
return {"db_id": 1, "email": "chris@sr2.uk"}
return {"db_id": 1, "email": "chris@sr2.uk"}
async def get_current_user(
oidc_auth_string: oidc_dependency, db: DbSession
oidc_auth_string: oidc_dependency, db: DbSession
) -> dict[str, Any]:
config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"]
key_response = requests.get(jwks_uri)
jwk_keys = KeySet.import_key_set(key_response.json())
config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"]
key_response = requests.get(jwks_uri)
jwk_keys = KeySet.import_key_set(key_response.json())
token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys)
token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys)
claims_requests = jwt.JWTClaimsRegistry(
exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER}
)
claims_requests = jwt.JWTClaimsRegistry(
exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER}
)
try:
claims_requests.validate(token.claims)
except ExpiredTokenError:
raise UnauthorizedException(message="Token is expired")
db_id = await add_user(db, token.claims)
try:
claims_requests.validate(token.claims)
except ExpiredTokenError:
raise UnauthorizedException(message="Token is expired")
db_id = await add_user(db, token.claims)
token.claims["db_id"] = db_id
token.claims["db_id"] = db_id
return token.claims
return token.claims
claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)]
async def org_status_check(org_model: Org, request: Request):
org_status = OrgStatus(org_model.status)
if org_status.is_blocked:
raise ForbiddenException("This organisation cannot perform this action.")
org_status = OrgStatus(org_model.status)
if org_status.is_blocked:
raise ForbiddenException("This organisation cannot perform this action.")
root = "/api/v1"
root = "/api/v1"
pre_approval_endpoints = [
f"PATCH{root}/org/status",
f"PATCH{root}/org/questionnaire",
f"GET{root}/org",
f"GET{root}/org/contact",
f"PATCH{root}/org/contact",
f"DELETE{root}/org/self",
]
current_request = f"{request.method}{request.url.path}"
if (
current_request not in pre_approval_endpoints
and org_model.status != OrgStatus.APPROVED
):
raise AwaitingApprovalException(org_model.id)
pre_approval_endpoints = [
f"PATCH{root}/org/status",
f"PATCH{root}/org/questionnaire",
f"GET{root}/org",
f"GET{root}/org/contact",
f"PATCH{root}/org/contact",
f"DELETE{root}/org/self",
]
current_request = f"{request.method}{request.url.path}"
if (
current_request not in pre_approval_endpoints
and org_model.status != OrgStatus.APPROVED
):
raise AwaitingApprovalException(org_model.id)

View file

@ -16,31 +16,31 @@ from src.constants import Environment
class CustomBaseSettings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
class Config(CustomBaseSettings):
APP_VERSION: str = "0.1"
ENVIRONMENT: Environment = Environment.PRODUCTION
SECRET_KEY: SecretStr = SecretStr("")
DISABLE_AUTH: bool = False
APP_VERSION: str = "0.1"
ENVIRONMENT: Environment = Environment.PRODUCTION
SECRET_KEY: SecretStr = SecretStr("")
DISABLE_AUTH: bool = False
CORS_ORIGINS: list[str] = ["*"]
CORS_ORIGINS_REGEX: str | None = None
CORS_HEADERS: list[str] = ["*"]
CORS_ORIGINS: list[str] = ["*"]
CORS_ORIGINS_REGEX: str | None = None
CORS_HEADERS: list[str] = ["*"]
DATABASE_NAME: str = "fastapi-exp"
DATABASE_PORT: str = "5432"
DATABASE_HOSTNAME: str = "localhost"
DATABASE_CREDENTIALS: SecretStr = SecretStr(":")
DATABASE_NAME: str = "fastapi-exp"
DATABASE_PORT: str = "5432"
DATABASE_HOSTNAME: str = "localhost"
DATABASE_CREDENTIALS: SecretStr = SecretStr(":")
DATABASE_POOL_SIZE: int = 16
DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes
DATABASE_POOL_PRE_PING: bool = True
DATABASE_POOL_SIZE: int = 16
DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes
DATABASE_POOL_PRE_PING: bool = True
LETTERMINT_API_TOKEN: SecretStr = SecretStr("")
LETTERMINT_API_TOKEN: SecretStr = SecretStr("")
settings = Config()
@ -51,20 +51,20 @@ DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME
DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value()
# this will support special chars for credentials
_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split(
":"
":"
)
_QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD))
SQLALCHEMY_DATABASE_URI = SecretStr(
f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
)
if settings.ENVIRONMENT == Environment.TESTING:
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")
app_configs: dict[str, Any] = {"title": "App API"}
if settings.ENVIRONMENT.is_deployed:
app_configs["root_path"] = f"/v{settings.APP_VERSION}"
app_configs["root_path"] = f"/v{settings.APP_VERSION}"
if not settings.ENVIRONMENT.is_debug:
app_configs["openapi_url"] = None # hide docs
app_configs["openapi_url"] = None # hide docs

View file

@ -9,29 +9,29 @@ from enum import StrEnum, auto
class Environment(StrEnum):
"""
Enumeration of environments.
"""
Enumeration of environments.
Attributes:
LOCAL (str): Application is running locally
TESTING (str): Application is running in testing mode
STAGING (str): Application is running in staging mode (ie not testing)
PRODUCTION (str): Application is running in production mode
"""
Attributes:
LOCAL (str): Application is running locally
TESTING (str): Application is running in testing mode
STAGING (str): Application is running in staging mode (ie not testing)
PRODUCTION (str): Application is running in production mode
"""
LOCAL = auto()
TESTING = auto()
STAGING = auto()
PRODUCTION = auto()
LOCAL = auto()
TESTING = auto()
STAGING = auto()
PRODUCTION = auto()
@property
def is_debug(self):
return self in (self.LOCAL, self.STAGING, self.TESTING)
@property
def is_debug(self):
return self in (self.LOCAL, self.STAGING, self.TESTING)
@property
def is_testing(self):
return self == self.TESTING
@property
def is_testing(self):
return self == self.TESTING
@property
def is_deployed(self) -> bool:
return self in (self.STAGING, self.PRODUCTION)
@property
def is_deployed(self) -> bool:
return self in (self.STAGING, self.PRODUCTION)

View file

@ -11,13 +11,13 @@ from fastapi import HTTPException, status
class ContactNotFoundException(HTTPException):
def __init__(self, contact_id: Optional[int] = None) -> None:
detail = (
"Contact not found"
if contact_id is None
else f"Contact with ID '{contact_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def __init__(self, contact_id: Optional[int] = None) -> None:
detail = (
"Contact not found"
if contact_id is None
else f"Contact with ID '{contact_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)

View file

@ -15,22 +15,22 @@ from src.models import CustomBase
class Contact(CustomBase, IdMixin):
__tablename__ = "contact"
__tablename__ = "contact"
email: Mapped[str] = mapped_column(default=None, nullable=True)
first_name: Mapped[str] = mapped_column(default=None, nullable=True)
last_name: Mapped[str] = mapped_column(default=None, nullable=True)
phonenumber: Mapped[str] = mapped_column(default=None, nullable=True)
vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
email: Mapped[str] = mapped_column(default=None, nullable=True)
first_name: Mapped[str] = mapped_column(default=None, nullable=True)
last_name: Mapped[str] = mapped_column(default=None, nullable=True)
phonenumber: Mapped[str] = mapped_column(default=None, nullable=True)
vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
street_address: Mapped[str] = mapped_column(default=None, nullable=True)
street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True)
post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City
country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB
address_region: Mapped[str | None] = mapped_column(default=None, nullable=True)
postal_code: Mapped[str] = mapped_column(default=None, nullable=True)
street_address: Mapped[str] = mapped_column(default=None, nullable=True)
street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True)
post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City
country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB
address_region: Mapped[str | None] = mapped_column(default=None, nullable=True)
postal_code: Mapped[str] = mapped_column(default=None, nullable=True)
org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False
)
org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False
)

View file

@ -6,6 +6,6 @@ from fastapi import APIRouter
router = APIRouter(
prefix="/contact",
tags=["contact"],
prefix="/contact",
tags=["contact"],
)

View file

@ -14,22 +14,22 @@ from src.schemas import CustomBaseModel
class ContactAddress(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
model_config = ConfigDict(from_attributes=True, extra="ignore")
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None
street_address_line_2: Optional[str] = None
locality: Optional[str] = None
address_region: Optional[str] = None
country_code: Optional[str] = None
postal_code: Optional[str] = None
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None
street_address_line_2: Optional[str] = None
locality: Optional[str] = None
address_region: Optional[str] = None
country_code: Optional[str] = None
postal_code: Optional[str] = None
class ContactModel(CustomBaseModel):
email: Optional[EmailStr] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phonenumber: Optional[str] = None
vat_number: Optional[str] = None
email: Optional[EmailStr] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phonenumber: Optional[str] = None
vat_number: Optional[str] = None
address: ContactAddress
address: ContactAddress

View file

@ -1,6 +1,7 @@
"""
Database connection and session utilities
"""
from contextlib import contextmanager
from typing import Annotated, Generator
from sqlalchemy import create_engine, StaticPool, Connection
@ -29,6 +30,7 @@ else:
sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine)
@contextmanager
def get_db_connection() -> Generator[Connection, None, None]:
with engine.connect() as connection:
@ -38,12 +40,15 @@ def get_db_connection() -> Generator[Connection, None, None]:
connection.rollback()
raise
def _get_db_connection() -> Generator[Connection, None]:
with get_db_connection() as connection:
yield connection
DbConnection = Annotated[Connection, Depends(_get_db_connection)]
@contextmanager
def get_db_session() -> Generator[Session, None, None]:
session = sm()
@ -60,4 +65,5 @@ def _get_db_session() -> Generator[Session, None]:
with get_db_session() as session:
yield session
DbSession = Annotated[Session, Depends(_get_db_session)]

View file

@ -12,36 +12,36 @@ from fastapi import HTTPException, status
class UnprocessableContentException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Unprocessable content" if not message else message
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=detail,
)
def __init__(self, message: Optional[str] = None) -> None:
detail = "Unprocessable content" if not message else message
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=detail,
)
class ConflictException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Conflict" if not message else message
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=detail,
)
def __init__(self, message: Optional[str] = None) -> None:
detail = "Conflict" if not message else message
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=detail,
)
class ForbiddenException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Forbidden" if not message else message
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail,
)
def __init__(self, message: Optional[str] = None) -> None:
detail = "Forbidden" if not message else message
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail,
)
class UnauthorizedException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Not authorized" if not message else message
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)
def __init__(self, message: Optional[str] = None) -> None:
detail = "Not authorized" if not message else message
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)

View file

@ -18,59 +18,55 @@ from src.iam.exceptions import GroupNotFoundException, PermNotFoundException
from src.iam.schemas import GroupIDMixin, PermIDMixin
def get_group_model_query(
db: DbSession, group_id: Annotated[int, Query(gt=0)]
) -> Group:
group_model = db.get(Group, group_id)
if group_model is None:
raise GroupNotFoundException(group_id)
def get_group_model_query(db: DbSession, group_id: Annotated[int, Query(gt=0)]) -> Group:
group_model = db.get(Group, group_id)
if group_model is None:
raise GroupNotFoundException(group_id)
return group_model
return group_model
group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)]
def get_group_model_body(
db: DbSession, request_model: Optional[GroupIDMixin] = None
db: DbSession, request_model: Optional[GroupIDMixin] = None
) -> Group:
group_id = getattr(request_model, "group_id", None)
if group_id is None:
raise GroupNotFoundException()
group_model = db.get(Group, group_id)
if group_model is None:
raise GroupNotFoundException(group_id)
group_id = getattr(request_model, "group_id", None)
if group_id is None:
raise GroupNotFoundException()
group_model = db.get(Group, group_id)
if group_model is None:
raise GroupNotFoundException(group_id)
return group_model
return group_model
group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)]
def get_perm_model_body(
db: DbSession, request_model: Optional[PermIDMixin] = None
db: DbSession, request_model: Optional[PermIDMixin] = None
) -> Permission:
perm_id = getattr(request_model, "permission_id", None)
if perm_id is None:
raise PermNotFoundException
perm_model = db.get(Permission, perm_id)
if perm_model is None:
raise PermNotFoundException(perm_id)
perm_id = getattr(request_model, "permission_id", None)
if perm_id is None:
raise PermNotFoundException
perm_model = db.get(Permission, perm_id)
if perm_model is None:
raise PermNotFoundException(perm_id)
return perm_model
return perm_model
perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)]
def get_perm_model_query(
db: DbSession, perm_id: Annotated[int, Query(gt=0)]
) -> Permission:
perm_model = db.get(Permission, perm_id)
if perm_model is None:
raise PermNotFoundException(perm_id)
def get_perm_model_query(db: DbSession, perm_id: Annotated[int, Query(gt=0)]) -> Permission:
perm_model = db.get(Permission, perm_id)
if perm_model is None:
raise PermNotFoundException(perm_id)
return perm_model
return perm_model
perm_model_query_dependency = Annotated[Permission, Depends(get_perm_model_query)]

View file

@ -12,26 +12,26 @@ from fastapi import HTTPException, status
class GroupNotFoundException(HTTPException):
def __init__(self, group_id: Optional[int] = None) -> None:
detail = (
"Group not found"
if group_id is None
else f"User with ID '{group_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def __init__(self, group_id: Optional[int] = None) -> None:
detail = (
"Group not found"
if group_id is None
else f"User with ID '{group_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
class PermNotFoundException(HTTPException):
def __init__(self, perm_id: Optional[int] = None) -> None:
detail = (
"Permission not found"
if perm_id is None
else f"User with ID '{perm_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def __init__(self, perm_id: Optional[int] = None) -> None:
detail = (
"Permission not found"
if perm_id is None
else f"User with ID '{perm_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)

View file

@ -25,90 +25,90 @@ from src.models import CustomBase, IdMixin
class Permission(CustomBase, IdMixin):
__tablename__ = "permission"
__tablename__ = "permission"
resource: Mapped[str]
action: Mapped[str]
resource: Mapped[str]
action: Mapped[str]
service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE"))
service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE"))
__table_args__ = (
UniqueConstraint(
"service_id",
"resource",
"action",
name="uniq_permission_resource_and_action",
),
)
__table_args__ = (
UniqueConstraint(
"service_id",
"resource",
"action",
name="uniq_permission_resource_and_action",
),
)
service_rel = relationship(
"Service",
back_populates="permission_rel",
foreign_keys="Permission.service_id",
)
service_rel = relationship(
"Service",
back_populates="permission_rel",
foreign_keys="Permission.service_id",
)
group_rel = relationship(
"Group", secondary="group_permissions", back_populates="permission_rel"
)
group_rel = relationship(
"Group", secondary="group_permissions", back_populates="permission_rel"
)
org_rel = relationship(
"Organisation", secondary="org_permissions", back_populates="permission_rel"
)
org_rel = relationship(
"Organisation", secondary="org_permissions", back_populates="permission_rel"
)
@property
def service_name(self):
return self.service_rel.name
@property
def service_name(self):
return self.service_rel.name
class Group(CustomBase, IdMixin):
__tablename__ = "group"
__tablename__ = "group"
name: Mapped[str]
name: Mapped[str]
org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE"))
org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE"))
__table_args__ = (
UniqueConstraint(
"name",
"org_id",
name="uniq_group_name_org_id",
),
)
__table_args__ = (
UniqueConstraint(
"name",
"org_id",
name="uniq_group_name_org_id",
),
)
user_rel = relationship("User", secondary="user_groups", back_populates="group_rel")
user_rel = relationship("User", secondary="user_groups", back_populates="group_rel")
org_rel = relationship("Organisation", back_populates="group_rel")
org_rel = relationship("Organisation", back_populates="group_rel")
permission_rel = relationship(
"Permission", secondary="group_permissions", back_populates="group_rel"
)
permission_rel = relationship(
"Permission", secondary="group_permissions", back_populates="group_rel"
)
class GroupPermissions(CustomBase):
__tablename__ = "group_permissions"
group_id: Mapped[int] = mapped_column(
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
)
permission_id: Mapped[int] = mapped_column(
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
)
__tablename__ = "group_permissions"
group_id: Mapped[int] = mapped_column(
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
)
permission_id: Mapped[int] = mapped_column(
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
)
class UserGroups(CustomBase):
__tablename__ = "user_groups"
user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
)
group_id: Mapped[int] = mapped_column(
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
)
__tablename__ = "user_groups"
user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
)
group_id: Mapped[int] = mapped_column(
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
)
class OrgPermissions(CustomBase):
__tablename__ = "org_permissions"
org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
)
permission_id: Mapped[int] = mapped_column(
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
)
__tablename__ = "org_permissions"
org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
)
permission_id: Mapped[int] = mapped_column(
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
)

File diff suppressed because it is too large Load diff

View file

@ -11,151 +11,151 @@ from typing import Optional, Annotated
from pydantic import EmailStr, ConfigDict, Field
from src.schemas import (
CustomBaseModel,
ResourceName,
ServiceIDMixin,
OrgIDMixin,
UserIDMixin,
PermIDMixin,
GroupIDMixin,
GroupSummary,
OrgSummary,
UserSummary,
CustomBaseModel,
ResourceName,
ServiceIDMixin,
OrgIDMixin,
UserIDMixin,
PermIDMixin,
GroupIDMixin,
GroupSummary,
OrgSummary,
UserSummary,
)
class UserSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int
first_name: str
last_name: str
email: EmailStr
id: int
first_name: str
last_name: str
email: EmailStr
class PermissionSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int
service_name: str
resource: str
action: str
id: int
service_name: str
resource: str
action: str
class GroupDetails(CustomBaseModel):
details: GroupSummary
permissions: list[PermissionSchema]
details: GroupSummary
permissions: list[PermissionSchema]
class IAMCAoRRequest(CustomBaseModel):
action: str
rn: ResourceName
action: str
rn: ResourceName
class IAMCAoRResponse(CustomBaseModel):
allowed: bool
user: UserSummary
action: str
rn: ResourceName
allowed: bool
user: UserSummary
action: str
rn: ResourceName
class IAMGetGroupPermissionsResponse(CustomBaseModel):
organisation: OrgSummary
group: GroupSummary
permissions: list[PermissionSchema]
organisation: OrgSummary
group: GroupSummary
permissions: list[PermissionSchema]
class IAMGetGroupUsersResponse(CustomBaseModel):
organisation: OrgSummary
group: GroupSummary
users: list[UserSummary]
organisation: OrgSummary
group: GroupSummary
users: list[UserSummary]
class IAMPostGroupRequest(OrgIDMixin):
name: str = Field(min_length=3)
name: str = Field(min_length=3)
class IAMPostGroupResponse(CustomBaseModel):
organisation: OrgSummary
group: GroupSummary
organisation: OrgSummary
group: GroupSummary
class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin):
pass
pass
class IAMPutGroupPermissionResponse(CustomBaseModel):
organisation: OrgSummary
group: GroupSummary
permissions: list[PermissionSchema]
organisation: OrgSummary
group: GroupSummary
permissions: list[PermissionSchema]
class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin):
pass
pass
class IAMPutGroupUserResponse(CustomBaseModel):
group: GroupSummary
users: list[UserSchema]
group: GroupSummary
users: list[UserSchema]
class IAMDeleteGroupPermissionResponse(CustomBaseModel):
group: GroupSummary
permissions: list[PermissionSchema]
group: GroupSummary
permissions: list[PermissionSchema]
class IAMDeleteGroupUserResponse(CustomBaseModel):
group: GroupSummary
users: list[UserSchema]
group: GroupSummary
users: list[UserSchema]
class IAMGetPermissionsResponse(CustomBaseModel):
permissions: list[PermissionSchema]
permissions: list[PermissionSchema]
class IAMPostPermissionRequest(ServiceIDMixin):
resource: str
action: str
resource: str
action: str
class IAMPostPermissionResponse(CustomBaseModel):
permission: PermissionSchema
permission: PermissionSchema
class IAMGetPermissionsSearchRequest(OrgIDMixin):
service_id: Annotated[int | None, Field(gt=0)] = None
resource: Optional[str] = None
action: Optional[str] = None
service_id: Annotated[int | None, Field(gt=0)] = None
resource: Optional[str] = None
action: Optional[str] = None
class IAMGetPermissionsSearchResponse(CustomBaseModel):
permissions: list[PermissionSchema]
permissions: list[PermissionSchema]
class IAMPutGroupInvitationRequest(OrgIDMixin, GroupIDMixin):
user_email: EmailStr
user_email: EmailStr
class IAMPutGroupInvitationResponse(CustomBaseModel):
organisation: OrgSummary
group: GroupSummary
invited_email: EmailStr
organisation: OrgSummary
group: GroupSummary
invited_email: EmailStr
class IAMPutGroupInvitationAcceptRequest(CustomBaseModel):
jwt: str
jwt: str
class IAMPutGroupInvitationAcceptResponse(CustomBaseModel):
organisation: OrgSummary
user: UserSummary
group: GroupDetails
organisation: OrgSummary
user: UserSummary
group: GroupDetails
class IAMPutOrgPermissionsRequest(OrgIDMixin):
permissions: list[int]
permissions: list[int]
class IAMPutOrgPermissionsResponse(CustomBaseModel):
organisation: OrgSummary
permissions: list[PermissionSchema]
organisation: OrgSummary
permissions: list[PermissionSchema]

View file

@ -23,90 +23,90 @@ from src.service.schemas import HasServiceName
def valid_service_key(
db: DbSession, request: Request, request_model: HasServiceName
db: DbSession, request: Request, request_model: HasServiceName
) -> bool:
rn = request_model.rn
api_key = request.headers.get("X-API-Key", None)
if not api_key:
raise UnauthorizedException("Missing API key")
service = rn.service
result = (
db.query(Service)
.filter(Service.name == service)
.filter(Service.api_key == api_key)
.first()
)
if result is None:
raise UnauthorizedException("Invalid API key")
rn = request_model.rn
api_key = request.headers.get("X-API-Key", None)
if not api_key:
raise UnauthorizedException("Missing API key")
service = rn.service
result = (
db.query(Service)
.filter(Service.name == service)
.filter(Service.api_key == api_key)
.first()
)
if result is None:
raise UnauthorizedException("Invalid API key")
return True
return True
service_key_dependency = Annotated[bool, Depends(valid_service_key)]
async def send_user_group_invitation(
user_email: str, org_name: str, org_id: int, group_id: int, group_name: str
user_email: str, org_name: str, org_id: int, group_id: int, group_name: str
):
expiry_delta = timedelta(hours=24)
expiry = datetime.now(timezone.utc) + expiry_delta
claims = {
"email": user_email,
"org_id": org_id,
"group_id": group_id,
"group_name": group_name,
"exp": expiry,
"type": "group_invite",
}
expiry_delta = timedelta(hours=24)
expiry = datetime.now(timezone.utc) + expiry_delta
claims = {
"email": user_email,
"org_id": org_id,
"group_id": group_id,
"group_name": group_name,
"exp": expiry,
"type": "group_invite",
}
token = await generate_jwt(claims)
subject = f"You have been invited to join a group of {org_name}"
body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
token = await generate_jwt(claims)
subject = f"You have been invited to join a group of {org_name}"
body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
await send_email(
recipient=user_email,
subject=subject,
body=body,
)
await send_email(
recipient=user_email,
subject=subject,
body=body,
)
async def create_group_and_assign_perms(
db: Session, org_model: Org, group_name: str, perm_list: list[int]
db: Session, org_model: Org, group_name: str, perm_list: list[int]
):
new_group = Group(name=group_name, org_id=org_model.id)
db.add(new_group)
db.flush()
new_group = Group(name=group_name, org_id=org_model.id)
db.add(new_group)
db.flush()
for permission in perm_list:
perm_model = db.get(Perm, permission)
for permission in perm_list:
perm_model = db.get(Perm, permission)
if perm_model is None:
continue
if perm_model is None:
continue
new_group.permission_rel.append(perm_model)
db.flush()
new_group.permission_rel.append(perm_model)
db.flush()
return new_group
return new_group
async def assign_default_group(
db: DbSession,
org_model: Org,
user_model: User,
group_name: str,
perm_list: list[int],
db: DbSession,
org_model: Org,
user_model: User,
group_name: str,
perm_list: list[int],
):
group_model = (
db.query(Group)
.filter(Group.org_id == org_model.id)
.filter(Group.name == group_name)
.first()
)
group_model = (
db.query(Group)
.filter(Group.org_id == org_model.id)
.filter(Group.name == group_name)
.first()
)
if group_model is None:
group_model = await create_group_and_assign_perms(
db=db, group_name=group_name, org_model=org_model, perm_list=perm_list
)
if group_model is None:
group_model = await create_group_and_assign_perms(
db=db, group_name=group_name, org_model=org_model, perm_list=perm_list
)
user_model.group_rel.append(group_model)
db.flush()
user_model.group_rel.append(group_model)
db.flush()

View file

@ -1,6 +1,7 @@
"""
Application root file: Inits the FastAPI application
"""
import os.path
from contextlib import asynccontextmanager
from typing import AsyncGenerator
@ -19,43 +20,43 @@ from src.auth.service import get_current_user, get_dev_user
@asynccontextmanager
async def lifespan(_application: FastAPI) -> AsyncGenerator:
# Startup
yield
# Shutdown
# Startup
yield
# Shutdown
if settings.ENVIRONMENT.is_deployed:
# Just a precaution, should be False anyway
settings.DISABLE_AUTH = False
# Just a precaution, should be False anyway
settings.DISABLE_AUTH = False
tags_metadata = [
{
"name": "User",
"description": "User related operations, includes getting information about the current user",
},
{
"name": "Organisation",
"description": "Organisation related operations, includes getting lists of users etc associated with orgs",
},
{
"name": "Service",
"description": "Services related operations, includes registering services and reissuing API keys",
},
{
"name": "IAM",
"description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.",
},
{
"name": "User",
"description": "User related operations, includes getting information about the current user",
},
{
"name": "Organisation",
"description": "Organisation related operations, includes getting lists of users etc associated with orgs",
},
{
"name": "Service",
"description": "Services related operations, includes registering services and reissuing API keys",
},
{
"name": "IAM",
"description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.",
},
]
app = FastAPI(
swagger_ui_init_oauth={
"clientId": auth_settings.CLIENT_ID,
"usePkceWithAuthorizationCodeGrant": True,
"scopes": "openid profile email",
},
openapi_tags=tags_metadata,
swagger_ui_init_oauth={
"clientId": auth_settings.CLIENT_ID,
"usePkceWithAuthorizationCodeGrant": True,
"scopes": "openid profile email",
},
openapi_tags=tags_metadata,
)
# Type inspection disabled for middleware injection.
@ -64,19 +65,19 @@ app = FastAPI(
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value())
# noinspection PyTypeChecker
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_origin_regex=settings.CORS_ORIGINS_REGEX,
allow_credentials=True,
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
allow_headers=settings.CORS_HEADERS,
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_origin_regex=settings.CORS_ORIGINS_REGEX,
allow_credentials=True,
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
allow_headers=settings.CORS_HEADERS,
)
if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL):
app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_current_user] = get_dev_user
app.include_router(api_router)
if os.path.exists("/app/static"):
app.frontend("/ui", directory="/app/static", fallback="index.html")
app.frontend("/ui", directory="/app/static", fallback="index.html")

View file

@ -10,28 +10,28 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class CustomBase(DeclarativeBase):
type_annotation_map = {
datetime: DateTime(timezone=True),
dict[str, Any]: JSON,
}
type_annotation_map = {
datetime: DateTime(timezone=True),
dict[str, Any]: JSON,
}
class ActivatedMixin:
active: Mapped[bool] = mapped_column(default=True)
active: Mapped[bool] = mapped_column(default=True)
class DeletedTimestampMixin:
deleted_at: Mapped[datetime | None] = mapped_column(nullable=True)
deleted_at: Mapped[datetime | None] = mapped_column(nullable=True)
class DescriptionMixin:
description: Mapped[str]
description: Mapped[str]
class IdMixin:
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
class TimestampMixin:
created_at: Mapped[datetime] = mapped_column(default=func.now())
updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now())
created_at: Mapped[datetime] = mapped_column(default=func.now())
updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now())

View file

@ -10,48 +10,48 @@ from enum import StrEnum, auto
class Status(StrEnum):
"""
Enumeration of organisation statuses.
"""
Enumeration of organisation statuses.
Attributes:
PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted.
SUBMITTED (str): Questionnaire submitted but not approved.
REMEDIATION (str): Questionnaire submitted but requires revisions.
APPROVED (str): Questionnaire has been approved by an admin.
REJECTED (str): Questionnaire has been rejected by an admin.
REMOVED (str): Organisation has been removed.
"""
Attributes:
PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted.
SUBMITTED (str): Questionnaire submitted but not approved.
REMEDIATION (str): Questionnaire submitted but requires revisions.
APPROVED (str): Questionnaire has been approved by an admin.
REJECTED (str): Questionnaire has been rejected by an admin.
REMOVED (str): Organisation has been removed.
"""
PARTIAL = auto()
SUBMITTED = auto()
REMEDIATION = auto()
APPROVED = auto()
REJECTED = auto()
REMOVED = auto()
PARTIAL = auto()
SUBMITTED = auto()
REMEDIATION = auto()
APPROVED = auto()
REJECTED = auto()
REMOVED = auto()
@property
def is_pre_approval(self):
return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION)
@property
def is_pre_approval(self):
return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION)
@property
def is_pre_submission(self):
return self in (self.PARTIAL, self.REMEDIATION)
@property
def is_pre_submission(self):
return self in (self.PARTIAL, self.REMEDIATION)
@property
def is_blocked(self):
return self in (self.REMOVED, self.REJECTED)
@property
def is_blocked(self):
return self in (self.REMOVED, self.REJECTED)
class ContactType(StrEnum):
"""
Enumeration of organisation contact types.
"""
Enumeration of organisation contact types.
Attributes:
BILLING(str): Billing contact.
SECURITY (str): Security contact.
OWNER (str): Owner contact.
"""
Attributes:
BILLING(str): Billing contact.
SECURITY (str): Security contact.
OWNER (str): Owner contact.
"""
BILLING = auto()
SECURITY = auto()
OWNER = auto()
BILLING = auto()
SECURITY = auto()
OWNER = auto()

View file

@ -18,25 +18,25 @@ from src.organisation.exceptions import OrgNotFoundException
def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org:
org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
return org_model
org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
return org_model
org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)]
def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org:
org_id: Optional[int] = getattr(request_model, "organisation_id", None)
if org_id is None:
raise OrgNotFoundException()
org_id: Optional[int] = getattr(request_model, "organisation_id", None)
if org_id is None:
raise OrgNotFoundException()
org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
return org_model
return org_model
org_model_body_dependency = Annotated[Org, Depends(get_org_model_body)]

View file

@ -12,26 +12,26 @@ from fastapi import HTTPException, status
class OrgNotFoundException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None:
detail = (
"Organisation not found"
if org_id is None
else f"Organisation with ID '{org_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def __init__(self, org_id: Optional[int] = None) -> None:
detail = (
"Organisation not found"
if org_id is None
else f"Organisation with ID '{org_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
class AwaitingApprovalException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None:
detail = (
"Organisation has not been approved."
if org_id is None
else f"Organisation with ID '{org_id}' has not been approved."
)
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)
def __init__(self, org_id: Optional[int] = None) -> None:
detail = (
"Organisation has not been approved."
if org_id is None
else f"Organisation with ID '{org_id}' has not been approved."
)
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)

View file

@ -25,51 +25,51 @@ from src.models import CustomBase, TimestampMixin
class Organisation(CustomBase, IdMixin, TimestampMixin, DeletedTimestampMixin):
__tablename__ = "organisation"
__tablename__ = "organisation"
name: Mapped[str]
status: Mapped[str] = mapped_column(default="partial")
intake_questionnaire: Mapped[dict[str, Any] | None]
name: Mapped[str]
status: Mapped[str] = mapped_column(default="partial")
intake_questionnaire: Mapped[dict[str, Any] | None]
root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
security_contact_id: Mapped[int] = mapped_column(
ForeignKey("contact.id"), nullable=True
)
owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
security_contact_id: Mapped[int] = mapped_column(
ForeignKey("contact.id"), nullable=True
)
owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel")
user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel")
group_rel = relationship(
"Group", back_populates="org_rel", cascade="all, delete-orphan"
)
root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id")
group_rel = relationship(
"Group", back_populates="org_rel", cascade="all, delete-orphan"
)
root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id")
billing_contact_rel = relationship(
"Contact", foreign_keys="Organisation.billing_contact_id"
)
security_contact_rel = relationship(
"Contact", foreign_keys="Organisation.security_contact_id"
)
owner_contact_rel = relationship(
"Contact", foreign_keys="Organisation.owner_contact_id"
)
billing_contact_rel = relationship(
"Contact", foreign_keys="Organisation.billing_contact_id"
)
security_contact_rel = relationship(
"Contact", foreign_keys="Organisation.security_contact_id"
)
owner_contact_rel = relationship(
"Contact", foreign_keys="Organisation.owner_contact_id"
)
permission_rel = relationship(
"Permission", secondary="org_permissions", back_populates="org_rel"
)
permission_rel = relationship(
"Permission", secondary="org_permissions", back_populates="org_rel"
)
@property
def root_user_email(self) -> str:
return self.root_user_rel.email if self.root_user_rel else ""
@property
def root_user_email(self) -> str:
return self.root_user_rel.email if self.root_user_rel else ""
class OrgUsers(CustomBase):
__tablename__ = "orgusers"
__tablename__ = "orgusers"
org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
)
user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
)
org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
)
user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
)

File diff suppressed because it is too large Load diff

View file

@ -12,139 +12,139 @@ from datetime import datetime
from pydantic import EmailStr, ConfigDict, Field
from src.schemas import (
CustomBaseModel,
OrgIDMixin,
UserIDMixin,
GroupSummary,
OrgSummary,
UserSummary,
CustomBaseModel,
OrgIDMixin,
UserIDMixin,
GroupSummary,
OrgSummary,
UserSummary,
)
from src.contact.schemas import ContactModel
from src.organisation.constants import Status, ContactType
from src.organisation.schemas_questionnaires import (
QuestionnaireQuestionsVersion0 as CurrentQuestions,
questionnaire_union,
QuestionnaireQuestionsVersion0 as CurrentQuestions,
questionnaire_union,
)
class QuestionnaireMetadata(CustomBaseModel):
version: int
submission_date: Optional[datetime] = None
version: int
submission_date: Optional[datetime] = None
class Questionnaire(CustomBaseModel):
metadata: QuestionnaireMetadata
questions: questionnaire_union
metadata: QuestionnaireMetadata
questions: questionnaire_union
class ContactSummary(CustomBaseModel):
id: int
email: Optional[EmailStr] = None
id: int
email: Optional[EmailStr] = None
class OrgSchema(OrgIDMixin):
name: str
status: Status
root_user_email: EmailStr
intake_questionnaire: Optional[Questionnaire] = None
name: str
status: Status
root_user_email: EmailStr
intake_questionnaire: Optional[Questionnaire] = None
billing_contact: ContactSummary
owner_contact: ContactSummary
security_contact: ContactSummary
billing_contact: ContactSummary
owner_contact: ContactSummary
security_contact: ContactSummary
class OrgPostOrgRequest(CustomBaseModel):
name: str = Field(min_length=3)
intake_questionnaire: Optional[CurrentQuestions] = None
name: str = Field(min_length=3)
intake_questionnaire: Optional[CurrentQuestions] = None
class OrgPostOrgResponse(CustomBaseModel):
id: int
name: str
status: Status
id: int
name: str
status: Status
class OrgPatchQuestionnaireRequest(OrgIDMixin):
intake_questionnaire: CurrentQuestions
partial: bool
intake_questionnaire: CurrentQuestions
partial: bool
class OrgPatchQuestionnaireResponse(CustomBaseModel):
id: int
name: str
intake_questionnaire: Questionnaire
status: Status
id: int
name: str
intake_questionnaire: Questionnaire
status: Status
class OrgPatchStatusRequest(OrgIDMixin):
status: Status
status: Status
class OrgPatchStatusResponse(CustomBaseModel):
id: int
name: str
status: Status
id: int
name: str
status: Status
class OrgPatchContactRequest(OrgIDMixin):
contact_type: ContactType
contact_type: ContactType
email: Optional[EmailStr] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phonenumber: Optional[str] = None
vat_number: Optional[str] = None
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None
street_address_line_2: Optional[str] = None
locality: Optional[str] = None
address_region: Optional[str] = None
country_code: Optional[str] = None
postal_code: Optional[str] = None
email: Optional[EmailStr] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phonenumber: Optional[str] = None
vat_number: Optional[str] = None
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None
street_address_line_2: Optional[str] = None
locality: Optional[str] = None
address_region: Optional[str] = None
country_code: Optional[str] = None
postal_code: Optional[str] = None
class OrgPostUserRequest(OrgIDMixin, UserIDMixin):
pass
pass
class OrgPostUserResponse(CustomBaseModel):
organisation: OrgSummary
users: list[UserSummary]
organisation: OrgSummary
users: list[UserSummary]
class OrgPatchRootRequest(OrgIDMixin, UserIDMixin):
pass
pass
class OrgPatchRootResponse(CustomBaseModel):
name: str
root_user_email: str
name: str
root_user_email: str
class OrgGetUserResponse(CustomBaseModel):
users: list[dict[str, str | int]]
organisation: OrgSummary
users: list[dict[str, str | int]]
organisation: OrgSummary
class OrgGetGroupResponse(CustomBaseModel):
organisation: OrgSummary
groups: list[GroupSummary]
organisation: OrgSummary
groups: list[GroupSummary]
class OrgGetContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel
organisation: OrgSummary
contact: ContactModel
organisation: OrgSummary
class OrgPatchContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel
organisation: OrgSummary
contact: ContactModel
organisation: OrgSummary
class OrgGetOrgResponse(CustomBaseModel):
organisations: list[OrgSchema]
organisations: list[OrgSchema]

View file

@ -4,13 +4,13 @@ from src.schemas import CustomBaseModel
class QuestionnaireQuestions(CustomBaseModel):
pass
pass
class QuestionnaireQuestionsVersion0(QuestionnaireQuestions):
question_one: Optional[str] = None
question_two: Optional[str] = None
question_three: Optional[str] = None
question_one: Optional[str] = None
question_two: Optional[str] = None
question_three: Optional[str] = None
questionnaire_union = QuestionnaireQuestionsVersion0 # | QuestionnaireQuestionsVersion1

View file

@ -11,57 +11,57 @@ from src.user.models import User
async def add_default_org_permissions(
db: Session,
org_model: Org,
perm_list: list[int],
db: Session,
org_model: Org,
perm_list: list[int],
):
for permission in perm_list:
perm_model = db.get(Perm, permission)
for permission in perm_list:
perm_model = db.get(Perm, permission)
if perm_model is None:
continue
if perm_model is None:
continue
if perm_model in org_model.permission_rel:
continue
if perm_model in org_model.permission_rel:
continue
org_model.permission_rel.append(perm_model)
db.flush()
org_model.permission_rel.append(perm_model)
db.flush()
db.commit()
db.commit()
async def assign_defaults(
db: Session,
org_id: int,
user_id: int,
db: Session,
org_id: int,
user_id: int,
):
default_org_permissions = []
default_org_permissions = []
default_user_permissions = []
default_user_permissions = []
org_model = db.get(Org, org_id)
if org_model is None:
print("Org not found while adding defaults")
return
org_model = db.get(Org, org_id)
if org_model is None:
print("Org not found while adding defaults")
return
user_model = db.get(User, user_id)
if user_model is None:
print("User not found while adding defaults")
return
user_model = db.get(User, user_id)
if user_model is None:
print("User not found while adding defaults")
return
await add_default_org_permissions(db, org_model, default_org_permissions)
await assign_default_group(
db=db,
org_model=org_model,
user_model=user_model,
group_name="Default Users",
perm_list=default_user_permissions,
)
await assign_default_group(
db=db,
org_model=org_model,
user_model=user_model,
group_name="Root User",
perm_list=default_org_permissions,
)
db.commit()
await add_default_org_permissions(db, org_model, default_org_permissions)
await assign_default_group(
db=db,
org_model=org_model,
user_model=user_model,
group_name="Default Users",
perm_list=default_user_permissions,
)
await assign_default_group(
db=db,
org_model=org_model,
user_model=user_model,
group_name="Root User",
perm_list=default_org_permissions,
)
db.commit()

View file

@ -11,54 +11,54 @@ from typing import Optional
class CustomBaseModel(BaseModel):
pass
pass
### Mixins ###
class OrgIDMixin(CustomBaseModel):
organisation_id: int = Field(gt=0)
organisation_id: int = Field(gt=0)
class GroupIDMixin(CustomBaseModel):
group_id: int = Field(gt=0)
group_id: int = Field(gt=0)
class PermIDMixin(CustomBaseModel):
permission_id: int = Field(gt=0)
permission_id: int = Field(gt=0)
class ServiceIDMixin(CustomBaseModel):
service_id: int = Field(gt=0)
service_id: int = Field(gt=0)
class UserIDMixin(CustomBaseModel):
user_id: int = Field(gt=0)
user_id: int = Field(gt=0)
class ServiceNameMixin(CustomBaseModel):
service: str
service: str
class OrgSummary(CustomBaseModel):
id: int
name: str
id: int
name: str
class GroupSummary(CustomBaseModel):
id: int
name: str
id: int
name: str
class UserSummary(CustomBaseModel):
id: int
email: str
id: int
email: str
class ServiceSummary(CustomBaseModel):
id: int
name: str
id: int
name: str
class ResourceName(ServiceNameMixin, OrgIDMixin):
resource: str
instance: Optional[str] = None
resource: str
instance: Optional[str] = None

View file

@ -16,25 +16,23 @@ from src.service.models import Service
from src.service.schemas import ServiceIDMixin
async def get_service_model_query(
db: DbSession, service_id: Annotated[int, Query(gt=0)]
):
service_model = db.get(Service, service_id)
if service_model is None:
raise ServiceNotFoundException(service_id=service_id)
async def get_service_model_query(db: DbSession, service_id: Annotated[int, Query(gt=0)]):
service_model = db.get(Service, service_id)
if service_model is None:
raise ServiceNotFoundException(service_id=service_id)
return service_model
return service_model
service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)]
async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin):
service_model = db.get(Service, request_model.service_id)
if service_model is None:
raise ServiceNotFoundException(service_id=request_model.service_id)
service_model = db.get(Service, request_model.service_id)
if service_model is None:
raise ServiceNotFoundException(service_id=request_model.service_id)
return service_model
return service_model
service_model_body_dependency = Annotated[Service, Depends(get_service_model_body)]

View file

@ -11,13 +11,13 @@ from fastapi import HTTPException, status
class ServiceNotFoundException(HTTPException):
def __init__(self, service_id: Optional[int] = None) -> None:
detail = (
"Service not found"
if service_id is None
else f"Service with ID '{service_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def __init__(self, service_id: Optional[int] = None) -> None:
detail = (
"Service not found"
if service_id is None
else f"Service with ID '{service_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)

View file

@ -12,11 +12,11 @@ from src.models import CustomBase, IdMixin
class Service(CustomBase, IdMixin):
__tablename__ = "service"
__tablename__ = "service"
name: Mapped[str] = mapped_column(unique=True)
api_key: Mapped[str]
name: Mapped[str] = mapped_column(unique=True)
api_key: Mapped[str]
permission_rel = relationship(
"Permission", back_populates="service_rel", cascade="all, delete-orphan"
)
permission_rel = relationship(
"Permission", back_populates="service_rel", cascade="all, delete-orphan"
)

View file

@ -15,8 +15,8 @@ from psycopg.errors import UniqueViolation
from src.exceptions import ConflictException
from src.database import DbSession
from src.auth.dependencies import (
super_admin_dependency,
org_model_root_claim_query_dependency,
super_admin_dependency,
org_model_root_claim_query_dependency,
)
from src.iam.service import service_key_dependency
from src.iam.models import Permission as Perm
@ -25,212 +25,210 @@ from src.service.exceptions import ServiceNotFoundException
from src.service.models import Service
from src.service.utils import generate_api_key
from src.service.dependencies import (
service_model_body_dependency,
service_model_query_dependency,
service_model_body_dependency,
service_model_query_dependency,
)
from src.service.schemas import (
ServiceGetServiceResponse,
ServicePostServiceRequest,
ServicePostServiceResponse,
ServiceWithKeySchema,
ServicePatchKeyResponse,
ServicePatchKeyRequest,
ServicePostPermissionsResponse,
ServicePostPermissionsRequest,
ServiceGetServiceResponse,
ServicePostServiceRequest,
ServicePostServiceResponse,
ServiceWithKeySchema,
ServicePatchKeyResponse,
ServicePatchKeyRequest,
ServicePostPermissionsResponse,
ServicePostPermissionsRequest,
)
router = APIRouter(
tags=["Service"],
prefix="/service",
tags=["Service"],
prefix="/service",
)
@router.get(
"",
summary="Get all services",
status_code=status.HTTP_200_OK,
response_model=ServiceGetServiceResponse,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
status.HTTP_401_UNAUTHORIZED: {
"description": "Unauthorized",
"content": {
"application/json": {
"examples": {
"awaiting_approval": {
"summary": "Organisation has not yet been approved."
},
}
}
},
},
status.HTTP_403_FORBIDDEN: {
"description": "Forbidden",
"content": {
"application/json": {
"examples": {
"not_root": {"summary": "Not authorised. Must be root user."},
}
}
},
},
},
"",
summary="Get all services",
status_code=status.HTTP_200_OK,
response_model=ServiceGetServiceResponse,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
status.HTTP_401_UNAUTHORIZED: {
"description": "Unauthorized",
"content": {
"application/json": {
"examples": {
"awaiting_approval": {
"summary": "Organisation has not yet been approved."
},
}
}
},
},
status.HTTP_403_FORBIDDEN: {
"description": "Forbidden",
"content": {
"application/json": {
"examples": {
"not_root": {"summary": "Not authorised. Must be root user."},
}
}
},
},
},
)
async def get_all_services(
db: DbSession, org_model: org_model_root_claim_query_dependency
):
"""
Returns the ID and name of all services registered to the hub.
"""
permission_models = db.query(Service).all()
async def get_all_services(db: DbSession, org_model: org_model_root_claim_query_dependency):
"""
Returns the ID and name of all services registered to the hub.
"""
permission_models = db.query(Service).all()
return {"services": permission_models}
return {"services": permission_models}
@router.post(
"",
summary="Register a new service.",
status_code=status.HTTP_200_OK,
response_model=ServicePostServiceResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully registered a new service"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"},
},
"",
summary="Register a new service.",
status_code=status.HTTP_200_OK,
response_model=ServicePostServiceResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully registered a new service"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"},
},
)
async def register_service(
db: DbSession,
su: super_admin_dependency,
request_model: ServicePostServiceRequest,
db: DbSession,
su: super_admin_dependency,
request_model: ServicePostServiceRequest,
):
"""
Registers a new service to the hub, generating and returning an API key for it.
"""
key = generate_api_key()
service_model = Service(name=request_model.name, api_key=key)
"""
Registers a new service to the hub, generating and returning an API key for it.
"""
key = generate_api_key()
service_model = Service(name=request_model.name, api_key=key)
db.add(service_model)
try:
db.flush()
except IntegrityError as e:
if (
isinstance(e.orig, UniqueViolation) # Postgres unique violation
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
):
raise ConflictException(message="Service with this name already exists")
raise
response = ServiceWithKeySchema(**service_model.__dict__)
db.commit()
return {"service": response}
db.add(service_model)
try:
db.flush()
except IntegrityError as e:
if (
isinstance(e.orig, UniqueViolation) # Postgres unique violation
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
):
raise ConflictException(message="Service with this name already exists")
raise
response = ServiceWithKeySchema(**service_model.__dict__)
db.commit()
return {"service": response}
@router.patch(
"/key",
summary="Regenerate service API key.",
status_code=status.HTTP_200_OK,
response_model=ServicePatchKeyResponse,
responses={
status.HTTP_200_OK: {"description": "Successful update of API key"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
"/key",
summary="Regenerate service API key.",
status_code=status.HTTP_200_OK,
response_model=ServicePatchKeyResponse,
responses={
status.HTTP_200_OK: {"description": "Successful update of API key"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
)
async def regenerate_api_key(
db: DbSession,
su: super_admin_dependency,
service_model: service_model_body_dependency,
request_model: ServicePatchKeyRequest,
db: DbSession,
su: super_admin_dependency,
service_model: service_model_body_dependency,
request_model: ServicePatchKeyRequest,
):
"""
Generates and returns a new API key for the service to access the hub.
"""
key = generate_api_key()
service_model.api_key = key
"""
Generates and returns a new API key for the service to access the hub.
"""
key = generate_api_key()
service_model.api_key = key
db.flush()
response = ServiceWithKeySchema(**service_model.__dict__)
db.commit()
return {"service": response}
db.flush()
response = ServiceWithKeySchema(**service_model.__dict__)
db.commit()
return {"service": response}
@router.delete(
"",
summary="Remove a service.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
"",
summary="Remove a service.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
)
async def remove_service(
db: DbSession,
service_model: service_model_query_dependency,
su: super_admin_dependency,
db: DbSession,
service_model: service_model_query_dependency,
su: super_admin_dependency,
):
"""
Removes a service from the hub.
"""
db.delete(service_model)
db.commit()
"""
Removes a service from the hub.
"""
db.delete(service_model)
db.commit()
@router.post(
path="/permissions",
summary="Service endpoint for creating its own permissions.",
status_code=status.HTTP_200_OK,
response_model=ServicePostPermissionsResponse,
responses={
status.HTTP_401_UNAUTHORIZED: {
"description": "API Key missing or invalid | Issue verifying user OIDC claims"
},
},
path="/permissions",
summary="Service endpoint for creating its own permissions.",
status_code=status.HTTP_200_OK,
response_model=ServicePostPermissionsResponse,
responses={
status.HTTP_401_UNAUTHORIZED: {
"description": "API Key missing or invalid | Issue verifying user OIDC claims"
},
},
)
async def service_create_new_permissions(
db: DbSession,
request_model: ServicePostPermissionsRequest,
valid_key: service_key_dependency,
db: DbSession,
request_model: ServicePostPermissionsRequest,
valid_key: service_key_dependency,
):
"""
Allows a service to register its own set of permissions.
"""
service_model = (
db.query(Service).filter(Service.name == request_model.rn.service).first()
)
if service_model is None:
raise ServiceNotFoundException()
else:
service_id = service_model.id
response_list = []
for new_permission in request_model.permissions:
perm_model = (
db.query(Perm)
.filter(Perm.service_id == service_id)
.filter(Perm.resource == new_permission.resource)
.filter(Perm.action == new_permission.action)
.first()
)
if perm_model is not None:
response_code = 409
response = {
"id": perm_model.id,
"service_name": perm_model.service_name,
"resource": perm_model.resource,
"action": perm_model.action,
}
response_list.append((response, response_code))
continue
"""
Allows a service to register its own set of permissions.
"""
service_model = (
db.query(Service).filter(Service.name == request_model.rn.service).first()
)
if service_model is None:
raise ServiceNotFoundException()
else:
service_id = service_model.id
response_list = []
for new_permission in request_model.permissions:
perm_model = (
db.query(Perm)
.filter(Perm.service_id == service_id)
.filter(Perm.resource == new_permission.resource)
.filter(Perm.action == new_permission.action)
.first()
)
if perm_model is not None:
response_code = 409
response = {
"id": perm_model.id,
"service_name": perm_model.service_name,
"resource": perm_model.resource,
"action": perm_model.action,
}
response_list.append((response, response_code))
continue
new_perm_model = Perm(**new_permission.__dict__)
new_perm_model.service_id = service_id
db.add(new_perm_model)
db.flush()
response_code = 201
response = {
"id": new_perm_model.id,
"service_name": new_perm_model.service_name,
"resource": new_perm_model.resource,
"action": new_perm_model.action,
}
response_list.append((response, response_code))
new_perm_model = Perm(**new_permission.__dict__)
new_perm_model.service_id = service_id
db.add(new_perm_model)
db.flush()
response_code = 201
response = {
"id": new_perm_model.id,
"service_name": new_perm_model.service_name,
"resource": new_perm_model.resource,
"action": new_perm_model.action,
}
response_list.append((response, response_code))
db.commit()
return {"permissions": response_list}
db.commit()
return {"permissions": response_list}

View file

@ -10,10 +10,10 @@ from typing import Generic, TypeVar
from pydantic import Field, ConfigDict
from src.schemas import (
CustomBaseModel,
ServiceIDMixin,
ServiceSummary,
ServiceNameMixin,
CustomBaseModel,
ServiceIDMixin,
ServiceSummary,
ServiceNameMixin,
)
@ -21,51 +21,51 @@ T = TypeVar("T", bound=ServiceNameMixin)
class HasServiceName(CustomBaseModel, Generic[T]):
rn: T
rn: T
class PermissionResponseSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int
service_name: str
resource: str
action: str
id: int
service_name: str
resource: str
action: str
class PermissionRequestSchema(CustomBaseModel):
resource: str
action: str
resource: str
action: str
class ServiceWithKeySchema(ServiceSummary):
api_key: str
api_key: str
class ServiceGetServiceResponse(CustomBaseModel):
services: list[ServiceSummary]
services: list[ServiceSummary]
class ServicePostServiceRequest(CustomBaseModel):
name: str = Field(min_length=3)
name: str = Field(min_length=3)
class ServicePostServiceResponse(CustomBaseModel):
service: ServiceWithKeySchema
service: ServiceWithKeySchema
class ServicePatchKeyRequest(ServiceIDMixin):
pass
pass
class ServicePatchKeyResponse(CustomBaseModel):
service: ServiceWithKeySchema
service: ServiceWithKeySchema
class ServicePostPermissionsRequest(CustomBaseModel):
rn: ServiceNameMixin
permissions: list[PermissionRequestSchema]
rn: ServiceNameMixin
permissions: list[PermissionRequestSchema]
class ServicePostPermissionsResponse(CustomBaseModel):
permissions: list[tuple[PermissionResponseSchema, int]]
permissions: list[tuple[PermissionResponseSchema, int]]

View file

@ -9,4 +9,4 @@ import uuid
def generate_api_key() -> str:
return str(uuid.uuid4())
return str(uuid.uuid4())

View file

@ -19,37 +19,37 @@ from src.schemas import UserIDMixin
async def get_user_model_claims(claims: claims_dependency, db: DbSession):
user_id = claims.get("db_id", None)
if user_id is None:
raise UserNotFoundException()
user_id = claims.get("db_id", None)
if user_id is None:
raise UserNotFoundException()
user_model = db.get(User, user_id)
if user_model is None:
raise UserNotFoundException(user_id=user_id)
user_model = db.get(User, user_id)
if user_model is None:
raise UserNotFoundException(user_id=user_id)
return user_model
return user_model
user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)]
async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]):
user_model = db.get(User, user_id)
if user_model is None:
raise UserNotFoundException(user_id=user_id)
user_model = db.get(User, user_id)
if user_model is None:
raise UserNotFoundException(user_id=user_id)
return user_model
return user_model
user_model_query_dependency = Annotated[User, Depends(get_user_model_query)]
async def get_user_model_body(db: DbSession, request_model: UserIDMixin):
user_model = db.get(User, request_model.user_id)
if user_model is None:
raise UserNotFoundException(user_id=request_model.user_id)
user_model = db.get(User, request_model.user_id)
if user_model is None:
raise UserNotFoundException(user_id=request_model.user_id)
return user_model
return user_model
user_model_body_dependency = Annotated[User, Depends(get_user_model_body)]

View file

@ -11,13 +11,13 @@ from fastapi import HTTPException, status
class UserNotFoundException(HTTPException):
def __init__(self, user_id: Optional[int] = None) -> None:
detail = (
"User not found"
if user_id is None
else f"User with ID '{user_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def __init__(self, user_id: Optional[int] = None) -> None:
detail = (
"User not found"
if user_id is None
else f"User with ID '{user_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)

View file

@ -20,26 +20,26 @@ from src.models import CustomBase
class User(CustomBase, IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin):
__tablename__ = "user"
__tablename__ = "user"
email: Mapped[str]
first_name: Mapped[str]
last_name: Mapped[str]
oidc_id: Mapped[str] = mapped_column(index=True, unique=True)
email: Mapped[str]
first_name: Mapped[str]
last_name: Mapped[str]
oidc_id: Mapped[str] = mapped_column(index=True, unique=True)
organisation_rel = relationship(
"Organisation", secondary="orgusers", back_populates="user_rel"
)
organisation_rel = relationship(
"Organisation", secondary="orgusers", back_populates="user_rel"
)
group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel")
group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel")
@property
def organisations(self):
return [{"name": org.name, "id": org.id} for org in self.organisation_rel]
@property
def organisations(self):
return [{"name": org.name, "id": org.id} for org in self.organisation_rel]
@property
def groups(self):
result = defaultdict(list)
for group in self.group_rel:
result[group.org_rel.name].append({"name": group.name, "id": group.id})
return dict(result)
@property
def groups(self):
result = defaultdict(list)
for group in self.group_rel:
result[group.org_rel.name].append({"name": group.name, "id": group.id})
return dict(result)

View file

@ -13,205 +13,205 @@ from fastapi import APIRouter, status, BackgroundTasks
from src.iam.models import Group
from src.organisation.exceptions import OrgNotFoundException
from src.user.schemas import (
UserResponse,
OIDCClaims,
UserPostInvitationRequest,
UserPostInvitationAcceptRequest,
UserGetSelfOrgsResponse,
UserPostInvitationResponse,
UserPostInvitationAcceptResponse,
UserResponse,
OIDCClaims,
UserPostInvitationRequest,
UserPostInvitationAcceptRequest,
UserGetSelfOrgsResponse,
UserPostInvitationResponse,
UserPostInvitationAcceptResponse,
)
from src.user.dependencies import (
user_model_claims_dependency,
user_model_query_dependency,
user_model_claims_dependency,
user_model_query_dependency,
)
from src.user.service import send_invitation
from src.organisation.models import Organisation as Org
from src.auth.dependencies import (
super_admin_dependency,
org_model_root_claim_body_dependency,
super_admin_dependency,
org_model_root_claim_body_dependency,
)
from src.auth.service import claims_dependency
from src.database import DbSession
from src.utils import verify_email_token
router = APIRouter(
prefix="/user",
tags=["User"],
prefix="/user",
tags=["User"],
)
@router.get(
"/self/claims",
summary="Get current user OIDC claims.",
response_model=OIDCClaims,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
"/self/claims",
summary="Get current user OIDC claims.",
response_model=OIDCClaims,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
)
async def current_user_claims(user: claims_dependency):
"""
Returns the full OIDC claims associated with the currently logged-in user.
"""
user["allowed_origins"] = user.get("allowed-origins", [])
return user
"""
Returns the full OIDC claims associated with the currently logged-in user.
"""
user["allowed_origins"] = user.get("allowed-origins", [])
return user
@router.get(
"/self/db",
summary="Get current user hub details.",
response_model=UserResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
"/self/db",
summary="Get current user hub details.",
response_model=UserResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
)
async def current_user(user_model: user_model_claims_dependency):
"""
Returns the database details associated with the currently logged-in user.
"""
return user_model
"""
Returns the database details associated with the currently logged-in user.
"""
return user_model
@router.get(
"",
summary="Get user hub details by ID.",
response_model=UserResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
"",
summary="Get user hub details by ID.",
response_model=UserResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
)
async def get_user_by_id(
user_model: user_model_query_dependency, su: super_admin_dependency
user_model: user_model_query_dependency, su: super_admin_dependency
):
"""
Returns the database details associated with the provided user ID.
"""
return user_model
"""
Returns the database details associated with the provided user ID.
"""
return user_model
@router.delete(
"",
summary="Delete user from hub by ID.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
},
"",
summary="Delete user from hub by ID.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
},
)
async def delete_user_by_id(
db: DbSession,
user_model: user_model_query_dependency,
su: super_admin_dependency,
db: DbSession,
user_model: user_model_query_dependency,
su: super_admin_dependency,
):
"""
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
"""
db.delete(user_model)
db.commit()
"""
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
"""
db.delete(user_model)
db.commit()
@router.get(
"/self/orgs",
summary="Get all orgs the current user is a member of",
status_code=status.HTTP_200_OK,
response_model=UserGetSelfOrgsResponse,
responses={},
"/self/orgs",
summary="Get all orgs the current user is a member of",
status_code=status.HTTP_200_OK,
response_model=UserGetSelfOrgsResponse,
responses={},
)
async def get_user_orgs(user_model: user_model_claims_dependency):
user_orgs = user_model.organisation_rel
response = []
for org in user_orgs:
response.append(
{
"organisation_id": org.id,
"name": org.name,
"status": org.status,
"intake_questionnaire": org.intake_questionnaire,
"root_user_email": org.root_user_email,
"billing_contact": {
"id": org.billing_contact_id,
"email": org.billing_contact_rel.email,
},
"owner_contact": {
"id": org.owner_contact_id,
"email": org.owner_contact_rel.email,
},
"security_contact": {
"id": org.security_contact_id,
"email": org.security_contact_rel.email,
},
}
)
user_orgs = user_model.organisation_rel
response = []
for org in user_orgs:
response.append(
{
"organisation_id": org.id,
"name": org.name,
"status": org.status,
"intake_questionnaire": org.intake_questionnaire,
"root_user_email": org.root_user_email,
"billing_contact": {
"id": org.billing_contact_id,
"email": org.billing_contact_rel.email,
},
"owner_contact": {
"id": org.owner_contact_id,
"email": org.owner_contact_rel.email,
},
"security_contact": {
"id": org.security_contact_id,
"email": org.security_contact_rel.email,
},
}
)
return {"organisations": response}
return {"organisations": response}
@router.post(
"/invitation",
summary="Send an email invitation for a user to join an org",
status_code=status.HTTP_200_OK,
response_model=UserPostInvitationResponse,
"/invitation",
summary="Send an email invitation for a user to join an org",
status_code=status.HTTP_200_OK,
response_model=UserPostInvitationResponse,
)
async def invitation(
background_tasks: BackgroundTasks,
org_model: org_model_root_claim_body_dependency,
request_model: UserPostInvitationRequest,
background_tasks: BackgroundTasks,
org_model: org_model_root_claim_body_dependency,
request_model: UserPostInvitationRequest,
):
org_id = org_model.id
org_name = org_model.name
user_email = request_model.user_email
org_id = org_model.id
org_name = org_model.name
user_email = request_model.user_email
background_tasks.add_task(
send_invitation, org_id=org_id, org_name=org_name, user_email=user_email
)
background_tasks.add_task(
send_invitation, org_id=org_id, org_name=org_name, user_email=user_email
)
response = {
"organisation": org_model,
"invited_email": user_email,
}
response = {
"organisation": org_model,
"invited_email": user_email,
}
return response
return response
@router.post(
"/invitation/accept",
summary="Accept email invitation to join an org",
status_code=status.HTTP_200_OK,
response_model=UserPostInvitationAcceptResponse,
"/invitation/accept",
summary="Accept email invitation to join an org",
status_code=status.HTTP_200_OK,
response_model=UserPostInvitationAcceptResponse,
)
async def accept_invitation(
db: DbSession,
user_model: user_model_claims_dependency,
request_model: UserPostInvitationAcceptRequest,
db: DbSession,
user_model: user_model_claims_dependency,
request_model: UserPostInvitationAcceptRequest,
):
email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model)
email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model)
org_model = db.get(Org, email_claims["org_id"])
if org_model is None:
raise OrgNotFoundException()
org_model = db.get(Org, email_claims["org_id"])
if org_model is None:
raise OrgNotFoundException()
org_model.user_rel.append(user_model)
db.flush()
group_model = (
db.query(Group)
.filter(Group.org_id == org_model.id)
.filter(Group.name == "Default Users")
.first()
)
if group_model is not None:
user_model.group_rel.append(group_model)
org_model.user_rel.append(user_model)
db.flush()
group_model = (
db.query(Group)
.filter(Group.org_id == org_model.id)
.filter(Group.name == "Default Users")
.first()
)
if group_model is not None:
user_model.group_rel.append(group_model)
response = {
"organisation": org_model,
"user": user_model,
}
response = {
"organisation": org_model,
"user": user_model,
}
db.commit()
db.commit()
return response
return response

View file

@ -10,63 +10,63 @@ from src.schemas import CustomBaseModel, OrgIDMixin, OrgSummary, UserSummary
class OIDCClaims(CustomBaseModel):
exp: int
iat: int
auth_time: int
jti: str
iss: str
aud: str
sub: str
typ: str
azp: str
sid: str
acr: str
allowed_origins: list[str]
realm_access: dict[str, list[str]]
resource_access: dict[str, dict[str, list[str]]]
scope: str
email_verified: bool
name: str
preferred_username: str
given_name: str
family_name: str
email: str
db_id: int
exp: int
iat: int
auth_time: int
jti: str
iss: str
aud: str
sub: str
typ: str
azp: str
sid: str
acr: str
allowed_origins: list[str]
realm_access: dict[str, list[str]]
resource_access: dict[str, dict[str, list[str]]]
scope: str
email_verified: bool
name: str
preferred_username: str
given_name: str
family_name: str
email: str
db_id: int
class OIDCUser(CustomBaseModel):
first_name: str
last_name: str
email: str
oidc_id: str
first_name: str
last_name: str
email: str
oidc_id: str
class UserResponse(CustomBaseModel):
id: int
first_name: str
last_name: str
email: str
organisations: list[Optional[dict[str, str | int]]]
groups: Optional[dict[str, list[dict[str, str | int]]]] = None
id: int
first_name: str
last_name: str
email: str
organisations: list[Optional[dict[str, str | int]]]
groups: Optional[dict[str, list[dict[str, str | int]]]] = None
class UserPostInvitationRequest(OrgIDMixin):
user_email: EmailStr
user_email: EmailStr
class UserPostInvitationAcceptRequest(CustomBaseModel):
jwt: str
jwt: str
class UserGetSelfOrgsResponse(CustomBaseModel):
organisations: list[OrgSchema]
organisations: list[OrgSchema]
class UserPostInvitationResponse(CustomBaseModel):
organisation: OrgSummary
invited_email: EmailStr
organisation: OrgSummary
invited_email: EmailStr
class UserPostInvitationAcceptResponse(CustomBaseModel):
organisation: OrgSummary
user: UserSummary
organisation: OrgSummary
user: UserSummary

View file

@ -16,49 +16,49 @@ from src.user.models import User
async def add_user(db: Session, user_claims: dict[str, Any]) -> int:
try:
valid_user = OIDCUser(
first_name=user_claims["given_name"],
last_name=user_claims["family_name"],
email=user_claims["email"],
oidc_id=user_claims["sub"],
)
except Exception as e:
logging.exception(e)
raise UnprocessableContentException("Invalid or missing OIDC data")
try:
valid_user = OIDCUser(
first_name=user_claims["given_name"],
last_name=user_claims["family_name"],
email=user_claims["email"],
oidc_id=user_claims["sub"],
)
except Exception as e:
logging.exception(e)
raise UnprocessableContentException("Invalid or missing OIDC data")
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
if not db_user:
user_model = User(**valid_user.model_dump())
db.add(user_model)
user_id = user_model.id
db.commit()
return user_id
if not db_user:
user_model = User(**valid_user.model_dump())
db.add(user_model)
user_id = user_model.id
db.commit()
return user_id
user_id = db_user.id
db_user.first_name = valid_user.first_name
db_user.last_name = valid_user.last_name
db.commit()
return user_id
user_id = db_user.id
db_user.first_name = valid_user.first_name
db_user.last_name = valid_user.last_name
db.commit()
return user_id
async def send_invitation(user_email: str, org_name: str, org_id: int):
expiry_delta = timedelta(hours=24)
expiry = datetime.now(timezone.utc) + expiry_delta
claims = {
"email": user_email,
"org_id": org_id,
"exp": expiry,
"type": "org_invite",
}
expiry_delta = timedelta(hours=24)
expiry = datetime.now(timezone.utc) + expiry_delta
claims = {
"email": user_email,
"org_id": org_id,
"exp": expiry,
"type": "org_invite",
}
token = await generate_jwt(claims)
subject = f"You have been invited to join {org_name}"
body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
token = await generate_jwt(claims)
subject = f"You have been invited to join {org_name}"
body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
await send_email(
recipient=user_email,
subject=subject,
body=body,
)
await send_email(
recipient=user_email,
subject=subject,
body=body,
)

View file

@ -11,52 +11,56 @@ KEY = jwk.import_key(settings.SECRET_KEY.get_secret_value(), "oct")
async def generate_jwt(claims):
jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims)
jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims)
return jwt_token
return jwt_token
async def decode_jwt(encoded):
try:
token = jwt.decode(encoded, key=KEY)
return token.claims
except errors.DecodeError:
raise UnauthorizedException("Invalid JWS")
try:
token = jwt.decode(encoded, key=KEY)
return token.claims
except errors.DecodeError:
raise UnauthorizedException("Invalid JWS")
async def verify_email_token(user_model, token):
email_claims = await decode_jwt(token)
email_claims = await decode_jwt(token)
claimed_email = email_claims["email"]
claimed_email = email_claims["email"]
expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc)
expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc)
if expiry < datetime.now(timezone.utc):
raise UnauthorizedException("Invitation expired.")
if expiry < datetime.now(timezone.utc):
raise UnauthorizedException("Invitation expired.")
if user_model.email != claimed_email:
raise ForbiddenException("The logged in user and email do not match.")
if user_model.email != claimed_email:
raise ForbiddenException("The logged in user and email do not match.")
return email_claims
return email_claims
async def send_email(recipient: str, subject: str, body: str):
if settings.ENVIRONMENT.is_testing:
return
if settings.ENVIRONMENT.is_testing:
return
lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value())
lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value())
if settings.ENVIRONMENT == "local":
recipient = "ok@testing.lettermint.co"
if settings.ENVIRONMENT == "local":
recipient = "ok@testing.lettermint.co"
try:
response = (
lettermint.email.from_("noreply@sr2.uk")
.to(recipient)
.subject(subject)
.text(body)
.send()
)
logging.info("Email sent to {} with subject {} (Status: {})".format(recipient, subject, response.status_code))
except ValidationError as e:
logging.exception(e)
try:
response = (
lettermint.email.from_("noreply@sr2.uk")
.to(recipient)
.subject(subject)
.text(body)
.send()
)
logging.info(
"Email sent to {} with subject {} (Status: {})".format(
recipient, subject, response.status_code
)
)
except ValidationError as e:
logging.exception(e)

View file

@ -22,15 +22,15 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture()
def db_session():
CustomBase.metadata.drop_all(bind=engine)
CustomBase.metadata.create_all(bind=engine)
db = SessionLocal()
try:
_seed(db) # extracted seeding logic into a plain function
yield db
finally:
db.rollback()
db.close()
CustomBase.metadata.drop_all(bind=engine)
CustomBase.metadata.create_all(bind=engine)
db = SessionLocal()
try:
_seed(db) # extracted seeding logic into a plain function
yield db
finally:
db.rollback()
db.close()
@pytest.fixture
@ -83,176 +83,176 @@ async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def _seed(db):
db.add(
User(
email="admin@test.com",
first_name="Admin",
last_name="Test",
oidc_id="abcd-efgh-ijkl-mnop",
)
)
db.add(
User(
email="user@orgone.com",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-qwer",
)
)
db.add(
User(
email="root@orgtwo.com",
first_name="Root",
last_name="Test",
oidc_id="abcd-efgh-ijkl-hjkl",
)
)
db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927"))
db.flush()
db.add(
Org(
name="Org One",
root_user_id=1,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(
Org(
name="Org Two",
root_user_id=3,
billing_contact_id=4,
owner_contact_id=5,
security_contact_id=6,
status="approved",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(
Org(
name="Org Three",
root_user_id=1,
billing_contact_id=7,
owner_contact_id=8,
security_contact_id=9,
status="partial",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(OrgUsers(org_id=1, user_id=2))
db.add(Service(name="Test Service", api_key="123456789"))
db.add(Permission(service_id=1, resource="test_resource", action="read"))
db.add(Permission(service_id=1, resource="test_resource", action="move"))
db.add(Permission(service_id=1, resource="test_resource", action="delete"))
db.add(OrgPermissions(org_id=1, permission_id=1))
db.add(OrgPermissions(org_id=1, permission_id=2))
db.add(Group(name="Org One Group", org_id=1))
db.add(Group(name="Org Two Group", org_id=2))
db.add(Group(name="Org One Group Two", org_id=1))
db.flush()
group_model = db.get(Group, 1)
perm_model = db.get(Permission, 1)
group_model.permission_rel.append(perm_model)
user_model = db.get(User, 1)
org_model = db.get(Org, 1)
org_model.user_rel.append(user_model)
org_model.group_rel.append(group_model)
db.flush()
group_model.user_rel.append(user_model)
db.commit()
db.add(
User(
email="admin@test.com",
first_name="Admin",
last_name="Test",
oidc_id="abcd-efgh-ijkl-mnop",
)
)
db.add(
User(
email="user@orgone.com",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-qwer",
)
)
db.add(
User(
email="root@orgtwo.com",
first_name="Root",
last_name="Test",
oidc_id="abcd-efgh-ijkl-hjkl",
)
)
db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927"))
db.flush()
db.add(
Org(
name="Org One",
root_user_id=1,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(
Org(
name="Org Two",
root_user_id=3,
billing_contact_id=4,
owner_contact_id=5,
security_contact_id=6,
status="approved",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(
Org(
name="Org Three",
root_user_id=1,
billing_contact_id=7,
owner_contact_id=8,
security_contact_id=9,
status="partial",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(OrgUsers(org_id=1, user_id=2))
db.add(Service(name="Test Service", api_key="123456789"))
db.add(Permission(service_id=1, resource="test_resource", action="read"))
db.add(Permission(service_id=1, resource="test_resource", action="move"))
db.add(Permission(service_id=1, resource="test_resource", action="delete"))
db.add(OrgPermissions(org_id=1, permission_id=1))
db.add(OrgPermissions(org_id=1, permission_id=2))
db.add(Group(name="Org One Group", org_id=1))
db.add(Group(name="Org Two Group", org_id=2))
db.add(Group(name="Org One Group Two", org_id=1))
db.flush()
group_model = db.get(Group, 1)
perm_model = db.get(Permission, 1)
group_model.permission_rel.append(perm_model)
user_model = db.get(User, 1)
org_model = db.get(Org, 1)
org_model.user_rel.append(user_model)
org_model.group_rel.append(group_model)
db.flush()
group_model.user_rel.append(user_model)
db.commit()
def generate_query_and_status(params) -> list[tuple[str, int]]:
possible_values = [0, -1, 42, "banana", ""]
possible_values = [0, -1, 42, "banana", ""]
defaults = [f"{param}=1" for param in params]
defaults = [f"{param}=1" for param in params]
# Missing params
query_list = [
"&".join(combo)
for r in range(len(defaults) + 1)
for combo in combinations(defaults, r)
]
# Missing params
query_list = [
"&".join(combo)
for r in range(len(defaults) + 1)
for combo in combinations(defaults, r)
]
# Complete query as default for invalid checks
default_query = query_list.pop(-1)
# Complete query as default for invalid checks
default_query = query_list.pop(-1)
# Checks for each param being invalid
for param in params:
for value in possible_values:
new_value = f"&{param}={value}"
query_list.append(default_query.replace(f"{param}=1", new_value))
# Checks for each param being invalid
for param in params:
for value in possible_values:
new_value = f"&{param}={value}"
query_list.append(default_query.replace(f"{param}=1", new_value))
query_and_status = []
query_and_status = []
# Assign expected status
for query in query_list:
# ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if "42" in query else 422
# Remove leading "&" if present
query = query if len(query) > 1 and query[0] != "&" else query[1:]
query_and_status.append((query, status))
# Assign expected status
for query in query_list:
# ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if "42" in query else 422
# Remove leading "&" if present
query = query if len(query) > 1 and query[0] != "&" else query[1:]
query_and_status.append((query, status))
return query_and_status
return query_and_status
def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]:
possible_values_int = [0, -1, 42, "banana", ""]
possible_values_str = [0, "", "a"]
possible_values_int = [0, -1, 42, "banana", ""]
possible_values_str = [0, "", "a"]
defaults = [{param: 1 for param in params.keys()}]
defaults = [{param: 1 for param in params.keys()}]
# Missing params
body_list = [
{key: ("valid string" if params[key] == "str" else 1) for key in combo}
for r in range(len(defaults[0].keys()) + 1)
for combo in combinations(defaults[0].keys(), r)
]
# Missing params
body_list = [
{key: ("valid string" if params[key] == "str" else 1) for key in combo}
for r in range(len(defaults[0].keys()) + 1)
for combo in combinations(defaults[0].keys(), r)
]
# Complete body as default for generating invalid checks
default_body = body_list.pop(-1)
# Complete body as default for generating invalid checks
default_body = body_list.pop(-1)
# Generates checks for each param being invalid
for param, typ in params.items():
if typ == "int":
possible_values = possible_values_int
elif typ == "str":
possible_values = possible_values_str
else:
raise TypeError(f"Unknown type {typ}")
for value in possible_values:
new_record = default_body.copy()
new_record[param] = value
body_list.append(new_record)
# Generates checks for each param being invalid
for param, typ in params.items():
if typ == "int":
possible_values = possible_values_int
elif typ == "str":
possible_values = possible_values_str
else:
raise TypeError(f"Unknown type {typ}")
for value in possible_values:
new_record = default_body.copy()
new_record[param] = value
body_list.append(new_record)
body_and_status = []
body_and_status = []
# Assign expected status
for body in body_list:
# ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if 42 in body.values() else 422
body_and_status.append((body, status))
return body_and_status
# Assign expected status
for body in body_list:
# ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if 42 in body.values() else 422
body_and_status.append((body, status))
return body_and_status
def get_testable_routes():

View file

@ -8,181 +8,181 @@ import pytest
from httpx import AsyncClient
pytestmark = [
pytest.mark.auth,
pytest.mark.preapproval,
pytest.mark.auth,
pytest.mark.preapproval,
]
@pytest.mark.anyio
async def test_get_org_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org?org_id=3")
assert resp.status_code != 422
assert resp.status_code == 200
resp = await no_su_client.get("/org?org_id=3")
assert resp.status_code != 422
assert resp.status_code == 200
@pytest.mark.anyio
async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.patch(
"/org/questionnaire",
json={
"organisation_id": 3,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
assert resp.status_code != 422
assert resp.status_code == 200
resp = await no_su_client.patch(
"/org/questionnaire",
json={
"organisation_id": 3,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
assert resp.status_code != 422
assert resp.status_code == 200
@pytest.mark.anyio
async def test_get_org_users_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/users?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.get("/org/users?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_groups_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/groups?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.get("/org/groups?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_contact_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing")
assert resp.status_code != 422
assert resp.status_code == 200
resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing")
assert resp.status_code != 422
assert resp.status_code == 200
@pytest.mark.anyio
async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.patch(
"/org/contact",
json={
"organisation_id": 3,
"contact_type": "billing",
"email": "user@example.com",
},
)
assert resp.status_code != 422
assert resp.status_code == 200
resp = await no_su_client.patch(
"/org/contact",
json={
"organisation_id": 3,
"contact_type": "billing",
"email": "user@example.com",
},
)
assert resp.status_code != 422
assert resp.status_code == 200
@pytest.mark.anyio
async def test_get_service_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/service?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.get("/service?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_iam_group_permissions_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_iam_group_users_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_post_iam_group_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 3}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 3}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.put(
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 3},
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.put(
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 3},
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_iam_permissions_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/permissions?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.get("/iam/permissions?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_post_iam_permissions_search_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 3, "action": "read"}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 3, "action": "read"}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_delete_org_user_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.delete("/org/user?org_id=3&user_id=1")
resp = await no_su_client.delete("/org/user?org_id=3&user_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_delete_preapproval_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.delete("/org/self?org_id=3")
resp = await no_su_client.delete("/org/self?org_id=3")
assert resp.status_code != 422
assert resp.status_code == 204
assert resp.status_code != 422
assert resp.status_code == 204
@pytest.mark.anyio
async def test_post_user_invitation_auth_approval(no_su_client: AsyncClient):
body = {"user_email": "admin@test.com", "organisation_id": 3}
resp = await no_su_client.post("/user/invitation", json=body)
body = {"user_email": "admin@test.com", "organisation_id": 3}
resp = await no_su_client.post("/user/invitation", json=body)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_delete_group_permissions_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1")
resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_delete_group_users_success(no_su_client: AsyncClient):
resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1")
resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_put_group_user_invitation_success(no_su_client: AsyncClient):
body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1}
resp = await no_su_client.put("/iam/group/user/invitation", json=body)
body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1}
resp = await no_su_client.put("/iam/group/user/invitation", json=body)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]

View file

@ -5,14 +5,14 @@ from httpx import AsyncClient
pytestmark = [
pytest.mark.auth,
pytest.mark.auth,
]
@pytest.mark.anyio
async def test_get_org_auth_root_su(default_client: AsyncClient):
# If a super admin can access a resource when not the root user
resp = await default_client.get("/org?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 200
assert resp.json()["organisations"][0]["name"] == "Org Two"
# If a super admin can access a resource when not the root user
resp = await default_client.get("/org?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 200
assert resp.json()["organisations"][0]["name"] == "Org Two"

View file

@ -7,147 +7,147 @@ import pytest
from httpx import AsyncClient
pytestmark = [
pytest.mark.auth,
pytest.mark.root_user,
pytest.mark.auth,
pytest.mark.root_user,
]
@pytest.mark.anyio
async def test_get_org_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/org?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.patch(
"/org/questionnaire",
json={
"organisation_id": 2,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.patch(
"/org/questionnaire",
json={
"organisation_id": 2,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_users_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/users?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/org/users?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_groups_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/groups?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/org/groups?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_contact_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_patch_org_contact_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.patch(
"/org/contact",
json={
"organisation_id": 2,
"contact_type": "billing",
"email": "user@example.com",
},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.patch(
"/org/contact",
json={
"organisation_id": 2,
"contact_type": "billing",
"email": "user@example.com",
},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_service_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/service?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/service?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_iam_group_permissions_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_post_iam_group_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.put(
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 2},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.put(
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 2},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_put_iam_group_user_auth_root(
no_su_client: AsyncClient,
no_su_client: AsyncClient,
):
resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/permissions?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/iam/permissions?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 2, "action": "read"}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 2, "action": "read"}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]

View file

@ -7,69 +7,69 @@ import pytest
from httpx import AsyncClient
pytestmark = [
pytest.mark.auth,
pytest.mark.super_admin,
pytest.mark.auth,
pytest.mark.super_admin,
]
@pytest.mark.anyio
async def test_get_user_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.get("/user?user_id=1")
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
resp = await no_su_client.get("/user?user_id=1")
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio
async def test_patch_org_status_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch(
"/org/status", json={"organisation_id": 1, "status": "submitted"}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
resp = await no_su_client.patch(
"/org/status", json={"organisation_id": 1, "status": "submitted"}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio
async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
resp = await no_su_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio
async def test_patch_service_key_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch("/service/key", json={"service_id": 1})
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
resp = await no_su_client.patch("/service/key", json={"service_id": 1})
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio
async def test_post_service_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.post("/service", json={"name": "New Test Service"})
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
resp = await no_su_client.post("/service", json={"name": "New Test Service"})
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio
async def test_post_perm_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.post(
"/iam/permission",
json={"service_id": 1, "resource": "test_resource", "action": "create"},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
resp = await no_su_client.post(
"/iam/permission",
json={"service_id": 1, "resource": "test_resource", "action": "create"},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio
async def test_post_org_user_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2})
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be super admin" in resp.json()["detail"]
resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2})
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be super admin" in resp.json()["detail"]

View file

@ -7,22 +7,22 @@ from httpx import AsyncClient
pytestmark = [
pytest.mark.auth,
pytest.mark.user,
pytest.mark.auth,
pytest.mark.user,
]
@pytest.mark.anyio
async def test_get_self_db_auth_user(no_user_client: AsyncClient):
resp = await no_user_client.get("/user/self/db")
assert resp.status_code != 422
assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated"
resp = await no_user_client.get("/user/self/db")
assert resp.status_code != 422
assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated"
@pytest.mark.anyio
async def test_post_org_success_auth_user(no_user_client: AsyncClient):
resp = await no_user_client.post("/org", json={"name": "New Test Org"})
assert resp.status_code != 422
assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated"
resp = await no_user_client.post("/org", json={"name": "New Test Org"})
assert resp.status_code != 422
assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated"

View file

@ -4,7 +4,7 @@ from httpx import AsyncClient
@pytest.mark.anyio
async def test_healthcheck(default_client: AsyncClient):
resp = await default_client.get("/healthcheck")
resp = await default_client.get("/healthcheck")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}

File diff suppressed because it is too large Load diff

View file

@ -9,506 +9,506 @@ from .conftest import generate_query_and_status
pytestmark = [
pytest.mark.org_module,
pytest.mark.org_module,
]
@pytest.mark.anyio
async def test_get_org_success(default_client: AsyncClient):
resp = await default_client.get("/org?org_id=1")
data = resp.json()
resp = await default_client.get("/org?org_id=1")
data = resp.json()
assert resp.status_code == 200
assert resp.status_code == 200
org = data["organisations"][0]
org = data["organisations"][0]
assert isinstance(org, dict)
assert org["organisation_id"] == 1
assert org["name"] == "Org One"
assert org["status"] == "approved"
assert org["root_user_email"] == "admin@test.com"
assert "intake_questionnaire" in org
assert isinstance(org["intake_questionnaire"], dict)
assert isinstance(org, dict)
assert org["organisation_id"] == 1
assert org["name"] == "Org One"
assert org["status"] == "approved"
assert org["root_user_email"] == "admin@test.com"
assert "intake_questionnaire" in org
assert isinstance(org["intake_questionnaire"], dict)
assert org["billing_contact"]["email"] == "billing@orgone.com"
assert org["owner_contact"]["email"] == "owner@orgone.com"
assert org["security_contact"]["email"] == "security@orgone.com"
assert org["billing_contact"]["email"] == "billing@orgone.com"
assert org["owner_contact"]["email"] == "owner@orgone.com"
assert org["security_contact"]["email"] == "security@orgone.com"
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio
async def test_get_org_status_checks(
default_client: AsyncClient, query: str, expected_status: int
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org?{query}")
resp = await default_client.get(f"/org?{query}")
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_post_org_success(default_client: AsyncClient):
resp = await default_client.post("/org", json={"name": "New Test Org"})
data = resp.json()
resp = await default_client.post("/org", json={"name": "New Test Org"})
data = resp.json()
assert resp.status_code == 201
assert data["name"] == "New Test Org"
assert data["status"] == "partial"
assert resp.status_code == 201
assert data["name"] == "New Test Org"
assert data["status"] == "partial"
@pytest.mark.parametrize(
"body, expected_status",
[
({"name": 42}, 422),
({}, 422),
({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422),
],
"body, expected_status",
[
({"name": 42}, 422),
({}, 422),
({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422),
],
)
@pytest.mark.anyio
async def test_post_org_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/org", json=body)
resp = await default_client.post("/org", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient):
resp = await default_client.patch(
"/org/questionnaire",
json={
"organisation_id": 3,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
data = resp.json()
resp = await default_client.patch(
"/org/questionnaire",
json={
"organisation_id": 3,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
data = resp.json()
assert resp.status_code == 200
assert data["name"] == "Org Three"
assert data["status"] == "partial"
assert "intake_questionnaire" in data
assert isinstance(data["intake_questionnaire"], dict)
metadata = data["intake_questionnaire"]["metadata"]
assert metadata["version"] == 0
assert metadata["submission_date"] is None
questions = data["intake_questionnaire"]["questions"]
assert questions["question_one"] == "new answer one"
assert questions["question_two"] == "answer two"
assert questions["question_three"] is None
assert resp.status_code == 200
assert data["name"] == "Org Three"
assert data["status"] == "partial"
assert "intake_questionnaire" in data
assert isinstance(data["intake_questionnaire"], dict)
metadata = data["intake_questionnaire"]["metadata"]
assert metadata["version"] == 0
assert metadata["submission_date"] is None
questions = data["intake_questionnaire"]["questions"]
assert questions["question_one"] == "new answer one"
assert questions["question_two"] == "answer two"
assert questions["question_three"] is None
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42}, 404),
({"organisation_id": "Org One"}, 422),
({"organisation_id": ""}, 422),
({}, 422),
(
{
"organisation_id": "1",
"intake_questionnaire": {"question_one": 42},
"partial": True,
},
422,
),
(
{"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}},
422,
),
(
{
"organisation_id": "1",
"intake_questionnaire": {"question_one": "valid"},
"partial": 42,
},
422,
),
],
"body, expected_status",
[
({"organisation_id": 42}, 404),
({"organisation_id": "Org One"}, 422),
({"organisation_id": ""}, 422),
({}, 422),
(
{
"organisation_id": "1",
"intake_questionnaire": {"question_one": 42},
"partial": True,
},
422,
),
(
{"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}},
422,
),
(
{
"organisation_id": "1",
"intake_questionnaire": {"question_one": "valid"},
"partial": 42,
},
422,
),
],
)
@pytest.mark.anyio
async def test_patch_questionnaire_partial_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/questionnaire", json=body)
resp = await default_client.patch("/org/questionnaire", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient):
resp = await default_client.patch(
"/org/questionnaire",
json={
"organisation_id": 3,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": False,
},
)
data = resp.json()
resp = await default_client.patch(
"/org/questionnaire",
json={
"organisation_id": 3,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": False,
},
)
data = resp.json()
assert resp.status_code == 200
assert data["name"] == "Org Three"
assert data["status"] == "submitted"
assert "intake_questionnaire" in data
assert isinstance(data["intake_questionnaire"], dict)
metadata = data["intake_questionnaire"]["metadata"]
assert metadata["version"] == 0
assert metadata["submission_date"] is not None
questions = data["intake_questionnaire"]["questions"]
assert questions["question_one"] == "new answer one"
assert questions["question_two"] == "answer two"
assert questions["question_three"] is None
assert resp.status_code == 200
assert data["name"] == "Org Three"
assert data["status"] == "submitted"
assert "intake_questionnaire" in data
assert isinstance(data["intake_questionnaire"], dict)
metadata = data["intake_questionnaire"]["metadata"]
assert metadata["version"] == 0
assert metadata["submission_date"] is not None
questions = data["intake_questionnaire"]["questions"]
assert questions["question_one"] == "new answer one"
assert questions["question_two"] == "answer two"
assert questions["question_three"] is None
@pytest.mark.parametrize(
"status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"]
"status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"]
)
@pytest.mark.anyio
async def test_patch_org_status_success(default_client: AsyncClient, status: str):
resp = await default_client.patch(
"/org/status", json={"organisation_id": 1, "status": status}
)
data = resp.json()
resp = await default_client.patch(
"/org/status", json={"organisation_id": 1, "status": status}
)
data = resp.json()
assert resp.status_code == 200
assert data["name"] == "Org One"
assert data["status"] == status
assert resp.status_code == 200
assert data["name"] == "Org One"
assert data["status"] == status
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42}, 404),
({"organisation_id": "Org One"}, 422),
({"organisation_id": ""}, 422),
({}, 422),
({"organisation_id": "1", "status": True}, 422),
({"organisation_id": "1", "status": 42}, 422),
],
"body, expected_status",
[
({"organisation_id": 42}, 404),
({"organisation_id": "Org One"}, 422),
({"organisation_id": ""}, 422),
({}, 422),
({"organisation_id": "1", "status": True}, 422),
({"organisation_id": "1", "status": 42}, 422),
],
)
@pytest.mark.anyio
async def test_patch_org_status_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/status", json=body)
resp = await default_client.patch("/org/status", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_get_org_users_success(default_client: AsyncClient):
resp = await default_client.get("/org/users?org_id=1")
data = resp.json()
resp = await default_client.get("/org/users?org_id=1")
data = resp.json()
assert resp.status_code == 200
assert resp.status_code == 200
assert "users" in data
assert isinstance(data["users"], list)
assert len(data["users"]) == 2
assert "users" in data
assert isinstance(data["users"], list)
assert len(data["users"]) == 2
user = data["users"][0]
assert isinstance(user, dict)
assert user["email"] == "admin@test.com"
assert user["id"] == 1
user = data["users"][0]
assert isinstance(user, dict)
assert user["email"] == "admin@test.com"
assert user["id"] == 1
assert "organisation" in data
assert data["organisation"]["name"] == "Org One"
assert data["organisation"]["id"] == 1
assert "organisation" in data
assert data["organisation"]["name"] == "Org One"
assert data["organisation"]["id"] == 1
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio
async def test_get_org_users_status_checks(
default_client: AsyncClient, query: str, expected_status: int
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org/users?{query}")
resp = await default_client.get(f"/org/users?{query}")
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_post_org_user_success(default_client: AsyncClient):
resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3})
resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3})
assert resp.status_code == 200
assert resp.status_code == 200
data = resp.json()
data = resp.json()
assert "organisation" in data
assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "organisation" in data
assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "users" in data
assert isinstance(data["users"], list)
assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1
assert "users" in data
assert isinstance(data["users"], list)
assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42}, 404),
({}, 422),
({"organisation_id": 1, "user_id": "id"}, 422),
({"user_id": 2}, 422),
({"organisation_id": 1, "user_id": 42}, 404),
({"organisation_id": 1, "user_id": 1}, 409),
],
"body, expected_status",
[
({"organisation_id": 42}, 404),
({}, 422),
({"organisation_id": 1, "user_id": "id"}, 422),
({"user_id": 2}, 422),
({"organisation_id": 1, "user_id": 42}, 404),
({"organisation_id": 1, "user_id": 1}, 409),
],
)
@pytest.mark.anyio
async def test_post_org_user_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/org/user", json=body)
resp = await default_client.post("/org/user", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_patch_org_root_user_success(default_client: AsyncClient):
resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code == 200
resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code == 200
data = resp.json()
data = resp.json()
assert data["name"] == "Org One"
assert data["root_user_email"] == "user@orgone.com"
assert data["name"] == "Org One"
assert data["root_user_email"] == "user@orgone.com"
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42, "user_id": 2}, 404),
({"organisation_id": "Org One", "user_id": 2}, 422),
({"organisation_id": "", "user_id": 2}, 422),
({}, 422),
({"user_id": 2}, 422),
({"user_id": 42}, 404),
({"organisation_id": 1, "user_id": "Test User"}, 422),
],
"body, expected_status",
[
({"organisation_id": 42, "user_id": 2}, 404),
({"organisation_id": "Org One", "user_id": 2}, 422),
({"organisation_id": "", "user_id": 2}, 422),
({}, 422),
({"user_id": 2}, 422),
({"user_id": 42}, 404),
({"organisation_id": 1, "user_id": "Test User"}, 422),
],
)
@pytest.mark.anyio
async def test_patch_root_user_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/root_user", json=body)
resp = await default_client.patch("/org/root_user", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_patch_org_root_user_non_member(default_client: AsyncClient):
resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 3}
)
data = resp.json()
resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 3}
)
data = resp.json()
assert resp.status_code == 422
assert data["detail"] == "This user does not belong to your organisation."
assert resp.status_code == 422
assert data["detail"] == "This user does not belong to your organisation."
@pytest.mark.anyio
async def test_get_org_groups_success(default_client: AsyncClient):
resp = await default_client.get("/org/groups?org_id=1")
assert resp.status_code == 200
resp = await default_client.get("/org/groups?org_id=1")
assert resp.status_code == 200
data = resp.json()
data = resp.json()
assert "organisation" in data
assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "organisation" in data
assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "groups" in data
assert isinstance(data["groups"], list)
group = data["groups"][0]
assert isinstance(group, dict)
assert group["id"] == 1
assert group["name"] == "Org One Group"
assert "groups" in data
assert isinstance(data["groups"], list)
group = data["groups"][0]
assert isinstance(group, dict)
assert group["id"] == 1
assert group["name"] == "Org One Group"
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio
async def test_get_org_groups_status_checks(
default_client: AsyncClient, query: str, expected_status: int
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org/groups?{query}")
resp = await default_client.get(f"/org/groups?{query}")
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.parametrize("contact_type", ["billing", "security", "owner"])
@pytest.mark.anyio
async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str):
resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}")
data = resp.json()
resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}")
data = resp.json()
assert resp.status_code == 200
assert resp.status_code == 200
assert "organisation" in data
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "organisation" in data
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
attributes = [
"email",
"first_name",
"last_name",
"phonenumber",
"vat_number",
"address",
]
attributes = [
"email",
"first_name",
"last_name",
"phonenumber",
"vat_number",
"address",
]
for attribute in attributes:
assert attribute in data["contact"]
for attribute in attributes:
assert attribute in data["contact"]
address_attributes = [
"post_office_box_number",
"street_address",
"street_address_line_2",
"locality",
"address_region",
"country_code",
"postal_code",
]
address_attributes = [
"post_office_box_number",
"street_address",
"street_address_line_2",
"locality",
"address_region",
"country_code",
"postal_code",
]
for attribute in address_attributes:
assert attribute in data["contact"]["address"]
for attribute in address_attributes:
assert attribute in data["contact"]["address"]
@pytest.mark.parametrize(
"query, expected_status",
[
("org_id=42&contact_type=billing", 404),
("org_id=banana&contact_type=billing", 422),
("", 422),
("org_id=1&contact_type=contact", 422),
("contact_type=billing", 422),
],
"query, expected_status",
[
("org_id=42&contact_type=billing", 404),
("org_id=banana&contact_type=billing", 422),
("", 422),
("org_id=1&contact_type=contact", 422),
("contact_type=billing", 422),
],
)
@pytest.mark.anyio
async def test_get_org_contact_status_checks(
default_client: AsyncClient, query: str, expected_status: int
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org/contact?{query}")
resp = await default_client.get(f"/org/contact?{query}")
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.parametrize(
"key, value",
[
("email", "user@example.com"),
("first_name", "John"),
("last_name", "Doe"),
("phonenumber", "+441234567890"),
("vat_number", "GB123456789"),
("post_office_box_number", "PO Box 123"),
("street_address", "123 Example Street"),
("street_address_line_2", "Suite 4B"),
("locality", "Glasgow"),
("address_region", "Glasgow City"),
("country_code", "GB"),
("postal_code", "G1 1AA"),
],
"key, value",
[
("email", "user@example.com"),
("first_name", "John"),
("last_name", "Doe"),
("phonenumber", "+441234567890"),
("vat_number", "GB123456789"),
("post_office_box_number", "PO Box 123"),
("street_address", "123 Example Street"),
("street_address_line_2", "Suite 4B"),
("locality", "Glasgow"),
("address_region", "Glasgow City"),
("country_code", "GB"),
("postal_code", "G1 1AA"),
],
)
@pytest.mark.anyio
async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str):
resp = await default_client.patch(
"/org/contact",
json={"organisation_id": 1, "contact_type": "billing", key: value},
)
assert resp.status_code == 200
resp = await default_client.patch(
"/org/contact",
json={"organisation_id": 1, "contact_type": "billing", key: value},
)
assert resp.status_code == 200
data = resp.json()
data = resp.json()
assert "organisation" in data
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "organisation" in data
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
attributes = [
"email",
"first_name",
"last_name",
"phonenumber",
"vat_number",
"address",
]
attributes = [
"email",
"first_name",
"last_name",
"phonenumber",
"vat_number",
"address",
]
for attribute in attributes:
assert attribute in data["contact"]
for attribute in attributes:
assert attribute in data["contact"]
address_attributes = [
"post_office_box_number",
"street_address",
"street_address_line_2",
"locality",
"address_region",
"country_code",
"postal_code",
]
address_attributes = [
"post_office_box_number",
"street_address",
"street_address_line_2",
"locality",
"address_region",
"country_code",
"postal_code",
]
for attribute in address_attributes:
assert attribute in data["contact"]["address"]
for attribute in address_attributes:
assert attribute in data["contact"]["address"]
if key in data["contact"]:
assert data["contact"][key] == value
elif key in data["contact"]["address"]:
assert data["contact"]["address"][key] == value
else:
pytest.fail(f"Invalid contact key: {key}")
if key in data["contact"]:
assert data["contact"][key] == value
elif key in data["contact"]["address"]:
assert data["contact"]["address"][key] == value
else:
pytest.fail(f"Invalid contact key: {key}")
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42, "contact_type": "billing"}, 404),
({"organisation_id": 1, "contact_type": "security"}, 200),
({"organisation_id": 1, "contact_type": "owner"}, 200),
({"organisation_id": "Org One", "contact_type": "billing"}, 422),
({"organisation_id": "", "contact_type": "billing"}, 422),
({}, 422),
({"organisation_id": 1, "contact_type": "not_real"}, 422),
({"organisation_id": 1, "contact_type": 42}, 422),
({"organisation_id": 1, "contact_type": ""}, 422),
],
"body, expected_status",
[
({"organisation_id": 42, "contact_type": "billing"}, 404),
({"organisation_id": 1, "contact_type": "security"}, 200),
({"organisation_id": 1, "contact_type": "owner"}, 200),
({"organisation_id": "Org One", "contact_type": "billing"}, 422),
({"organisation_id": "", "contact_type": "billing"}, 422),
({}, 422),
({"organisation_id": 1, "contact_type": "not_real"}, 422),
({"organisation_id": 1, "contact_type": 42}, 422),
({"organisation_id": 1, "contact_type": ""}, 422),
],
)
@pytest.mark.anyio
async def test_patch_org_contact_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/contact", json=body)
resp = await default_client.patch("/org/contact", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_delete_org_success(default_client: AsyncClient):
resp = await default_client.delete("/org?org_id=1")
resp = await default_client.delete("/org?org_id=1")
assert resp.status_code == 204
assert resp.status_code == 204
@pytest.mark.anyio
async def test_delete_org_users_success(default_client: AsyncClient):
resp = await default_client.delete("/org/user?org_id=1&user_id=2")
resp = await default_client.delete("/org/user?org_id=1&user_id=2")
assert resp.status_code == 204
assert resp.status_code == 204
@pytest.mark.anyio
async def test_delete_preapproval_org_success(default_client: AsyncClient):
resp = await default_client.delete("/org/self?org_id=3")
resp = await default_client.delete("/org/self?org_id=3")
assert resp.status_code == 204
assert resp.status_code == 204

View file

@ -8,90 +8,90 @@ from httpx import AsyncClient
from .conftest import generate_query_and_status, generate_body_and_status
pytestmark = [
pytest.mark.service_module,
pytest.mark.service_module,
]
@pytest.mark.anyio
async def test_get_services_success(default_client: AsyncClient):
resp = await default_client.get("/service?org_id=1")
data = resp.json()
resp = await default_client.get("/service?org_id=1")
data = resp.json()
assert resp.status_code == 200
assert "services" in data
assert isinstance(data["services"], list)
assert data["services"][0]["id"] == 1
assert data["services"][0]["name"] == "Test Service"
assert resp.status_code == 200
assert "services" in data
assert isinstance(data["services"], list)
assert data["services"][0]["id"] == 1
assert data["services"][0]["name"] == "Test Service"
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio
async def test_get_services_status_checks(
default_client: AsyncClient, query: str, expected_status: int
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/service?{query}")
resp = await default_client.get(f"/service?{query}")
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_post_service_success(default_client: AsyncClient):
resp = await default_client.post("/service", json={"name": "New Test Service"})
data = resp.json()
resp = await default_client.post("/service", json={"name": "New Test Service"})
data = resp.json()
assert resp.status_code == 200
assert "service" in data
assert isinstance(data["service"], dict)
assert data["service"]["name"] == "New Test Service"
assert data["service"]["id"] == 2
assert isinstance(data["service"]["api_key"], str)
assert resp.status_code == 200
assert "service" in data
assert isinstance(data["service"], dict)
assert data["service"]["name"] == "New Test Service"
assert data["service"]["id"] == 2
assert isinstance(data["service"]["api_key"], str)
@pytest.mark.parametrize("body, expected_status", generate_body_and_status({"name": "str"}))
@pytest.mark.anyio
async def test_post_service_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/service", json=body)
resp = await default_client.post("/service", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_post_service_conflict(default_client: AsyncClient):
resp = await default_client.post("/service", json={"name": "Test Service"})
resp = await default_client.post("/service", json={"name": "Test Service"})
assert resp.status_code == 409
assert resp.status_code == 409
@pytest.mark.anyio
async def test_patch_service_success(default_client: AsyncClient):
resp = await default_client.patch("/service/key", json={"service_id": 1})
data = resp.json()
resp = await default_client.patch("/service/key", json={"service_id": 1})
data = resp.json()
assert resp.status_code == 200
assert "service" in data
assert isinstance(data["service"], dict)
assert data["service"]["name"] == "Test Service"
assert data["service"]["id"] == 1
assert isinstance(data["service"]["api_key"], str)
assert resp.status_code == 200
assert "service" in data
assert isinstance(data["service"], dict)
assert data["service"]["name"] == "Test Service"
assert data["service"]["id"] == 1
assert isinstance(data["service"]["api_key"], str)
@pytest.mark.parametrize(
"body, expected_status",
generate_body_and_status({"service_id": "int"}),
"body, expected_status",
generate_body_and_status({"service_id": "int"}),
)
@pytest.mark.anyio
async def test_patch_services_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/service/key", json=body)
resp = await default_client.patch("/service/key", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_delete_service_success(default_client: AsyncClient):
resp = await default_client.delete("/service?service_id=1")
resp = await default_client.delete("/service?service_id=1")
assert resp.status_code == 204
assert resp.status_code == 204

View file

@ -11,191 +11,191 @@ from .conftest import generate_query_and_status
pytestmark = [
pytest.mark.user_module,
pytest.mark.user_module,
]
@pytest.mark.anyio
async def test_get_self_db_success(default_client: AsyncClient):
resp = await default_client.get("/user/self/db")
data = resp.json()
resp = await default_client.get("/user/self/db")
data = resp.json()
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert isinstance(data["organisations"], list)
assert "groups" in data
assert isinstance(data["groups"], dict)
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert isinstance(data["organisations"], list)
assert "groups" in data
assert isinstance(data["groups"], dict)
@pytest.mark.anyio
async def test_get_user_success(default_client: AsyncClient):
resp = await default_client.get("/user?user_id=1")
data = resp.json()
resp = await default_client.get("/user?user_id=1")
data = resp.json()
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert isinstance(data["organisations"], list)
assert "groups" in data
assert isinstance(data["groups"], dict)
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert isinstance(data["organisations"], list)
assert "groups" in data
assert isinstance(data["groups"], dict)
@pytest.mark.anyio
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["user_id"]))
async def test_get_user_status_checks(
default_client: AsyncClient, query: str, expected_status: int
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/user?{query}")
resp = await default_client.get(f"/user?{query}")
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_delete_user_success(default_client: AsyncClient):
resp = await default_client.delete("/user?user_id=1")
resp = await default_client.delete("/user?user_id=1")
assert resp.status_code == 204
assert resp.status_code == 204
@pytest.mark.anyio
async def test_post_user_invitation_success(default_client: AsyncClient):
body = {"user_email": "admin@test.com", "organisation_id": 1}
resp = await default_client.post("/user/invitation", json=body)
body = {"user_email": "admin@test.com", "organisation_id": 1}
resp = await default_client.post("/user/invitation", json=body)
assert resp.status_code == 200
data = resp.json()
assert "organisation" in data
assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert resp.status_code == 200
data = resp.json()
assert "organisation" in data
assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "invited_email" in data
assert isinstance(data["invited_email"], str)
assert data["invited_email"] == "admin@test.com"
assert "invited_email" in data
assert isinstance(data["invited_email"], str)
assert data["invited_email"] == "admin@test.com"
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42, "user_email": "admin@test.com"}, 404),
({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422),
({"organisation_id": "", "user_email": "admin@test.com"}, 422),
({}, 422),
({"user_email": 42}, 422),
({"organisation_id": 1, "user_email": "Test User"}, 422),
],
"body, expected_status",
[
({"organisation_id": 42, "user_email": "admin@test.com"}, 404),
({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422),
({"organisation_id": "", "user_email": "admin@test.com"}, 422),
({}, 422),
({"user_email": 42}, 422),
({"organisation_id": 1, "user_email": "Test User"}, 422),
],
)
@pytest.mark.anyio
async def test_post_user_invitation_status_checks(
default_client: AsyncClient, body, expected_status
default_client: AsyncClient, body, expected_status
):
resp = await default_client.post("/user/invitation", json=body)
resp = await default_client.post("/user/invitation", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.parametrize(
"body, expected_status",
[
({"jwt": "invalid"}, 401),
({"jwt": ""}, 401),
({"jwt": None}, 422),
({"jwt": 42}, 422),
],
"body, expected_status",
[
({"jwt": "invalid"}, 401),
({"jwt": ""}, 401),
({"jwt": None}, 422),
({"jwt": 42}, 422),
],
)
@pytest.mark.anyio
async def test_post_user_invitation_accept_status_checks(
default_client: AsyncClient, body, expected_status
default_client: AsyncClient, body, expected_status
):
resp = await default_client.post("/user/invitation/accept", json=body)
resp = await default_client.post("/user/invitation/accept", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
if resp.status_code == 401:
assert resp.json()["detail"] == "Invalid JWS"
if resp.status_code == 401:
assert resp.json()["detail"] == "Invalid JWS"
@pytest.mark.anyio
async def test_get_self_orgs_success(default_client: AsyncClient):
resp = await default_client.get("/user/self/orgs")
assert resp.status_code == 200
resp = await default_client.get("/user/self/orgs")
assert resp.status_code == 200
data = resp.json()
data = resp.json()
assert "organisations" in data
assert isinstance(data["organisations"], list)
assert len(data["organisations"]) > 0
assert "organisations" in data
assert isinstance(data["organisations"], list)
assert len(data["organisations"]) > 0
org = data["organisations"][0]
assert org["organisation_id"] == 1
assert org["name"] == "Org One"
assert org["status"] == "approved"
assert org["root_user_email"] == "admin@test.com"
assert "intake_questionnaire" in org
assert isinstance(org["intake_questionnaire"], dict)
org = data["organisations"][0]
assert org["organisation_id"] == 1
assert org["name"] == "Org One"
assert org["status"] == "approved"
assert org["root_user_email"] == "admin@test.com"
assert "intake_questionnaire" in org
assert isinstance(org["intake_questionnaire"], dict)
assert isinstance(org["billing_contact"], dict)
assert org["billing_contact"]["email"] == "billing@orgone.com"
assert org["billing_contact"]["id"] == 1
assert isinstance(org["billing_contact"], dict)
assert org["billing_contact"]["email"] == "billing@orgone.com"
assert org["billing_contact"]["id"] == 1
assert isinstance(org["owner_contact"], dict)
assert org["owner_contact"]["email"] == "owner@orgone.com"
assert org["owner_contact"]["id"] == 2
assert isinstance(org["owner_contact"], dict)
assert org["owner_contact"]["email"] == "owner@orgone.com"
assert org["owner_contact"]["id"] == 2
assert isinstance(org["security_contact"], dict)
assert org["security_contact"]["email"] == "security@orgone.com"
assert org["security_contact"]["id"] == 3
assert isinstance(org["security_contact"], dict)
assert org["security_contact"]["email"] == "security@orgone.com"
assert org["security_contact"]["id"] == 3
@pytest.mark.anyio
async def test_get_self_orgs_dynamic(default_client: AsyncClient):
method = "GET"
path = "/user/self/orgs"
expected_data = {
"organisations": [
{
"organisation_id": 1,
"name": "Org One",
"status": "approved",
"root_user_email": "admin@test.com",
"owner_contact": {"email": "owner@orgone.com", "id": 2},
"security_contact": {"email": "security@orgone.com", "id": 3},
"billing_contact": {"email": "billing@orgone.com", "id": 1},
"intake_questionnaire": {
"questions": {
"question_one": None,
"question_three": None,
"question_two": "answer two",
},
"metadata": {"version": 0, "submission_date": None},
},
}
]
}
method = "GET"
path = "/user/self/orgs"
expected_data = {
"organisations": [
{
"organisation_id": 1,
"name": "Org One",
"status": "approved",
"root_user_email": "admin@test.com",
"owner_contact": {"email": "owner@orgone.com", "id": 2},
"security_contact": {"email": "security@orgone.com", "id": 3},
"billing_contact": {"email": "billing@orgone.com", "id": 1},
"intake_questionnaire": {
"questions": {
"question_one": None,
"question_three": None,
"question_two": "answer two",
},
"metadata": {"version": 0, "submission_date": None},
},
}
]
}
resp = await default_client.get(path)
resp = await default_client.get(path)
route = next(
route
for route in default_client._transport.app.routes # ty:ignore[unresolved-attribute]
if isinstance(route, APIRoute) and path in route.path and method in route.methods
)
route = next(
route
for route in default_client._transport.app.routes # ty:ignore[unresolved-attribute]
if isinstance(route, APIRoute) and path in route.path and method in route.methods
)
assert resp.status_code == route.status_code
if route.status_code == 204:
return
assert resp.status_code == route.status_code
if route.status_code == 204:
return
expected_response_schema = route.response_model
data = resp.json()
expected_response_schema = route.response_model
data = resp.json()
response_model = expected_response_schema(**data)
assert isinstance(response_model, expected_response_schema)
response_model = expected_response_schema(**data)
assert isinstance(response_model, expected_response_schema)
expected_response_model = expected_response_schema(**expected_data)
expected_response_model = expected_response_schema(**expected_data)
assert response_model == expected_response_model
assert response_model == expected_response_model