diff --git a/src/_module_template/router.py b/src/_module_template/router.py index 09250df..7d29b56 100644 --- a/src/_module_template/router.py +++ b/src/_module_template/router.py @@ -22,5 +22,5 @@ from fastapi import APIRouter router = APIRouter( - tags=[""], + tags=[""], ) diff --git a/src/admin/router.py b/src/admin/router.py index 9fe91eb..8742405 100644 --- a/src/admin/router.py +++ b/src/admin/router.py @@ -8,6 +8,6 @@ Exports: from fastapi import APIRouter router = APIRouter( - tags=["admin"], - prefix="/admin", + tags=["admin"], + prefix="/admin", ) diff --git a/src/api.py b/src/api.py index 0e46a53..9ad8b5d 100644 --- a/src/api.py +++ b/src/api.py @@ -26,15 +26,15 @@ api_router.include_router(iam_router) class HealthCheckResponse(CustomBaseModel): - status: str + status: str @api_router.get( - path="/healthcheck", - status_code=status.HTTP_200_OK, - response_model=HealthCheckResponse, - include_in_schema=False, + path="/healthcheck", + status_code=status.HTTP_200_OK, + response_model=HealthCheckResponse, + include_in_schema=False, ) def healthcheck(): - """Simple health check endpoint.""" - return {"status": "ok"} + """Simple health check endpoint.""" + return {"status": "ok"} diff --git a/src/auth/config.py b/src/auth/config.py index 030c36e..0e2734a 100644 --- a/src/auth/config.py +++ b/src/auth/config.py @@ -9,9 +9,9 @@ from src.config import CustomBaseSettings class AuthConfig(CustomBaseSettings): - OIDC_CONFIG: str = "" - OIDC_ISSUER: str = "" - CLIENT_ID: str = "" + OIDC_CONFIG: str = "" + OIDC_ISSUER: str = "" + CLIENT_ID: str = "" auth_settings = AuthConfig() diff --git a/src/auth/dependencies.py b/src/auth/dependencies.py index 7cf4e9f..65fb16d 100644 --- a/src/auth/dependencies.py +++ b/src/auth/dependencies.py @@ -16,92 +16,92 @@ from src.exceptions import ForbiddenException from src.user.dependencies import user_model_claims_dependency from src.user.models import User from src.organisation.dependencies import ( - org_model_query_dependency, - org_model_body_dependency, + org_model_query_dependency, + org_model_body_dependency, ) from src.organisation.models import Organisation as Org async def org_query_user_claims( - org_model: org_model_query_dependency, user_model: user_model_claims_dependency + org_model: org_model_query_dependency, user_model: user_model_claims_dependency ): - if user_model in org_model.user_rel: - return True + if user_model in org_model.user_rel: + return True - raise ForbiddenException() + raise ForbiddenException() org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)] def get_super_admin_list(): - return [] + return [] def empty_su_list(): - return [] + return [] def testing_su_list(): - return ["admin@test.com"] + return ["admin@test.com"] su_list_dependency = Annotated[list[str | None], Depends(get_super_admin_list)] async def user_model_super_admin( - user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency + user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency ): - if user_model.email in super_admin_emails: - return user_model + if user_model.email in super_admin_emails: + return user_model - raise ForbiddenException(message="Must be super admin") + raise ForbiddenException(message="Must be super admin") super_admin_dependency = Annotated[User, Depends(user_model_super_admin)] async def org_query_root_claims( - user_model: user_model_claims_dependency, - org_model: org_model_query_dependency, - su_emails: su_list_dependency, - request: Request, + user_model: user_model_claims_dependency, + org_model: org_model_query_dependency, + su_emails: su_list_dependency, + request: Request, ): - try: - if await user_model_super_admin(user_model, su_emails): - return org_model - except ForbiddenException: - pass + try: + if await user_model_super_admin(user_model, su_emails): + return org_model + except ForbiddenException: + pass - await org_status_check(org_model, request) + await org_status_check(org_model, request) - if org_model.root_user_id == user_model.id: - return org_model + if org_model.root_user_id == user_model.id: + return org_model - raise ForbiddenException(message="Must be the org's root user") + raise ForbiddenException(message="Must be the org's root user") org_model_root_claim_query_dependency = Annotated[Org, Depends(org_query_root_claims)] async def org_body_root_claims( - user_model: user_model_claims_dependency, - org_model: org_model_body_dependency, - su_emails: su_list_dependency, - request: Request, + user_model: user_model_claims_dependency, + org_model: org_model_body_dependency, + su_emails: su_list_dependency, + request: Request, ): - try: - if await user_model_super_admin(user_model, su_emails): - return org_model - except ForbiddenException: - pass + try: + if await user_model_super_admin(user_model, su_emails): + return org_model + except ForbiddenException: + pass - await org_status_check(org_model, request) + await org_status_check(org_model, request) - if org_model.root_user_id == user_model.id: - return org_model + if org_model.root_user_id == user_model.id: + return org_model - raise ForbiddenException(message="Must be the org's root user") + raise ForbiddenException(message="Must be the org's root user") org_model_root_claim_body_dependency = Annotated[Org, Depends(org_body_root_claims)] diff --git a/src/auth/router.py b/src/auth/router.py index ee32033..b472586 100644 --- a/src/auth/router.py +++ b/src/auth/router.py @@ -8,5 +8,5 @@ Exports: from fastapi import APIRouter router = APIRouter( - tags=["auth"], + tags=["auth"], ) diff --git a/src/auth/service.py b/src/auth/service.py index 4b42590..1808341 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -31,56 +31,56 @@ oidc_dependency = Annotated[str, Depends(oidc)] async def get_dev_user(): - return {"db_id": 1, "email": "chris@sr2.uk"} + return {"db_id": 1, "email": "chris@sr2.uk"} async def get_current_user( - oidc_auth_string: oidc_dependency, db: DbSession + oidc_auth_string: oidc_dependency, db: DbSession ) -> dict[str, Any]: - config_url = urlopen(auth_settings.OIDC_CONFIG) - config = json.loads(config_url.read()) - jwks_uri = config["jwks_uri"] - key_response = requests.get(jwks_uri) - jwk_keys = KeySet.import_key_set(key_response.json()) + config_url = urlopen(auth_settings.OIDC_CONFIG) + config = json.loads(config_url.read()) + jwks_uri = config["jwks_uri"] + key_response = requests.get(jwks_uri) + jwk_keys = KeySet.import_key_set(key_response.json()) - token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys) + token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys) - claims_requests = jwt.JWTClaimsRegistry( - exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER} - ) + claims_requests = jwt.JWTClaimsRegistry( + exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER} + ) - try: - claims_requests.validate(token.claims) - except ExpiredTokenError: - raise UnauthorizedException(message="Token is expired") - db_id = await add_user(db, token.claims) + try: + claims_requests.validate(token.claims) + except ExpiredTokenError: + raise UnauthorizedException(message="Token is expired") + db_id = await add_user(db, token.claims) - token.claims["db_id"] = db_id + token.claims["db_id"] = db_id - return token.claims + return token.claims claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] async def org_status_check(org_model: Org, request: Request): - org_status = OrgStatus(org_model.status) - if org_status.is_blocked: - raise ForbiddenException("This organisation cannot perform this action.") + org_status = OrgStatus(org_model.status) + if org_status.is_blocked: + raise ForbiddenException("This organisation cannot perform this action.") - root = "/api/v1" + root = "/api/v1" - pre_approval_endpoints = [ - f"PATCH{root}/org/status", - f"PATCH{root}/org/questionnaire", - f"GET{root}/org", - f"GET{root}/org/contact", - f"PATCH{root}/org/contact", - f"DELETE{root}/org/self", - ] - current_request = f"{request.method}{request.url.path}" - if ( - current_request not in pre_approval_endpoints - and org_model.status != OrgStatus.APPROVED - ): - raise AwaitingApprovalException(org_model.id) + pre_approval_endpoints = [ + f"PATCH{root}/org/status", + f"PATCH{root}/org/questionnaire", + f"GET{root}/org", + f"GET{root}/org/contact", + f"PATCH{root}/org/contact", + f"DELETE{root}/org/self", + ] + current_request = f"{request.method}{request.url.path}" + if ( + current_request not in pre_approval_endpoints + and org_model.status != OrgStatus.APPROVED + ): + raise AwaitingApprovalException(org_model.id) diff --git a/src/config.py b/src/config.py index 6fc68f0..1b46118 100644 --- a/src/config.py +++ b/src/config.py @@ -16,31 +16,31 @@ from src.constants import Environment class CustomBaseSettings(BaseSettings): - model_config = SettingsConfigDict( - env_file=".env", env_file_encoding="utf-8", extra="ignore" - ) + model_config = SettingsConfigDict( + env_file=".env", env_file_encoding="utf-8", extra="ignore" + ) class Config(CustomBaseSettings): - APP_VERSION: str = "0.1" - ENVIRONMENT: Environment = Environment.PRODUCTION - SECRET_KEY: SecretStr = SecretStr("") - DISABLE_AUTH: bool = False + APP_VERSION: str = "0.1" + ENVIRONMENT: Environment = Environment.PRODUCTION + SECRET_KEY: SecretStr = SecretStr("") + DISABLE_AUTH: bool = False - CORS_ORIGINS: list[str] = ["*"] - CORS_ORIGINS_REGEX: str | None = None - CORS_HEADERS: list[str] = ["*"] + CORS_ORIGINS: list[str] = ["*"] + CORS_ORIGINS_REGEX: str | None = None + CORS_HEADERS: list[str] = ["*"] - DATABASE_NAME: str = "fastapi-exp" - DATABASE_PORT: str = "5432" - DATABASE_HOSTNAME: str = "localhost" - DATABASE_CREDENTIALS: SecretStr = SecretStr(":") + DATABASE_NAME: str = "fastapi-exp" + DATABASE_PORT: str = "5432" + DATABASE_HOSTNAME: str = "localhost" + DATABASE_CREDENTIALS: SecretStr = SecretStr(":") - DATABASE_POOL_SIZE: int = 16 - DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes - DATABASE_POOL_PRE_PING: bool = True + DATABASE_POOL_SIZE: int = 16 + DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes + DATABASE_POOL_PRE_PING: bool = True - LETTERMINT_API_TOKEN: SecretStr = SecretStr("") + LETTERMINT_API_TOKEN: SecretStr = SecretStr("") settings = Config() @@ -51,20 +51,20 @@ DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value() # this will support special chars for credentials _DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split( - ":" + ":" ) _QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD)) SQLALCHEMY_DATABASE_URI = SecretStr( - f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}" + f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}" ) if settings.ENVIRONMENT == Environment.TESTING: - SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:") + SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:") app_configs: dict[str, Any] = {"title": "App API"} if settings.ENVIRONMENT.is_deployed: - app_configs["root_path"] = f"/v{settings.APP_VERSION}" + app_configs["root_path"] = f"/v{settings.APP_VERSION}" if not settings.ENVIRONMENT.is_debug: - app_configs["openapi_url"] = None # hide docs + app_configs["openapi_url"] = None # hide docs diff --git a/src/constants.py b/src/constants.py index f9237e4..d18e300 100644 --- a/src/constants.py +++ b/src/constants.py @@ -9,29 +9,29 @@ from enum import StrEnum, auto class Environment(StrEnum): - """ - Enumeration of environments. + """ + Enumeration of environments. - Attributes: - LOCAL (str): Application is running locally - TESTING (str): Application is running in testing mode - STAGING (str): Application is running in staging mode (ie not testing) - PRODUCTION (str): Application is running in production mode - """ + Attributes: + LOCAL (str): Application is running locally + TESTING (str): Application is running in testing mode + STAGING (str): Application is running in staging mode (ie not testing) + PRODUCTION (str): Application is running in production mode + """ - LOCAL = auto() - TESTING = auto() - STAGING = auto() - PRODUCTION = auto() + LOCAL = auto() + TESTING = auto() + STAGING = auto() + PRODUCTION = auto() - @property - def is_debug(self): - return self in (self.LOCAL, self.STAGING, self.TESTING) + @property + def is_debug(self): + return self in (self.LOCAL, self.STAGING, self.TESTING) - @property - def is_testing(self): - return self == self.TESTING + @property + def is_testing(self): + return self == self.TESTING - @property - def is_deployed(self) -> bool: - return self in (self.STAGING, self.PRODUCTION) + @property + def is_deployed(self) -> bool: + return self in (self.STAGING, self.PRODUCTION) diff --git a/src/contact/exceptions.py b/src/contact/exceptions.py index 55e9e30..7c36941 100644 --- a/src/contact/exceptions.py +++ b/src/contact/exceptions.py @@ -11,13 +11,13 @@ from fastapi import HTTPException, status class ContactNotFoundException(HTTPException): - def __init__(self, contact_id: Optional[int] = None) -> None: - detail = ( - "Contact not found" - if contact_id is None - else f"Contact with ID '{contact_id}' was not found." - ) - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - ) + def __init__(self, contact_id: Optional[int] = None) -> None: + detail = ( + "Contact not found" + if contact_id is None + else f"Contact with ID '{contact_id}' was not found." + ) + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) diff --git a/src/contact/models.py b/src/contact/models.py index 37665e5..4bf3a79 100644 --- a/src/contact/models.py +++ b/src/contact/models.py @@ -15,22 +15,22 @@ from src.models import CustomBase class Contact(CustomBase, IdMixin): - __tablename__ = "contact" + __tablename__ = "contact" - email: Mapped[str] = mapped_column(default=None, nullable=True) - first_name: Mapped[str] = mapped_column(default=None, nullable=True) - last_name: Mapped[str] = mapped_column(default=None, nullable=True) - phonenumber: Mapped[str] = mapped_column(default=None, nullable=True) - vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True) + email: Mapped[str] = mapped_column(default=None, nullable=True) + first_name: Mapped[str] = mapped_column(default=None, nullable=True) + last_name: Mapped[str] = mapped_column(default=None, nullable=True) + phonenumber: Mapped[str] = mapped_column(default=None, nullable=True) + vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True) - street_address: Mapped[str] = mapped_column(default=None, nullable=True) - street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True) - post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True) - locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City - country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB - address_region: Mapped[str | None] = mapped_column(default=None, nullable=True) - postal_code: Mapped[str] = mapped_column(default=None, nullable=True) + street_address: Mapped[str] = mapped_column(default=None, nullable=True) + street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True) + post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True) + locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City + country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB + address_region: Mapped[str | None] = mapped_column(default=None, nullable=True) + postal_code: Mapped[str] = mapped_column(default=None, nullable=True) - org_id: Mapped[int] = mapped_column( - ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False - ) + org_id: Mapped[int] = mapped_column( + ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False + ) diff --git a/src/contact/router.py b/src/contact/router.py index 2e5f8f4..6f2376b 100644 --- a/src/contact/router.py +++ b/src/contact/router.py @@ -6,6 +6,6 @@ from fastapi import APIRouter router = APIRouter( - prefix="/contact", - tags=["contact"], + prefix="/contact", + tags=["contact"], ) diff --git a/src/contact/schemas.py b/src/contact/schemas.py index b9cec61..f9d4635 100644 --- a/src/contact/schemas.py +++ b/src/contact/schemas.py @@ -14,22 +14,22 @@ from src.schemas import CustomBaseModel class ContactAddress(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") + model_config = ConfigDict(from_attributes=True, extra="ignore") - post_office_box_number: Optional[str] = None - street_address: Optional[str] = None - street_address_line_2: Optional[str] = None - locality: Optional[str] = None - address_region: Optional[str] = None - country_code: Optional[str] = None - postal_code: Optional[str] = None + post_office_box_number: Optional[str] = None + street_address: Optional[str] = None + street_address_line_2: Optional[str] = None + locality: Optional[str] = None + address_region: Optional[str] = None + country_code: Optional[str] = None + postal_code: Optional[str] = None class ContactModel(CustomBaseModel): - email: Optional[EmailStr] = None - first_name: Optional[str] = None - last_name: Optional[str] = None - phonenumber: Optional[str] = None - vat_number: Optional[str] = None + email: Optional[EmailStr] = None + first_name: Optional[str] = None + last_name: Optional[str] = None + phonenumber: Optional[str] = None + vat_number: Optional[str] = None - address: ContactAddress + address: ContactAddress diff --git a/src/database.py b/src/database.py index 41509d4..01c55c8 100644 --- a/src/database.py +++ b/src/database.py @@ -1,6 +1,7 @@ """ Database connection and session utilities """ + from contextlib import contextmanager from typing import Annotated, Generator from sqlalchemy import create_engine, StaticPool, Connection @@ -29,6 +30,7 @@ else: sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine) + @contextmanager def get_db_connection() -> Generator[Connection, None, None]: with engine.connect() as connection: @@ -38,12 +40,15 @@ def get_db_connection() -> Generator[Connection, None, None]: connection.rollback() raise + def _get_db_connection() -> Generator[Connection, None]: with get_db_connection() as connection: yield connection + DbConnection = Annotated[Connection, Depends(_get_db_connection)] + @contextmanager def get_db_session() -> Generator[Session, None, None]: session = sm() @@ -60,4 +65,5 @@ def _get_db_session() -> Generator[Session, None]: with get_db_session() as session: yield session + DbSession = Annotated[Session, Depends(_get_db_session)] diff --git a/src/exceptions.py b/src/exceptions.py index 1f5fdcb..f34ad35 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -12,36 +12,36 @@ from fastapi import HTTPException, status class UnprocessableContentException(HTTPException): - def __init__(self, message: Optional[str] = None) -> None: - detail = "Unprocessable content" if not message else message - super().__init__( - status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, - detail=detail, - ) + def __init__(self, message: Optional[str] = None) -> None: + detail = "Unprocessable content" if not message else message + super().__init__( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail=detail, + ) class ConflictException(HTTPException): - def __init__(self, message: Optional[str] = None) -> None: - detail = "Conflict" if not message else message - super().__init__( - status_code=status.HTTP_409_CONFLICT, - detail=detail, - ) + def __init__(self, message: Optional[str] = None) -> None: + detail = "Conflict" if not message else message + super().__init__( + status_code=status.HTTP_409_CONFLICT, + detail=detail, + ) class ForbiddenException(HTTPException): - def __init__(self, message: Optional[str] = None) -> None: - detail = "Forbidden" if not message else message - super().__init__( - status_code=status.HTTP_403_FORBIDDEN, - detail=detail, - ) + def __init__(self, message: Optional[str] = None) -> None: + detail = "Forbidden" if not message else message + super().__init__( + status_code=status.HTTP_403_FORBIDDEN, + detail=detail, + ) class UnauthorizedException(HTTPException): - def __init__(self, message: Optional[str] = None) -> None: - detail = "Not authorized" if not message else message - super().__init__( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=detail, - ) + def __init__(self, message: Optional[str] = None) -> None: + detail = "Not authorized" if not message else message + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=detail, + ) diff --git a/src/iam/dependencies.py b/src/iam/dependencies.py index 113a3c6..e559bc1 100644 --- a/src/iam/dependencies.py +++ b/src/iam/dependencies.py @@ -18,59 +18,55 @@ from src.iam.exceptions import GroupNotFoundException, PermNotFoundException from src.iam.schemas import GroupIDMixin, PermIDMixin -def get_group_model_query( - db: DbSession, group_id: Annotated[int, Query(gt=0)] -) -> Group: - group_model = db.get(Group, group_id) - if group_model is None: - raise GroupNotFoundException(group_id) +def get_group_model_query(db: DbSession, group_id: Annotated[int, Query(gt=0)]) -> Group: + group_model = db.get(Group, group_id) + if group_model is None: + raise GroupNotFoundException(group_id) - return group_model + return group_model group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)] def get_group_model_body( - db: DbSession, request_model: Optional[GroupIDMixin] = None + db: DbSession, request_model: Optional[GroupIDMixin] = None ) -> Group: - group_id = getattr(request_model, "group_id", None) - if group_id is None: - raise GroupNotFoundException() - group_model = db.get(Group, group_id) - if group_model is None: - raise GroupNotFoundException(group_id) + group_id = getattr(request_model, "group_id", None) + if group_id is None: + raise GroupNotFoundException() + group_model = db.get(Group, group_id) + if group_model is None: + raise GroupNotFoundException(group_id) - return group_model + return group_model group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)] def get_perm_model_body( - db: DbSession, request_model: Optional[PermIDMixin] = None + db: DbSession, request_model: Optional[PermIDMixin] = None ) -> Permission: - perm_id = getattr(request_model, "permission_id", None) - if perm_id is None: - raise PermNotFoundException - perm_model = db.get(Permission, perm_id) - if perm_model is None: - raise PermNotFoundException(perm_id) + perm_id = getattr(request_model, "permission_id", None) + if perm_id is None: + raise PermNotFoundException + perm_model = db.get(Permission, perm_id) + if perm_model is None: + raise PermNotFoundException(perm_id) - return perm_model + return perm_model perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)] -def get_perm_model_query( - db: DbSession, perm_id: Annotated[int, Query(gt=0)] -) -> Permission: - perm_model = db.get(Permission, perm_id) - if perm_model is None: - raise PermNotFoundException(perm_id) +def get_perm_model_query(db: DbSession, perm_id: Annotated[int, Query(gt=0)]) -> Permission: + perm_model = db.get(Permission, perm_id) + if perm_model is None: + raise PermNotFoundException(perm_id) - return perm_model + return perm_model perm_model_query_dependency = Annotated[Permission, Depends(get_perm_model_query)] diff --git a/src/iam/exceptions.py b/src/iam/exceptions.py index 503b844..7d38887 100644 --- a/src/iam/exceptions.py +++ b/src/iam/exceptions.py @@ -12,26 +12,26 @@ from fastapi import HTTPException, status class GroupNotFoundException(HTTPException): - def __init__(self, group_id: Optional[int] = None) -> None: - detail = ( - "Group not found" - if group_id is None - else f"User with ID '{group_id}' was not found." - ) - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - ) + def __init__(self, group_id: Optional[int] = None) -> None: + detail = ( + "Group not found" + if group_id is None + else f"User with ID '{group_id}' was not found." + ) + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) class PermNotFoundException(HTTPException): - def __init__(self, perm_id: Optional[int] = None) -> None: - detail = ( - "Permission not found" - if perm_id is None - else f"User with ID '{perm_id}' was not found." - ) - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - ) + def __init__(self, perm_id: Optional[int] = None) -> None: + detail = ( + "Permission not found" + if perm_id is None + else f"User with ID '{perm_id}' was not found." + ) + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) diff --git a/src/iam/models.py b/src/iam/models.py index 7fdb7c7..9a0fc36 100644 --- a/src/iam/models.py +++ b/src/iam/models.py @@ -25,90 +25,90 @@ from src.models import CustomBase, IdMixin class Permission(CustomBase, IdMixin): - __tablename__ = "permission" + __tablename__ = "permission" - resource: Mapped[str] - action: Mapped[str] + resource: Mapped[str] + action: Mapped[str] - service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE")) + service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE")) - __table_args__ = ( - UniqueConstraint( - "service_id", - "resource", - "action", - name="uniq_permission_resource_and_action", - ), - ) + __table_args__ = ( + UniqueConstraint( + "service_id", + "resource", + "action", + name="uniq_permission_resource_and_action", + ), + ) - service_rel = relationship( - "Service", - back_populates="permission_rel", - foreign_keys="Permission.service_id", - ) + service_rel = relationship( + "Service", + back_populates="permission_rel", + foreign_keys="Permission.service_id", + ) - group_rel = relationship( - "Group", secondary="group_permissions", back_populates="permission_rel" - ) + group_rel = relationship( + "Group", secondary="group_permissions", back_populates="permission_rel" + ) - org_rel = relationship( - "Organisation", secondary="org_permissions", back_populates="permission_rel" - ) + org_rel = relationship( + "Organisation", secondary="org_permissions", back_populates="permission_rel" + ) - @property - def service_name(self): - return self.service_rel.name + @property + def service_name(self): + return self.service_rel.name class Group(CustomBase, IdMixin): - __tablename__ = "group" + __tablename__ = "group" - name: Mapped[str] + name: Mapped[str] - org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE")) + org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE")) - __table_args__ = ( - UniqueConstraint( - "name", - "org_id", - name="uniq_group_name_org_id", - ), - ) + __table_args__ = ( + UniqueConstraint( + "name", + "org_id", + name="uniq_group_name_org_id", + ), + ) - user_rel = relationship("User", secondary="user_groups", back_populates="group_rel") + user_rel = relationship("User", secondary="user_groups", back_populates="group_rel") - org_rel = relationship("Organisation", back_populates="group_rel") + org_rel = relationship("Organisation", back_populates="group_rel") - permission_rel = relationship( - "Permission", secondary="group_permissions", back_populates="group_rel" - ) + permission_rel = relationship( + "Permission", secondary="group_permissions", back_populates="group_rel" + ) class GroupPermissions(CustomBase): - __tablename__ = "group_permissions" - group_id: Mapped[int] = mapped_column( - ForeignKey("group.id", ondelete="CASCADE"), primary_key=True - ) - permission_id: Mapped[int] = mapped_column( - ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True - ) + __tablename__ = "group_permissions" + group_id: Mapped[int] = mapped_column( + ForeignKey("group.id", ondelete="CASCADE"), primary_key=True + ) + permission_id: Mapped[int] = mapped_column( + ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True + ) class UserGroups(CustomBase): - __tablename__ = "user_groups" - user_id: Mapped[int] = mapped_column( - ForeignKey("user.id", ondelete="CASCADE"), primary_key=True - ) - group_id: Mapped[int] = mapped_column( - ForeignKey("group.id", ondelete="CASCADE"), primary_key=True - ) + __tablename__ = "user_groups" + user_id: Mapped[int] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), primary_key=True + ) + group_id: Mapped[int] = mapped_column( + ForeignKey("group.id", ondelete="CASCADE"), primary_key=True + ) class OrgPermissions(CustomBase): - __tablename__ = "org_permissions" - org_id: Mapped[int] = mapped_column( - ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True - ) - permission_id: Mapped[int] = mapped_column( - ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True - ) + __tablename__ = "org_permissions" + org_id: Mapped[int] = mapped_column( + ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True + ) + permission_id: Mapped[int] = mapped_column( + ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True + ) diff --git a/src/iam/router.py b/src/iam/router.py index 1b2c297..2d50b36 100644 --- a/src/iam/router.py +++ b/src/iam/router.py @@ -28,674 +28,672 @@ from src.organisation.exceptions import OrgNotFoundException from src.schemas import GroupSummary, OrgSummary from src.service.dependencies import service_model_body_dependency from src.exceptions import ( - ConflictException, - ForbiddenException, - UnprocessableContentException, + ConflictException, + ForbiddenException, + UnprocessableContentException, ) from src.database import DbSession from src.auth.service import claims_dependency from src.auth.dependencies import ( - org_model_root_claim_query_dependency, - org_model_root_claim_body_dependency, - super_admin_dependency, + org_model_root_claim_query_dependency, + org_model_root_claim_body_dependency, + super_admin_dependency, ) from src.user.models import User from src.user.dependencies import ( - user_model_body_dependency, - user_model_query_dependency, - user_model_claims_dependency, + user_model_body_dependency, + user_model_query_dependency, + user_model_claims_dependency, ) from src.organisation.models import Organisation as Org from src.service.models import Service from src.iam.service import service_key_dependency, send_user_group_invitation from src.iam.models import ( - Permission as Perm, - GroupPermissions as GPerms, - Group, - UserGroups, + Permission as Perm, + GroupPermissions as GPerms, + Group, + UserGroups, ) from src.iam.dependencies import ( - group_model_query_dependency, - group_model_body_dependency, - perm_model_body_dependency, - perm_model_query_dependency, + group_model_query_dependency, + group_model_body_dependency, + perm_model_body_dependency, + perm_model_query_dependency, ) from src.iam.schemas import ( - IAMCAoRRequest, - IAMGetGroupPermissionsResponse, - IAMGetGroupUsersResponse, - IAMPostGroupRequest, - IAMPostGroupResponse, - IAMPutGroupPermissionRequest, - IAMPutGroupPermissionResponse, - IAMPutGroupUserRequest, - IAMPutGroupUserResponse, - IAMDeleteGroupPermissionResponse, - IAMDeleteGroupUserResponse, - IAMGetPermissionsResponse, - IAMPostPermissionRequest, - IAMPostPermissionResponse, - IAMGetPermissionsSearchRequest, - IAMGetPermissionsSearchResponse, - IAMPutGroupInvitationRequest, - IAMPutGroupInvitationAcceptRequest, - IAMCAoRResponse, - IAMPutGroupInvitationAcceptResponse, - IAMPutGroupInvitationResponse, - IAMPutOrgPermissionsRequest, - IAMPutOrgPermissionsResponse, + IAMCAoRRequest, + IAMGetGroupPermissionsResponse, + IAMGetGroupUsersResponse, + IAMPostGroupRequest, + IAMPostGroupResponse, + IAMPutGroupPermissionRequest, + IAMPutGroupPermissionResponse, + IAMPutGroupUserRequest, + IAMPutGroupUserResponse, + IAMDeleteGroupPermissionResponse, + IAMDeleteGroupUserResponse, + IAMGetPermissionsResponse, + IAMPostPermissionRequest, + IAMPostPermissionResponse, + IAMGetPermissionsSearchRequest, + IAMGetPermissionsSearchResponse, + IAMPutGroupInvitationRequest, + IAMPutGroupInvitationAcceptRequest, + IAMCAoRResponse, + IAMPutGroupInvitationAcceptResponse, + IAMPutGroupInvitationResponse, + IAMPutOrgPermissionsRequest, + IAMPutOrgPermissionsResponse, ) from src.utils import verify_email_token router = APIRouter( - tags=["IAM"], - prefix="/iam", + tags=["IAM"], + prefix="/iam", ) @router.post( - path="/can_act_on_resource", - summary="Used for services to check user access permission", - status_code=status.HTTP_200_OK, - response_model=IAMCAoRResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "API Key missing or invalid | Issue verifying user OIDC claims" - }, - }, + path="/can_act_on_resource", + summary="Used for services to check user access permission", + status_code=status.HTTP_200_OK, + response_model=IAMCAoRResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "API Key missing or invalid | Issue verifying user OIDC claims" + }, + }, ) async def can_act_on_resource( - valid_key: service_key_dependency, - db: DbSession, - user_claims: claims_dependency, - request_model: IAMCAoRRequest, + valid_key: service_key_dependency, + db: DbSession, + user_claims: claims_dependency, + request_model: IAMCAoRRequest, ): - """ - This endpoint is not meant for the Hub frontend to interact with.\n - Services accessing this endpoint must be already registered within the Hub and been issued an API key.\n - Resource Names have an instance property but permissions do not presently have that level of granularity.\n - """ - response = { - "allowed": False, - "rn": request_model.rn, - "action": "", - "user": {"id": 0, "email": ""}, - } + """ + This endpoint is not meant for the Hub frontend to interact with.\n + Services accessing this endpoint must be already registered within the Hub and been issued an API key.\n + Resource Names have an instance property but permissions do not presently have that level of granularity.\n + """ + response = { + "allowed": False, + "rn": request_model.rn, + "action": "", + "user": {"id": 0, "email": ""}, + } - try: - rn = request_model.rn - action = request_model.action - user_id = user_claims["db_id"] - rn_org = rn.organisation_id - rn_service = rn.service - rn_resource = rn.resource + try: + rn = request_model.rn + action = request_model.action + user_id = user_claims["db_id"] + rn_org = rn.organisation_id + rn_service = rn.service + rn_resource = rn.resource - response["user"] = {"id": user_id, "email": user_claims["email"]} - response["action"] = action - response["rn"] = rn + response["user"] = {"id": user_id, "email": user_claims["email"]} + response["action"] = action + response["rn"] = rn - result = ( - db.query(Perm) - .join(Service, Service.id == Perm.service_id) - .join(GPerms, GPerms.permission_id == Perm.id) - .join(Group, Group.id == GPerms.group_id) - .join(Org, Org.id == Group.org_id) - .join(UserGroups, UserGroups.group_id == Group.id) - .join(User, User.id == UserGroups.user_id) - .filter(User.id == user_id) - .filter(Org.id == rn_org) - .filter(Service.name == rn_service) - .filter(Perm.resource == rn_resource) - .filter(Perm.action == action) - ).first() + result = ( + db.query(Perm) + .join(Service, Service.id == Perm.service_id) + .join(GPerms, GPerms.permission_id == Perm.id) + .join(Group, Group.id == GPerms.group_id) + .join(Org, Org.id == Group.org_id) + .join(UserGroups, UserGroups.group_id == Group.id) + .join(User, User.id == UserGroups.user_id) + .filter(User.id == user_id) + .filter(Org.id == rn_org) + .filter(Service.name == rn_service) + .filter(Perm.resource == rn_resource) + .filter(Perm.action == action) + ).first() - if result: - response["allowed"] = True - else: - response["allowed"] = False - except Exception as e: - print(e) - response["allowed"] = False + if result: + response["allowed"] = True + else: + response["allowed"] = False + except Exception as e: + print(e) + response["allowed"] = False - return response + return response @router.get( - path="/group/permissions", - summary="Gets a list of permissions granted to a group", - status_code=status.HTTP_200_OK, - response_model=IAMGetGroupPermissionsResponse, - responses={ - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Unprocessable content.", - "content": { - "application/json": { - "examples": { - "org_id": {"summary": "Invalid or missing org ID."}, - "oidc_claims": {"summary": "Invalid or missing OIDC claims."}, - } - } - }, - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Unauthorized", - "content": { - "application/json": { - "examples": { - "awaiting_approval": { - "summary": "Organisation has not yet been approved." - }, - "expired_token": {"summary": "User token has expired."}, - "oidc": {"summary": "Failed to verify OIDC claims."}, - } - } - }, - }, - status.HTTP_403_FORBIDDEN: { - "description": "Forbidden", - "content": { - "application/json": { - "examples": { - "not_root": {"summary": "Not authorised. Must be root user."}, - } - } - }, - }, - status.HTTP_404_NOT_FOUND: { - "description": "Not found", - "content": { - "application/json": { - "examples": { - "db_id": {"summary": "User not found in db when checking claims."}, - "user_model": {"summary": "User model not found in db."}, - "org_model": {"summary": "Org model not found in db."}, - "group_model": {"summary": "Group model not found in db."}, - } - } - }, - }, - }, + path="/group/permissions", + summary="Gets a list of permissions granted to a group", + status_code=status.HTTP_200_OK, + response_model=IAMGetGroupPermissionsResponse, + responses={ + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Unprocessable content.", + "content": { + "application/json": { + "examples": { + "org_id": {"summary": "Invalid or missing org ID."}, + "oidc_claims": {"summary": "Invalid or missing OIDC claims."}, + } + } + }, + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Unauthorized", + "content": { + "application/json": { + "examples": { + "awaiting_approval": { + "summary": "Organisation has not yet been approved." + }, + "expired_token": {"summary": "User token has expired."}, + "oidc": {"summary": "Failed to verify OIDC claims."}, + } + } + }, + }, + status.HTTP_403_FORBIDDEN: { + "description": "Forbidden", + "content": { + "application/json": { + "examples": { + "not_root": {"summary": "Not authorised. Must be root user."}, + } + } + }, + }, + status.HTTP_404_NOT_FOUND: { + "description": "Not found", + "content": { + "application/json": { + "examples": { + "db_id": {"summary": "User not found in db when checking claims."}, + "user_model": {"summary": "User model not found in db."}, + "org_model": {"summary": "Org model not found in db."}, + "group_model": {"summary": "Group model not found in db."}, + } + } + }, + }, + }, ) async def get_group_permissions( - group_model: group_model_query_dependency, - org_model: org_model_root_claim_query_dependency, + group_model: group_model_query_dependency, + org_model: org_model_root_claim_query_dependency, ): - """ - Gets a list of permissions granted to the group. Also returns a summary for the org and group. - """ - if group_model.org_id != org_model.id: - raise ForbiddenException("Group does not belong to this organization") - return { - "organisation": org_model, - "group": group_model, - "permissions": group_model.permission_rel, - } + """ + Gets a list of permissions granted to the group. Also returns a summary for the org and group. + """ + if group_model.org_id != org_model.id: + raise ForbiddenException("Group does not belong to this organization") + return { + "organisation": org_model, + "group": group_model, + "permissions": group_model.permission_rel, + } @router.get( - path="/group/users", - summary="Gets a list of users assigned to a group", - status_code=status.HTTP_200_OK, - response_model=IAMGetGroupUsersResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "Group does not belong to this organization" - }, - }, + path="/group/users", + summary="Gets a list of users assigned to a group", + status_code=status.HTTP_200_OK, + response_model=IAMGetGroupUsersResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "Group does not belong to this organization" + }, + }, ) async def get_group_users( - group_model: group_model_query_dependency, - org_model: org_model_root_claim_query_dependency, + group_model: group_model_query_dependency, + org_model: org_model_root_claim_query_dependency, ): - """ - Gets a list of users assigned to the group. Also returns a summary for the org and group. - """ - if group_model.org_id != org_model.id: - raise ForbiddenException("Group does not belong to this organization") - return { - "organisation": org_model, - "group": group_model, - "users": group_model.user_rel, - } + """ + Gets a list of users assigned to the group. Also returns a summary for the org and group. + """ + if group_model.org_id != org_model.id: + raise ForbiddenException("Group does not belong to this organization") + return { + "organisation": org_model, + "group": group_model, + "users": group_model.user_rel, + } @router.post( - path="/group", - summary="Creates a new group", - status_code=status.HTTP_201_CREATED, - response_model=IAMPostGroupResponse, - responses={ - status.HTTP_409_CONFLICT: {"description": "Group with this name already exists"}, - }, + path="/group", + summary="Creates a new group", + status_code=status.HTTP_201_CREATED, + response_model=IAMPostGroupResponse, + responses={ + status.HTTP_409_CONFLICT: {"description": "Group with this name already exists"}, + }, ) async def create_group( - db: DbSession, - org_model: org_model_root_claim_body_dependency, - request_model: IAMPostGroupRequest, + db: DbSession, + org_model: org_model_root_claim_body_dependency, + request_model: IAMPostGroupRequest, ): - """ - Creates a new IAM group. - """ - group_model = Group(name=request_model.name, org_id=org_model.id) + """ + Creates a new IAM group. + """ + group_model = Group(name=request_model.name, org_id=org_model.id) - db.add(group_model) - try: - db.flush() - except IntegrityError as e: - if ( - isinstance(e.orig, UniqueViolation) # Postgres unique violation - or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation - ): - raise ConflictException("Group with this name already exists") - raise - group_response = GroupSummary(**group_model.__dict__) - org_response = OrgSummary(**org_model.__dict__) - db.commit() - return {"group": group_response, "organisation": org_response} + db.add(group_model) + try: + db.flush() + except IntegrityError as e: + if ( + isinstance(e.orig, UniqueViolation) # Postgres unique violation + or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation + ): + raise ConflictException("Group with this name already exists") + raise + group_response = GroupSummary(**group_model.__dict__) + org_response = OrgSummary(**org_model.__dict__) + db.commit() + return {"group": group_response, "organisation": org_response} @router.put( - path="/group/permission", - summary="Grants a permission to a group", - status_code=status.HTTP_200_OK, - response_model=IAMPutGroupPermissionResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "Group does not belong to this organization" - }, - status.HTTP_409_CONFLICT: { - "description": "This permission is already granted to this group" - }, - }, + path="/group/permission", + summary="Grants a permission to a group", + status_code=status.HTTP_200_OK, + response_model=IAMPutGroupPermissionResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "Group does not belong to this organization" + }, + status.HTTP_409_CONFLICT: { + "description": "This permission is already granted to this group" + }, + }, ) async def add_group_permission( - db: DbSession, - group_model: group_model_body_dependency, - perm_model: perm_model_body_dependency, - org_model: org_model_root_claim_body_dependency, - request_model: IAMPutGroupPermissionRequest, + db: DbSession, + group_model: group_model_body_dependency, + perm_model: perm_model_body_dependency, + org_model: org_model_root_claim_body_dependency, + request_model: IAMPutGroupPermissionRequest, ): - """ - Grants a permission to a group. Returns a list of the permissions in the group as well as a summary for the org and group. - """ - if group_model.org_id != org_model.id: - raise ForbiddenException("Group does not belong to this organization") + """ + Grants a permission to a group. Returns a list of the permissions in the group as well as a summary for the org and group. + """ + if group_model.org_id != org_model.id: + raise ForbiddenException("Group does not belong to this organization") - if perm_model in group_model.permission_rel: - raise ConflictException("Group already has this permission") + if perm_model in group_model.permission_rel: + raise ConflictException("Group already has this permission") - if perm_model not in org_model.permission_rel: # TODO: and not su - raise ForbiddenException("You cannot grant this permission") + if perm_model not in org_model.permission_rel: # TODO: and not su + raise ForbiddenException("You cannot grant this permission") - group_model.permission_rel.append(perm_model) + group_model.permission_rel.append(perm_model) - db.flush() - response = IAMPutGroupPermissionResponse( - organisation=OrgSummary(**org_model.__dict__), - group=GroupSummary(**group_model.__dict__), - permissions=group_model.permission_rel, - ) - db.commit() - return response + db.flush() + response = IAMPutGroupPermissionResponse( + organisation=OrgSummary(**org_model.__dict__), + group=GroupSummary(**group_model.__dict__), + permissions=group_model.permission_rel, + ) + db.commit() + return response @router.put( - path="/group/user", - summary="Directly adds a user to the group", - status_code=status.HTTP_200_OK, - response_model=IAMPutGroupUserResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "Group not in org | User not authenticated | User does not have permission" - }, - status.HTTP_409_CONFLICT: {"description": "User is already in group"}, - status.HTTP_403_FORBIDDEN: { - "description": "Only existing org members can be added directly." - }, - }, + path="/group/user", + summary="Directly adds a user to the group", + status_code=status.HTTP_200_OK, + response_model=IAMPutGroupUserResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "Group not in org | User not authenticated | User does not have permission" + }, + status.HTTP_409_CONFLICT: {"description": "User is already in group"}, + status.HTTP_403_FORBIDDEN: { + "description": "Only existing org members can be added directly." + }, + }, ) async def add_group_user( - db: DbSession, - group_model: group_model_body_dependency, - user_model: user_model_body_dependency, - org_model: org_model_root_claim_body_dependency, - request_model: IAMPutGroupUserRequest, + db: DbSession, + group_model: group_model_body_dependency, + user_model: user_model_body_dependency, + org_model: org_model_root_claim_body_dependency, + request_model: IAMPutGroupUserRequest, ): - """ - Directly adds an organisation member to a group.\n - To add a non-member, use an email invitation instead.\n - The user's email address must match the email on their OIDC profile. - """ - if group_model.org_id != org_model.id: - raise ForbiddenException("Group does not belong to this organization") + """ + Directly adds an organisation member to a group.\n + To add a non-member, use an email invitation instead.\n + The user's email address must match the email on their OIDC profile. + """ + if group_model.org_id != org_model.id: + raise ForbiddenException("Group does not belong to this organization") - if user_model in group_model.user_rel: - raise ConflictException("User already in group") + if user_model in group_model.user_rel: + raise ConflictException("User already in group") - if user_model not in org_model.user_rel: - raise ForbiddenException( - "Adding users directly can only be done with org members. Use email invitation instead." - ) + if user_model not in org_model.user_rel: + raise ForbiddenException( + "Adding users directly can only be done with org members. Use email invitation instead." + ) - group_model.user_rel.append(user_model) - db.flush() - response = IAMPutGroupUserResponse( - group=GroupSummary(**group_model.__dict__), users=group_model.user_rel - ) - db.commit() - return response + group_model.user_rel.append(user_model) + db.flush() + response = IAMPutGroupUserResponse( + group=GroupSummary(**group_model.__dict__), users=group_model.user_rel + ) + db.commit() + return response @router.delete( - path="/group/permission", - summary="Removes a permission from the group", - status_code=status.HTTP_200_OK, - response_model=IAMDeleteGroupPermissionResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "Group not in org | User not authenticated | User does not have permission" - }, - }, + path="/group/permission", + summary="Removes a permission from the group", + status_code=status.HTTP_200_OK, + response_model=IAMDeleteGroupPermissionResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "Group not in org | User not authenticated | User does not have permission" + }, + }, ) async def remove_group_permission( - db: DbSession, - group_model: group_model_query_dependency, - perm_model: perm_model_query_dependency, - org_model: org_model_root_claim_query_dependency, + db: DbSession, + group_model: group_model_query_dependency, + perm_model: perm_model_query_dependency, + org_model: org_model_root_claim_query_dependency, ): - """ - Removes a permission from the group. - """ - if group_model.org_id != org_model.id: - raise ForbiddenException("Group does not belong to this organization") + """ + Removes a permission from the group. + """ + if group_model.org_id != org_model.id: + raise ForbiddenException("Group does not belong to this organization") - if perm_model not in group_model.permission_rel: - raise UnprocessableContentException("Permission not granted to group") + if perm_model not in group_model.permission_rel: + raise UnprocessableContentException("Permission not granted to group") - group_model.permission_rel.remove(perm_model) - db.flush() - response = IAMDeleteGroupPermissionResponse( - group=GroupSummary(**group_model.__dict__), - permissions=group_model.permission_rel, - ) - db.commit() - return response + group_model.permission_rel.remove(perm_model) + db.flush() + response = IAMDeleteGroupPermissionResponse( + group=GroupSummary(**group_model.__dict__), + permissions=group_model.permission_rel, + ) + db.commit() + return response @router.delete( - path="/group/user", - summary="Removes a user from the group", - status_code=status.HTTP_200_OK, - response_model=IAMDeleteGroupUserResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "User not authenticated | User does not have permission" - }, - status.HTTP_403_FORBIDDEN: {"description": "Group not in org"}, - }, + path="/group/user", + summary="Removes a user from the group", + status_code=status.HTTP_200_OK, + response_model=IAMDeleteGroupUserResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "User not authenticated | User does not have permission" + }, + status.HTTP_403_FORBIDDEN: {"description": "Group not in org"}, + }, ) async def remove_group_user( - db: DbSession, - group_model: group_model_query_dependency, - user_model: user_model_query_dependency, - org_model: org_model_root_claim_query_dependency, + db: DbSession, + group_model: group_model_query_dependency, + user_model: user_model_query_dependency, + org_model: org_model_root_claim_query_dependency, ): - """ - Removes a user from the group. - """ - if group_model.org_id != org_model.id: - raise ForbiddenException("Group does not belong to this organization") + """ + Removes a user from the group. + """ + if group_model.org_id != org_model.id: + raise ForbiddenException("Group does not belong to this organization") - user_model.group_rel.remove(group_model) - db.flush() - response = IAMDeleteGroupUserResponse( - group=GroupSummary(**group_model.__dict__), users=group_model.user_rel - ) - db.commit() + user_model.group_rel.remove(group_model) + db.flush() + response = IAMDeleteGroupUserResponse( + group=GroupSummary(**group_model.__dict__), users=group_model.user_rel + ) + db.commit() - return response + return response @router.get( - path="/permissions", - summary="Returns a full list of permissions", - status_code=status.HTTP_200_OK, - response_model=IAMGetPermissionsResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "User must be root user of an organisation." - }, - }, + path="/permissions", + summary="Returns a full list of permissions", + status_code=status.HTTP_200_OK, + response_model=IAMGetPermissionsResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "User must be root user of an organisation." + }, + }, ) -async def get_permissions( - db: DbSession, org_model: org_model_root_claim_query_dependency -): - """ - Returns a full list of permissions. - """ - # TODO: if su: - # permission_models = db.query(Perm).all() - # else - permission_models = db.query(Perm).filter(Perm.org_rel.any(id=org_model.id)).all() - return {"permissions": permission_models} +async def get_permissions(db: DbSession, org_model: org_model_root_claim_query_dependency): + """ + Returns a full list of permissions. + """ + # TODO: if su: + # permission_models = db.query(Perm).all() + # else + permission_models = db.query(Perm).filter(Perm.org_rel.any(id=org_model.id)).all() + return {"permissions": permission_models} @router.post( - path="/permission", - summary="Creates a new permission for a service", - status_code=status.HTTP_201_CREATED, - response_model=IAMPostPermissionResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: {"description": "Must be super user."}, - status.HTTP_404_NOT_FOUND: {"description": "Service does not exist"}, - status.HTTP_409_CONFLICT: {"description": "Permission already exists"}, - }, + path="/permission", + summary="Creates a new permission for a service", + status_code=status.HTTP_201_CREATED, + response_model=IAMPostPermissionResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: {"description": "Must be super user."}, + status.HTTP_404_NOT_FOUND: {"description": "Service does not exist"}, + status.HTTP_409_CONFLICT: {"description": "Permission already exists"}, + }, ) async def create_new_permission( - db: DbSession, - su: super_admin_dependency, - request_model: IAMPostPermissionRequest, - service_model: service_model_body_dependency, # Used to verify service model exists + db: DbSession, + su: super_admin_dependency, + request_model: IAMPostPermissionRequest, + service_model: service_model_body_dependency, # Used to verify service model exists ): - """ - Allows a super admin to create a new IAM permission for a service. - """ - perm_model = Perm(**request_model.__dict__) - db.add(perm_model) - try: - db.flush() - except IntegrityError as e: - if ( - isinstance(e.orig, UniqueViolation) # Postgres unique violation - or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation - ): - raise ConflictException(message="Permission already exists") - raise - response = { - "id": perm_model.id, - "service_name": perm_model.service_name, - "resource": perm_model.resource, - "action": perm_model.action, - } - db.commit() - return {"permission": response} + """ + Allows a super admin to create a new IAM permission for a service. + """ + perm_model = Perm(**request_model.__dict__) + db.add(perm_model) + try: + db.flush() + except IntegrityError as e: + if ( + isinstance(e.orig, UniqueViolation) # Postgres unique violation + or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation + ): + raise ConflictException(message="Permission already exists") + raise + response = { + "id": perm_model.id, + "service_name": perm_model.service_name, + "resource": perm_model.resource, + "action": perm_model.action, + } + db.commit() + return {"permission": response} @router.delete( - path="/permission", - summary="Deletes a permission", - status_code=status.HTTP_204_NO_CONTENT, - responses={}, + path="/permission", + summary="Deletes a permission", + status_code=status.HTTP_204_NO_CONTENT, + responses={}, ) async def delete_permission( - db: DbSession, - su: super_admin_dependency, - perm_model: perm_model_query_dependency, + db: DbSession, + su: super_admin_dependency, + perm_model: perm_model_query_dependency, ): - """ - Allows a super admin to remove a permission. - """ - db.delete(perm_model) - db.commit() + """ + Allows a super admin to remove a permission. + """ + db.delete(perm_model) + db.commit() @router.post( - path="/permissions/search", - summary="Search list of permissions", - status_code=status.HTTP_200_OK, - response_model=IAMGetPermissionsSearchResponse, - responses={}, + path="/permissions/search", + summary="Search list of permissions", + status_code=status.HTTP_200_OK, + response_model=IAMGetPermissionsSearchResponse, + responses={}, ) async def permissions_search( - db: DbSession, - org_model: org_model_root_claim_body_dependency, - request_model: IAMGetPermissionsSearchRequest, + db: DbSession, + org_model: org_model_root_claim_body_dependency, + request_model: IAMGetPermissionsSearchRequest, ): - """ - Returns a list of permissions filtered by the queries provided.\n - If a query is null, it will be ignored. - """ - permission_query = db.query(Perm) + """ + Returns a list of permissions filtered by the queries provided.\n + If a query is null, it will be ignored. + """ + permission_query = db.query(Perm) - if not (request_model.service_id is None or request_model.service_id == ""): - permission_query = permission_query.filter( - Perm.service_id == request_model.service_id - ) + if not (request_model.service_id is None or request_model.service_id == ""): + permission_query = permission_query.filter( + Perm.service_id == request_model.service_id + ) - if not (request_model.resource is None or request_model.resource == ""): - permission_query = permission_query.filter(Perm.resource == request_model.resource) + if not (request_model.resource is None or request_model.resource == ""): + permission_query = permission_query.filter(Perm.resource == request_model.resource) - if not (request_model.action is None or request_model.action == ""): - permission_query = permission_query.filter(Perm.action == request_model.action) + if not (request_model.action is None or request_model.action == ""): + permission_query = permission_query.filter(Perm.action == request_model.action) - # TODO: if not su: - permission_query = permission_query.filter(Perm.org_rel.any(id=org_model.id)) + # TODO: if not su: + permission_query = permission_query.filter(Perm.org_rel.any(id=org_model.id)) - permission_models = permission_query.all() + permission_models = permission_query.all() - return {"permissions": permission_models} + return {"permissions": permission_models} @router.put( - path="/group/user/invitation", - summary="Send an email invitation for non-org member to join a group", - status_code=status.HTTP_200_OK, - response_model=IAMPutGroupInvitationResponse, - responses={}, + path="/group/user/invitation", + summary="Send an email invitation for non-org member to join a group", + status_code=status.HTTP_200_OK, + response_model=IAMPutGroupInvitationResponse, + responses={}, ) async def invitation( - background_tasks: BackgroundTasks, - org_model: org_model_root_claim_body_dependency, - group_model: group_model_body_dependency, - request_model: IAMPutGroupInvitationRequest, + background_tasks: BackgroundTasks, + org_model: org_model_root_claim_body_dependency, + group_model: group_model_body_dependency, + request_model: IAMPutGroupInvitationRequest, ): - """ - Sends an email invitation to join a group.\n - This is intended for inviting no-members to a group, giving them permission to access org resources.\n - i.e. Allowing somebody in a partner organisation to view metrics.\n - Can also be used for inviting organisaion members if needed. - """ - org_id: int = org_model.id - org_name: str = org_model.name - user_email = request_model.user_email - group_id: int = group_model.id - group_name: str = group_model.name + """ + Sends an email invitation to join a group.\n + This is intended for inviting no-members to a group, giving them permission to access org resources.\n + i.e. Allowing somebody in a partner organisation to view metrics.\n + Can also be used for inviting organisaion members if needed. + """ + org_id: int = org_model.id + org_name: str = org_model.name + user_email = request_model.user_email + group_id: int = group_model.id + group_name: str = group_model.name - background_tasks.add_task( - send_user_group_invitation, - org_id=org_id, - org_name=org_name, - user_email=user_email, - group_id=group_id, - group_name=group_name, - ) + background_tasks.add_task( + send_user_group_invitation, + org_id=org_id, + org_name=org_name, + user_email=user_email, + group_id=group_id, + group_name=group_name, + ) - response = { - "organisation": org_model, - "group": group_model, - "invited_email": user_email, - } + response = { + "organisation": org_model, + "group": group_model, + "invited_email": user_email, + } - return response + return response @router.put( - path="/group/user/invitation/accept", - summary="Accept email invitation to join an org's group", - status_code=status.HTTP_200_OK, - response_model=IAMPutGroupInvitationAcceptResponse, - responses={ - status.HTTP_404_NOT_FOUND: {"description": "User|Org|Group not found"}, - status.HTTP_403_FORBIDDEN: {"description": "Group and organisation do not match"}, - status.HTTP_409_CONFLICT: {"description": "User is already in the group"}, - }, + path="/group/user/invitation/accept", + summary="Accept email invitation to join an org's group", + status_code=status.HTTP_200_OK, + response_model=IAMPutGroupInvitationAcceptResponse, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "User|Org|Group not found"}, + status.HTTP_403_FORBIDDEN: {"description": "Group and organisation do not match"}, + status.HTTP_409_CONFLICT: {"description": "User is already in the group"}, + }, ) async def accept_invitation( - db: DbSession, - user_model: user_model_claims_dependency, - request_model: IAMPutGroupInvitationAcceptRequest, + db: DbSession, + user_model: user_model_claims_dependency, + request_model: IAMPutGroupInvitationAcceptRequest, ): - """ - Accepts an invitation to join an org's group - """ - email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model) + """ + Accepts an invitation to join an org's group + """ + email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model) - org_model = db.get(Org, email_claims["org_id"]) - if org_model is None: - raise OrgNotFoundException(email_claims["org_id"]) + org_model = db.get(Org, email_claims["org_id"]) + if org_model is None: + raise OrgNotFoundException(email_claims["org_id"]) - group_model = db.get(Group, email_claims["group_id"]) - if group_model is None: - raise GroupNotFoundException(email_claims["group_id"]) + group_model = db.get(Group, email_claims["group_id"]) + if group_model is None: + raise GroupNotFoundException(email_claims["group_id"]) - if group_model not in org_model.group_rel: - raise ForbiddenException("Group and org do not match.") + if group_model not in org_model.group_rel: + raise ForbiddenException("Group and org do not match.") - if user_model in group_model.user_rel: - raise ConflictException("User already in group.") + if user_model in group_model.user_rel: + raise ConflictException("User already in group.") - group_model.user_rel.append(user_model) - db.flush() + group_model.user_rel.append(user_model) + db.flush() - response = { - "organisation": org_model, - "user": user_model, - "group": {"details": group_model, "permissions": group_model.permission_rel}, - } - db.commit() + response = { + "organisation": org_model, + "user": user_model, + "group": {"details": group_model, "permissions": group_model.permission_rel}, + } + db.commit() - return response + return response @router.put( - path="/org/permissions", - summary="Grants an org access to permissions", - status_code=status.HTTP_200_OK, - response_model=IAMPutOrgPermissionsResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: {"description": "Must be super user."}, - }, + path="/org/permissions", + summary="Grants an org access to permissions", + status_code=status.HTTP_200_OK, + response_model=IAMPutOrgPermissionsResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: {"description": "Must be super user."}, + }, ) async def add_org_permissions( - db: DbSession, - su: super_admin_dependency, - org_model: org_model_body_dependency, - request_model: IAMPutOrgPermissionsRequest, + db: DbSession, + su: super_admin_dependency, + org_model: org_model_body_dependency, + request_model: IAMPutOrgPermissionsRequest, ): - """ - Grants a permission to a group. Returns a list of the permissions in the group as well as a summary for the org and group. - """ - for permission in request_model.permissions: - perm_model = db.get(Perm, permission) + """ + Grants a permission to a group. Returns a list of the permissions in the group as well as a summary for the org and group. + """ + for permission in request_model.permissions: + perm_model = db.get(Perm, permission) - if perm_model not in org_model.permission_rel: - org_model.permission_rel.append(perm_model) + if perm_model not in org_model.permission_rel: + org_model.permission_rel.append(perm_model) - db.flush() - response = IAMPutOrgPermissionsResponse( - organisation=OrgSummary(**org_model.__dict__), - permissions=org_model.permission_rel, - ) - db.commit() - return response + db.flush() + response = IAMPutOrgPermissionsResponse( + organisation=OrgSummary(**org_model.__dict__), + permissions=org_model.permission_rel, + ) + db.commit() + return response diff --git a/src/iam/schemas.py b/src/iam/schemas.py index 8072914..3a8380c 100644 --- a/src/iam/schemas.py +++ b/src/iam/schemas.py @@ -11,151 +11,151 @@ from typing import Optional, Annotated from pydantic import EmailStr, ConfigDict, Field from src.schemas import ( - CustomBaseModel, - ResourceName, - ServiceIDMixin, - OrgIDMixin, - UserIDMixin, - PermIDMixin, - GroupIDMixin, - GroupSummary, - OrgSummary, - UserSummary, + CustomBaseModel, + ResourceName, + ServiceIDMixin, + OrgIDMixin, + UserIDMixin, + PermIDMixin, + GroupIDMixin, + GroupSummary, + OrgSummary, + UserSummary, ) class UserSchema(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") + model_config = ConfigDict(from_attributes=True, extra="ignore") - id: int - first_name: str - last_name: str - email: EmailStr + id: int + first_name: str + last_name: str + email: EmailStr class PermissionSchema(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") + model_config = ConfigDict(from_attributes=True, extra="ignore") - id: int - service_name: str - resource: str - action: str + id: int + service_name: str + resource: str + action: str class GroupDetails(CustomBaseModel): - details: GroupSummary - permissions: list[PermissionSchema] + details: GroupSummary + permissions: list[PermissionSchema] class IAMCAoRRequest(CustomBaseModel): - action: str - rn: ResourceName + action: str + rn: ResourceName class IAMCAoRResponse(CustomBaseModel): - allowed: bool - user: UserSummary - action: str - rn: ResourceName + allowed: bool + user: UserSummary + action: str + rn: ResourceName class IAMGetGroupPermissionsResponse(CustomBaseModel): - organisation: OrgSummary - group: GroupSummary - permissions: list[PermissionSchema] + organisation: OrgSummary + group: GroupSummary + permissions: list[PermissionSchema] class IAMGetGroupUsersResponse(CustomBaseModel): - organisation: OrgSummary - group: GroupSummary - users: list[UserSummary] + organisation: OrgSummary + group: GroupSummary + users: list[UserSummary] class IAMPostGroupRequest(OrgIDMixin): - name: str = Field(min_length=3) + name: str = Field(min_length=3) class IAMPostGroupResponse(CustomBaseModel): - organisation: OrgSummary - group: GroupSummary + organisation: OrgSummary + group: GroupSummary class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin): - pass + pass class IAMPutGroupPermissionResponse(CustomBaseModel): - organisation: OrgSummary - group: GroupSummary - permissions: list[PermissionSchema] + organisation: OrgSummary + group: GroupSummary + permissions: list[PermissionSchema] class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin): - pass + pass class IAMPutGroupUserResponse(CustomBaseModel): - group: GroupSummary - users: list[UserSchema] + group: GroupSummary + users: list[UserSchema] class IAMDeleteGroupPermissionResponse(CustomBaseModel): - group: GroupSummary - permissions: list[PermissionSchema] + group: GroupSummary + permissions: list[PermissionSchema] class IAMDeleteGroupUserResponse(CustomBaseModel): - group: GroupSummary - users: list[UserSchema] + group: GroupSummary + users: list[UserSchema] class IAMGetPermissionsResponse(CustomBaseModel): - permissions: list[PermissionSchema] + permissions: list[PermissionSchema] class IAMPostPermissionRequest(ServiceIDMixin): - resource: str - action: str + resource: str + action: str class IAMPostPermissionResponse(CustomBaseModel): - permission: PermissionSchema + permission: PermissionSchema class IAMGetPermissionsSearchRequest(OrgIDMixin): - service_id: Annotated[int | None, Field(gt=0)] = None - resource: Optional[str] = None - action: Optional[str] = None + service_id: Annotated[int | None, Field(gt=0)] = None + resource: Optional[str] = None + action: Optional[str] = None class IAMGetPermissionsSearchResponse(CustomBaseModel): - permissions: list[PermissionSchema] + permissions: list[PermissionSchema] class IAMPutGroupInvitationRequest(OrgIDMixin, GroupIDMixin): - user_email: EmailStr + user_email: EmailStr class IAMPutGroupInvitationResponse(CustomBaseModel): - organisation: OrgSummary - group: GroupSummary - invited_email: EmailStr + organisation: OrgSummary + group: GroupSummary + invited_email: EmailStr class IAMPutGroupInvitationAcceptRequest(CustomBaseModel): - jwt: str + jwt: str class IAMPutGroupInvitationAcceptResponse(CustomBaseModel): - organisation: OrgSummary - user: UserSummary - group: GroupDetails + organisation: OrgSummary + user: UserSummary + group: GroupDetails class IAMPutOrgPermissionsRequest(OrgIDMixin): - permissions: list[int] + permissions: list[int] class IAMPutOrgPermissionsResponse(CustomBaseModel): - organisation: OrgSummary - permissions: list[PermissionSchema] + organisation: OrgSummary + permissions: list[PermissionSchema] diff --git a/src/iam/service.py b/src/iam/service.py index 2112c6c..0af4bfd 100644 --- a/src/iam/service.py +++ b/src/iam/service.py @@ -23,90 +23,90 @@ from src.service.schemas import HasServiceName def valid_service_key( - db: DbSession, request: Request, request_model: HasServiceName + db: DbSession, request: Request, request_model: HasServiceName ) -> bool: - rn = request_model.rn - api_key = request.headers.get("X-API-Key", None) - if not api_key: - raise UnauthorizedException("Missing API key") - service = rn.service - result = ( - db.query(Service) - .filter(Service.name == service) - .filter(Service.api_key == api_key) - .first() - ) - if result is None: - raise UnauthorizedException("Invalid API key") + rn = request_model.rn + api_key = request.headers.get("X-API-Key", None) + if not api_key: + raise UnauthorizedException("Missing API key") + service = rn.service + result = ( + db.query(Service) + .filter(Service.name == service) + .filter(Service.api_key == api_key) + .first() + ) + if result is None: + raise UnauthorizedException("Invalid API key") - return True + return True service_key_dependency = Annotated[bool, Depends(valid_service_key)] async def send_user_group_invitation( - user_email: str, org_name: str, org_id: int, group_id: int, group_name: str + user_email: str, org_name: str, org_id: int, group_id: int, group_name: str ): - expiry_delta = timedelta(hours=24) - expiry = datetime.now(timezone.utc) + expiry_delta - claims = { - "email": user_email, - "org_id": org_id, - "group_id": group_id, - "group_name": group_name, - "exp": expiry, - "type": "group_invite", - } + expiry_delta = timedelta(hours=24) + expiry = datetime.now(timezone.utc) + expiry_delta + claims = { + "email": user_email, + "org_id": org_id, + "group_id": group_id, + "group_name": group_name, + "exp": expiry, + "type": "group_invite", + } - token = await generate_jwt(claims) - subject = f"You have been invited to join a group of {org_name}" - body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}" + token = await generate_jwt(claims) + subject = f"You have been invited to join a group of {org_name}" + body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}" - await send_email( - recipient=user_email, - subject=subject, - body=body, - ) + await send_email( + recipient=user_email, + subject=subject, + body=body, + ) async def create_group_and_assign_perms( - db: Session, org_model: Org, group_name: str, perm_list: list[int] + db: Session, org_model: Org, group_name: str, perm_list: list[int] ): - new_group = Group(name=group_name, org_id=org_model.id) - db.add(new_group) - db.flush() + new_group = Group(name=group_name, org_id=org_model.id) + db.add(new_group) + db.flush() - for permission in perm_list: - perm_model = db.get(Perm, permission) + for permission in perm_list: + perm_model = db.get(Perm, permission) - if perm_model is None: - continue + if perm_model is None: + continue - new_group.permission_rel.append(perm_model) - db.flush() + new_group.permission_rel.append(perm_model) + db.flush() - return new_group + return new_group async def assign_default_group( - db: DbSession, - org_model: Org, - user_model: User, - group_name: str, - perm_list: list[int], + db: DbSession, + org_model: Org, + user_model: User, + group_name: str, + perm_list: list[int], ): - group_model = ( - db.query(Group) - .filter(Group.org_id == org_model.id) - .filter(Group.name == group_name) - .first() - ) + group_model = ( + db.query(Group) + .filter(Group.org_id == org_model.id) + .filter(Group.name == group_name) + .first() + ) - if group_model is None: - group_model = await create_group_and_assign_perms( - db=db, group_name=group_name, org_model=org_model, perm_list=perm_list - ) + if group_model is None: + group_model = await create_group_and_assign_perms( + db=db, group_name=group_name, org_model=org_model, perm_list=perm_list + ) - user_model.group_rel.append(group_model) - db.flush() + user_model.group_rel.append(group_model) + db.flush() diff --git a/src/main.py b/src/main.py index a96eddc..787cb36 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,7 @@ """ Application root file: Inits the FastAPI application """ + import os.path from contextlib import asynccontextmanager from typing import AsyncGenerator @@ -19,43 +20,43 @@ from src.auth.service import get_current_user, get_dev_user @asynccontextmanager async def lifespan(_application: FastAPI) -> AsyncGenerator: - # Startup - yield - # Shutdown + # Startup + yield + # Shutdown if settings.ENVIRONMENT.is_deployed: - # Just a precaution, should be False anyway - settings.DISABLE_AUTH = False + # Just a precaution, should be False anyway + settings.DISABLE_AUTH = False tags_metadata = [ - { - "name": "User", - "description": "User related operations, includes getting information about the current user", - }, - { - "name": "Organisation", - "description": "Organisation related operations, includes getting lists of users etc associated with orgs", - }, - { - "name": "Service", - "description": "Services related operations, includes registering services and reissuing API keys", - }, - { - "name": "IAM", - "description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.", - }, + { + "name": "User", + "description": "User related operations, includes getting information about the current user", + }, + { + "name": "Organisation", + "description": "Organisation related operations, includes getting lists of users etc associated with orgs", + }, + { + "name": "Service", + "description": "Services related operations, includes registering services and reissuing API keys", + }, + { + "name": "IAM", + "description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.", + }, ] app = FastAPI( - swagger_ui_init_oauth={ - "clientId": auth_settings.CLIENT_ID, - "usePkceWithAuthorizationCodeGrant": True, - "scopes": "openid profile email", - }, - openapi_tags=tags_metadata, + swagger_ui_init_oauth={ + "clientId": auth_settings.CLIENT_ID, + "usePkceWithAuthorizationCodeGrant": True, + "scopes": "openid profile email", + }, + openapi_tags=tags_metadata, ) # Type inspection disabled for middleware injection. @@ -64,19 +65,19 @@ app = FastAPI( app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value()) # noinspection PyTypeChecker app.add_middleware( - CORSMiddleware, - allow_origins=settings.CORS_ORIGINS, - allow_origin_regex=settings.CORS_ORIGINS_REGEX, - allow_credentials=True, - allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"), - allow_headers=settings.CORS_HEADERS, + CORSMiddleware, + allow_origins=settings.CORS_ORIGINS, + allow_origin_regex=settings.CORS_ORIGINS_REGEX, + allow_credentials=True, + allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"), + allow_headers=settings.CORS_HEADERS, ) if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL): - app.dependency_overrides[get_current_user] = get_dev_user + app.dependency_overrides[get_current_user] = get_dev_user app.include_router(api_router) if os.path.exists("/app/static"): - app.frontend("/ui", directory="/app/static", fallback="index.html") + app.frontend("/ui", directory="/app/static", fallback="index.html") diff --git a/src/models.py b/src/models.py index 3f2295d..3dbc891 100644 --- a/src/models.py +++ b/src/models.py @@ -10,28 +10,28 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column class CustomBase(DeclarativeBase): - type_annotation_map = { - datetime: DateTime(timezone=True), - dict[str, Any]: JSON, - } + type_annotation_map = { + datetime: DateTime(timezone=True), + dict[str, Any]: JSON, + } class ActivatedMixin: - active: Mapped[bool] = mapped_column(default=True) + active: Mapped[bool] = mapped_column(default=True) class DeletedTimestampMixin: - deleted_at: Mapped[datetime | None] = mapped_column(nullable=True) + deleted_at: Mapped[datetime | None] = mapped_column(nullable=True) class DescriptionMixin: - description: Mapped[str] + description: Mapped[str] class IdMixin: - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) class TimestampMixin: - created_at: Mapped[datetime] = mapped_column(default=func.now()) - updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now()) + created_at: Mapped[datetime] = mapped_column(default=func.now()) + updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now()) diff --git a/src/organisation/constants.py b/src/organisation/constants.py index 94bcc2d..dc3d4c2 100644 --- a/src/organisation/constants.py +++ b/src/organisation/constants.py @@ -10,48 +10,48 @@ from enum import StrEnum, auto class Status(StrEnum): - """ - Enumeration of organisation statuses. + """ + Enumeration of organisation statuses. - Attributes: - PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted. - SUBMITTED (str): Questionnaire submitted but not approved. - REMEDIATION (str): Questionnaire submitted but requires revisions. - APPROVED (str): Questionnaire has been approved by an admin. - REJECTED (str): Questionnaire has been rejected by an admin. - REMOVED (str): Organisation has been removed. - """ + Attributes: + PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted. + SUBMITTED (str): Questionnaire submitted but not approved. + REMEDIATION (str): Questionnaire submitted but requires revisions. + APPROVED (str): Questionnaire has been approved by an admin. + REJECTED (str): Questionnaire has been rejected by an admin. + REMOVED (str): Organisation has been removed. + """ - PARTIAL = auto() - SUBMITTED = auto() - REMEDIATION = auto() - APPROVED = auto() - REJECTED = auto() - REMOVED = auto() + PARTIAL = auto() + SUBMITTED = auto() + REMEDIATION = auto() + APPROVED = auto() + REJECTED = auto() + REMOVED = auto() - @property - def is_pre_approval(self): - return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION) + @property + def is_pre_approval(self): + return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION) - @property - def is_pre_submission(self): - return self in (self.PARTIAL, self.REMEDIATION) + @property + def is_pre_submission(self): + return self in (self.PARTIAL, self.REMEDIATION) - @property - def is_blocked(self): - return self in (self.REMOVED, self.REJECTED) + @property + def is_blocked(self): + return self in (self.REMOVED, self.REJECTED) class ContactType(StrEnum): - """ - Enumeration of organisation contact types. + """ + Enumeration of organisation contact types. - Attributes: - BILLING(str): Billing contact. - SECURITY (str): Security contact. - OWNER (str): Owner contact. - """ + Attributes: + BILLING(str): Billing contact. + SECURITY (str): Security contact. + OWNER (str): Owner contact. + """ - BILLING = auto() - SECURITY = auto() - OWNER = auto() + BILLING = auto() + SECURITY = auto() + OWNER = auto() diff --git a/src/organisation/dependencies.py b/src/organisation/dependencies.py index 2686560..e9f6f6b 100644 --- a/src/organisation/dependencies.py +++ b/src/organisation/dependencies.py @@ -18,25 +18,25 @@ from src.organisation.exceptions import OrgNotFoundException def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org: - org_model = db.get(Org, org_id) - if org_model is None: - raise OrgNotFoundException(org_id) - return org_model + org_model = db.get(Org, org_id) + if org_model is None: + raise OrgNotFoundException(org_id) + return org_model org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)] def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org: - org_id: Optional[int] = getattr(request_model, "organisation_id", None) - if org_id is None: - raise OrgNotFoundException() + org_id: Optional[int] = getattr(request_model, "organisation_id", None) + if org_id is None: + raise OrgNotFoundException() - org_model = db.get(Org, org_id) - if org_model is None: - raise OrgNotFoundException(org_id) + org_model = db.get(Org, org_id) + if org_model is None: + raise OrgNotFoundException(org_id) - return org_model + return org_model org_model_body_dependency = Annotated[Org, Depends(get_org_model_body)] diff --git a/src/organisation/exceptions.py b/src/organisation/exceptions.py index a56b395..930aaf5 100644 --- a/src/organisation/exceptions.py +++ b/src/organisation/exceptions.py @@ -12,26 +12,26 @@ from fastapi import HTTPException, status class OrgNotFoundException(HTTPException): - def __init__(self, org_id: Optional[int] = None) -> None: - detail = ( - "Organisation not found" - if org_id is None - else f"Organisation with ID '{org_id}' was not found." - ) - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - ) + def __init__(self, org_id: Optional[int] = None) -> None: + detail = ( + "Organisation not found" + if org_id is None + else f"Organisation with ID '{org_id}' was not found." + ) + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) class AwaitingApprovalException(HTTPException): - def __init__(self, org_id: Optional[int] = None) -> None: - detail = ( - "Organisation has not been approved." - if org_id is None - else f"Organisation with ID '{org_id}' has not been approved." - ) - super().__init__( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=detail, - ) + def __init__(self, org_id: Optional[int] = None) -> None: + detail = ( + "Organisation has not been approved." + if org_id is None + else f"Organisation with ID '{org_id}' has not been approved." + ) + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=detail, + ) diff --git a/src/organisation/models.py b/src/organisation/models.py index 97d1247..e28c7cc 100644 --- a/src/organisation/models.py +++ b/src/organisation/models.py @@ -25,51 +25,51 @@ from src.models import CustomBase, TimestampMixin class Organisation(CustomBase, IdMixin, TimestampMixin, DeletedTimestampMixin): - __tablename__ = "organisation" + __tablename__ = "organisation" - name: Mapped[str] - status: Mapped[str] = mapped_column(default="partial") - intake_questionnaire: Mapped[dict[str, Any] | None] + name: Mapped[str] + status: Mapped[str] = mapped_column(default="partial") + intake_questionnaire: Mapped[dict[str, Any] | None] - root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) - billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) - security_contact_id: Mapped[int] = mapped_column( - ForeignKey("contact.id"), nullable=True - ) - owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) + root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) + billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) + security_contact_id: Mapped[int] = mapped_column( + ForeignKey("contact.id"), nullable=True + ) + owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) - user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel") + user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel") - group_rel = relationship( - "Group", back_populates="org_rel", cascade="all, delete-orphan" - ) - root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id") + group_rel = relationship( + "Group", back_populates="org_rel", cascade="all, delete-orphan" + ) + root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id") - billing_contact_rel = relationship( - "Contact", foreign_keys="Organisation.billing_contact_id" - ) - security_contact_rel = relationship( - "Contact", foreign_keys="Organisation.security_contact_id" - ) - owner_contact_rel = relationship( - "Contact", foreign_keys="Organisation.owner_contact_id" - ) + billing_contact_rel = relationship( + "Contact", foreign_keys="Organisation.billing_contact_id" + ) + security_contact_rel = relationship( + "Contact", foreign_keys="Organisation.security_contact_id" + ) + owner_contact_rel = relationship( + "Contact", foreign_keys="Organisation.owner_contact_id" + ) - permission_rel = relationship( - "Permission", secondary="org_permissions", back_populates="org_rel" - ) + permission_rel = relationship( + "Permission", secondary="org_permissions", back_populates="org_rel" + ) - @property - def root_user_email(self) -> str: - return self.root_user_rel.email if self.root_user_rel else "" + @property + def root_user_email(self) -> str: + return self.root_user_rel.email if self.root_user_rel else "" class OrgUsers(CustomBase): - __tablename__ = "orgusers" + __tablename__ = "orgusers" - org_id: Mapped[int] = mapped_column( - ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True - ) - user_id: Mapped[int] = mapped_column( - ForeignKey("user.id", ondelete="CASCADE"), primary_key=True - ) + org_id: Mapped[int] = mapped_column( + ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True + ) + user_id: Mapped[int] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), primary_key=True + ) diff --git a/src/organisation/router.py b/src/organisation/router.py index 78159ca..3c06b91 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -26,9 +26,9 @@ from fastapi.params import Query from src.contact.schemas import ContactModel from src.exceptions import ( - UnprocessableContentException, - ConflictException, - ForbiddenException, + UnprocessableContentException, + ConflictException, + ForbiddenException, ) from src.contact.models import Contact from src.contact.schemas import ContactAddress @@ -37,614 +37,612 @@ from src.database import DbSession from src.organisation.schemas_questionnaires import QuestionnaireQuestionsVersion0 from src.organisation.service import assign_defaults from src.user.dependencies import ( - user_model_body_dependency, - user_model_claims_dependency, - user_model_query_dependency, + user_model_body_dependency, + user_model_claims_dependency, + user_model_query_dependency, ) from src.auth.dependencies import ( - super_admin_dependency, - org_model_root_claim_query_dependency, - org_model_root_claim_body_dependency, + super_admin_dependency, + org_model_root_claim_query_dependency, + org_model_root_claim_body_dependency, ) from src.iam.models import Group from src.organisation.dependencies import ( - org_model_body_dependency, - org_model_query_dependency, + org_model_body_dependency, + org_model_query_dependency, ) from src.organisation.constants import ContactType, Status as StatusEnum from src.organisation.models import Organisation as Org from src.organisation.schemas import ( - OrgPostOrgRequest, - OrgPatchQuestionnaireRequest, - OrgPatchStatusRequest, - OrgPatchContactRequest, - OrgPostUserRequest, - OrgGetUserResponse, - OrgGetContactResponse, - OrgGetOrgResponse, - OrgPatchRootRequest, - OrgGetGroupResponse, - OrgPostOrgResponse, - OrgPatchQuestionnaireResponse, - OrgPatchStatusResponse, - OrgPostUserResponse, - OrgPatchRootResponse, - Questionnaire, - OrgPatchContactResponse, - QuestionnaireMetadata, + OrgPostOrgRequest, + OrgPatchQuestionnaireRequest, + OrgPatchStatusRequest, + OrgPatchContactRequest, + OrgPostUserRequest, + OrgGetUserResponse, + OrgGetContactResponse, + OrgGetOrgResponse, + OrgPatchRootRequest, + OrgGetGroupResponse, + OrgPostOrgResponse, + OrgPatchQuestionnaireResponse, + OrgPatchStatusResponse, + OrgPostUserResponse, + OrgPatchRootResponse, + Questionnaire, + OrgPatchContactResponse, + QuestionnaireMetadata, ) router = APIRouter( - prefix="/org", - tags=["Organisation"], + prefix="/org", + tags=["Organisation"], ) @router.get( - "", - summary="Get org details by ID.", - response_model=OrgGetOrgResponse, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - status.HTTP_404_NOT_FOUND: {"description": "Organisation not found"}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Missing or invalid org_id query parameter" - }, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - }, + "", + summary="Get org details by ID.", + response_model=OrgGetOrgResponse, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + status.HTTP_404_NOT_FOUND: {"description": "Organisation not found"}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Missing or invalid org_id query parameter" + }, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + }, ) -async def get_org_by_id( - db: DbSession, org_model: org_model_root_claim_query_dependency -): - """ - Returns organisation details including key member email addresses - """ - response = { - "organisation_id": org_model.id, - "name": org_model.name, - "status": org_model.status, - "intake_questionnaire": org_model.intake_questionnaire, - "root_user_email": org_model.root_user_email, - "billing_contact": { - "id": org_model.billing_contact_id, - "email": org_model.billing_contact_rel.email, - }, - "owner_contact": { - "id": org_model.owner_contact_id, - "email": org_model.owner_contact_rel.email, - }, - "security_contact": { - "id": org_model.security_contact_id, - "email": org_model.security_contact_rel.email, - }, - } +async def get_org_by_id(db: DbSession, org_model: org_model_root_claim_query_dependency): + """ + Returns organisation details including key member email addresses + """ + response = { + "organisation_id": org_model.id, + "name": org_model.name, + "status": org_model.status, + "intake_questionnaire": org_model.intake_questionnaire, + "root_user_email": org_model.root_user_email, + "billing_contact": { + "id": org_model.billing_contact_id, + "email": org_model.billing_contact_rel.email, + }, + "owner_contact": { + "id": org_model.owner_contact_id, + "email": org_model.owner_contact_rel.email, + }, + "security_contact": { + "id": org_model.security_contact_id, + "email": org_model.security_contact_rel.email, + }, + } - return {"organisations": [response]} + return {"organisations": [response]} @router.post( - "", - summary="Create new organisation.", - status_code=status.HTTP_201_CREATED, - response_model=OrgPostOrgResponse, - responses={ - status.HTTP_201_CREATED: {"description": "Successfully created organisation."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_401_UNAUTHORIZED: { - "description": "User must be logged in with OIDC to create organisation." - }, - status.HTTP_409_CONFLICT: { - "description": "Organisation with this name already exists." - }, - }, + "", + summary="Create new organisation.", + status_code=status.HTTP_201_CREATED, + response_model=OrgPostOrgResponse, + responses={ + status.HTTP_201_CREATED: {"description": "Successfully created organisation."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_401_UNAUTHORIZED: { + "description": "User must be logged in with OIDC to create organisation." + }, + status.HTTP_409_CONFLICT: { + "description": "Organisation with this name already exists." + }, + }, ) async def create_org( - db: DbSession, - user_model: user_model_claims_dependency, - request_model: OrgPostOrgRequest, - background_tasks: BackgroundTasks, + db: DbSession, + user_model: user_model_claims_dependency, + request_model: OrgPostOrgRequest, + background_tasks: BackgroundTasks, ): - """ - Creates a new organisation with optional questionnaire (to be completed or submitted). - ALl organisations are given the "partial" status on creation. See update_questionnaire() for more details. - """ - if request_model.intake_questionnaire: - questionnaire_questions = request_model.intake_questionnaire.model_dump() - else: - questionnaire_questions = QuestionnaireQuestionsVersion0().model_dump() + """ + Creates a new organisation with optional questionnaire (to be completed or submitted). + ALl organisations are given the "partial" status on creation. See update_questionnaire() for more details. + """ + if request_model.intake_questionnaire: + questionnaire_questions = request_model.intake_questionnaire.model_dump() + else: + questionnaire_questions = QuestionnaireQuestionsVersion0().model_dump() - questionnaire_metadata = QuestionnaireMetadata(version=0, submission_date=None) + questionnaire_metadata = QuestionnaireMetadata(version=0, submission_date=None) - intake_questionnaire = Questionnaire( - metadata=questionnaire_metadata, - questions=questionnaire_questions, - ) + intake_questionnaire = Questionnaire( + metadata=questionnaire_metadata, + questions=questionnaire_questions, + ) - org_model = Org( - name=request_model.name, - intake_questionnaire=intake_questionnaire.model_dump(mode="json"), - root_user_id=user_model.id, - ) + org_model = Org( + name=request_model.name, + intake_questionnaire=intake_questionnaire.model_dump(mode="json"), + root_user_id=user_model.id, + ) - org_model.status = "partial" + org_model.status = "partial" - db.add(org_model) - try: - db.flush() - except IntegrityError as e: - if ( - isinstance(e.orig, UniqueViolation) # Postgres unique violation - or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation - ): - raise ConflictException(message="Organisation with this name already exists") - raise - # Adds currently logged-in user to org users list and sets them as root_user - org_model.user_rel.append(user_model) + db.add(org_model) + try: + db.flush() + except IntegrityError as e: + if ( + isinstance(e.orig, UniqueViolation) # Postgres unique violation + or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation + ): + raise ConflictException(message="Organisation with this name already exists") + raise + # Adds currently logged-in user to org users list and sets them as root_user + org_model.user_rel.append(user_model) - background_tasks.add_task( - assign_defaults, db, org_id=org_model.id, user_id=user_model.id - ) + background_tasks.add_task( + assign_defaults, db, org_id=org_model.id, user_id=user_model.id + ) - for contact_type in [ - "billing_contact_id", - "security_contact_id", - "owner_contact_id", - ]: - contact_model = Contact(org_id=org_model.id) - db.add(contact_model) - db.flush() - org_model.__setattr__(contact_type, contact_model.id) - response = OrgPostOrgResponse(**org_model.__dict__) - db.commit() - return response + for contact_type in [ + "billing_contact_id", + "security_contact_id", + "owner_contact_id", + ]: + contact_model = Contact(org_id=org_model.id) + db.add(contact_model) + db.flush() + org_model.__setattr__(contact_type, contact_model.id) + response = OrgPostOrgResponse(**org_model.__dict__) + db.commit() + return response @router.patch( - "/questionnaire", - summary="Update questionnaire.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchQuestionnaireResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated questionnaire."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - }, + "/questionnaire", + summary="Update questionnaire.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchQuestionnaireResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated questionnaire."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + }, ) async def update_questionnaire( - db: DbSession, - org_model: org_model_root_claim_body_dependency, - request_model: OrgPatchQuestionnaireRequest, + db: DbSession, + org_model: org_model_root_claim_body_dependency, + request_model: OrgPatchQuestionnaireRequest, ): - """ - Route for updating questionnaire. - The partial bool allows for submission of partially completed questionnaire and/or - final "are you sure" check before setting the org to be in "submitted" status, awaiting admin approval. - """ - org_status = StatusEnum(org_model.status) - if not org_status.is_pre_submission: - raise ForbiddenException("Questionnaire may only be modified prior to submission.") - update_data: dict = request_model.intake_questionnaire.model_dump(exclude_none=True) - questionnaire = org_model.intake_questionnaire - if questionnaire is None: - questionnaire_questions = QuestionnaireQuestionsVersion0().model_dump() + """ + Route for updating questionnaire. + The partial bool allows for submission of partially completed questionnaire and/or + final "are you sure" check before setting the org to be in "submitted" status, awaiting admin approval. + """ + org_status = StatusEnum(org_model.status) + if not org_status.is_pre_submission: + raise ForbiddenException("Questionnaire may only be modified prior to submission.") + update_data: dict = request_model.intake_questionnaire.model_dump(exclude_none=True) + questionnaire = org_model.intake_questionnaire + if questionnaire is None: + questionnaire_questions = QuestionnaireQuestionsVersion0().model_dump() - questionnaire_metadata = QuestionnaireMetadata(version=0, submission_date=None) + questionnaire_metadata = QuestionnaireMetadata(version=0, submission_date=None) - questionnaire = Questionnaire( - metadata=questionnaire_metadata, - questions=questionnaire_questions, - ).model_dump() + questionnaire = Questionnaire( + metadata=questionnaire_metadata, + questions=questionnaire_questions, + ).model_dump() - questions_model = QuestionnaireQuestionsVersion0(**questionnaire["questions"]) - else: - questions_model = QuestionnaireQuestionsVersion0(**questionnaire["questions"]) - for key, value in update_data.items(): - if hasattr(questions_model, key): - setattr(questions_model, key, value) - else: - raise UnprocessableContentException("Invalid keys in update request") + questions_model = QuestionnaireQuestionsVersion0(**questionnaire["questions"]) + else: + questions_model = QuestionnaireQuestionsVersion0(**questionnaire["questions"]) + for key, value in update_data.items(): + if hasattr(questions_model, key): + setattr(questions_model, key, value) + else: + raise UnprocessableContentException("Invalid keys in update request") - metadata = QuestionnaireMetadata(version=questionnaire["metadata"]["version"]) + metadata = QuestionnaireMetadata(version=questionnaire["metadata"]["version"]) - # Allows for partially completed questionnaires to be saved without being submitted for review - if not request_model.partial: - org_model.status = "submitted" - metadata.submission_date = datetime.now(timezone.utc) + # Allows for partially completed questionnaires to be saved without being submitted for review + if not request_model.partial: + org_model.status = "submitted" + metadata.submission_date = datetime.now(timezone.utc) - questionnaire_model = Questionnaire( - metadata=metadata, - questions=questions_model, - ) + questionnaire_model = Questionnaire( + metadata=metadata, + questions=questions_model, + ) - org_model.intake_questionnaire = questionnaire_model.model_dump(mode="json") - db.flush() - response = OrgPatchQuestionnaireResponse(**org_model.__dict__) - db.commit() - return response + org_model.intake_questionnaire = questionnaire_model.model_dump(mode="json") + db.flush() + response = OrgPatchQuestionnaireResponse(**org_model.__dict__) + db.commit() + return response @router.patch( - "/status", - summary="Update status of organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchStatusResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated organisation status."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_403_FORBIDDEN: {"description": "Not authorised. Must be super admin."}, - }, + "/status", + summary="Update status of organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchStatusResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated organisation status."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_403_FORBIDDEN: {"description": "Not authorised. Must be super admin."}, + }, ) async def update_status( - db: DbSession, - org_model: org_model_body_dependency, - su: super_admin_dependency, - request_model: OrgPatchStatusRequest, + db: DbSession, + org_model: org_model_body_dependency, + su: super_admin_dependency, + request_model: OrgPatchStatusRequest, ): - """ - Sets an organisation's status. This is the endpoint for approving or denying an organisation after reviewing the questionnaire. - """ - org_model.status = request_model.status - db.flush() - response = OrgPatchStatusResponse(**org_model.__dict__) - db.commit() - return response + """ + Sets an organisation's status. This is the endpoint for approving or denying an organisation after reviewing the questionnaire. + """ + org_model.status = request_model.status + db.flush() + response = OrgPatchStatusResponse(**org_model.__dict__) + db.commit() + return response @router.get( - "/users", - summary="Get email addresses of users of the organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgGetUserResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval of users."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Org ID missing or invalid." - }, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - }, + "/users", + summary="Get email addresses of users of the organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgGetUserResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval of users."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Org ID missing or invalid." + }, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + }, ) async def get_users(org_model: org_model_root_claim_query_dependency): - """ - Returns a list of the email addresses of all users of the organisation. - """ - return { - "users": [{"email": user.email, "id": user.id} for user in org_model.user_rel], - "organisation": org_model, - } + """ + Returns a list of the email addresses of all users of the organisation. + """ + return { + "users": [{"email": user.email, "id": user.id} for user in org_model.user_rel], + "organisation": org_model, + } @router.post( - "/user", - summary="Add user to the organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPostUserResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully added user to the organisation."}, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_409_CONFLICT: { - "description": "User is already a member of the organisation." - }, - }, + "/user", + summary="Add user to the organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPostUserResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully added user to the organisation."}, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_409_CONFLICT: { + "description": "User is already a member of the organisation." + }, + }, ) async def add_user_to_org( - db: DbSession, - org_model: org_model_body_dependency, - user_model: user_model_body_dependency, - su: super_admin_dependency, - request_model: OrgPostUserRequest, + db: DbSession, + org_model: org_model_body_dependency, + user_model: user_model_body_dependency, + su: super_admin_dependency, + request_model: OrgPostUserRequest, ): - """ - Adds a user to the organisation. - """ - if user_model in org_model.user_rel: - raise ConflictException(message="User already a part of this organisation") - org_model.user_rel.append(user_model) - db.flush() - group_model = ( - db.query(Group) - .filter(Group.org_id == org_model.id) - .filter(Group.name == "Default Users") - .first() - ) - if group_model is not None: - user_model.group_rel.append(group_model) - response = { - "organisation": org_model, - "users": [{"id": user.id, "email": user.email} for user in org_model.user_rel], - } - db.commit() - return response + """ + Adds a user to the organisation. + """ + if user_model in org_model.user_rel: + raise ConflictException(message="User already a part of this organisation") + org_model.user_rel.append(user_model) + db.flush() + group_model = ( + db.query(Group) + .filter(Group.org_id == org_model.id) + .filter(Group.name == "Default Users") + .first() + ) + if group_model is not None: + user_model.group_rel.append(group_model) + response = { + "organisation": org_model, + "users": [{"id": user.id, "email": user.email} for user in org_model.user_rel], + } + db.commit() + return response @router.delete( - "", - summary="Delete organisation from the hub.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."}, - status.HTTP_403_FORBIDDEN: {"description": "Not authorised. Must be super admin."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Org ID missing or invalid." - }, - }, + "", + summary="Delete organisation from the hub.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."}, + status.HTTP_403_FORBIDDEN: {"description": "Not authorised. Must be super admin."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Org ID missing or invalid." + }, + }, ) async def delete_organisation_by_id( - db: DbSession, - org_model: org_model_query_dependency, - su: super_admin_dependency, + db: DbSession, + org_model: org_model_query_dependency, + su: super_admin_dependency, ): - """ - Removes an organisation from the hub. - """ - org_model.status = "removed" - org_model.deleted_at = datetime.now(tz=timezone.utc) - db.commit() + """ + Removes an organisation from the hub. + """ + org_model.status = "removed" + org_model.deleted_at = datetime.now(tz=timezone.utc) + db.commit() @router.delete( - "/self", - summary="Delete organisation from the hub as root user before it has been approved.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Unprocessable content.", - "content": { - "application/json": { - "examples": { - "org_id": {"summary": "Invalid or missing org ID."}, - "oidc_claims": {"summary": "Invalid or missing OIDC claims."}, - } - } - }, - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Unauthorized", - "content": { - "application/json": { - "examples": { - "awaiting_approval": { - "summary": "Organisation has not yet been approved." - }, - "expired_token": {"summary": "User token has expired."}, - "oidc": {"summary": "Failed to verify OIDC claims."}, - } - } - }, - }, - status.HTTP_403_FORBIDDEN: { - "description": "Forbidden", - "content": { - "application/json": { - "examples": { - "invalid_state": { - "summary": "Organisation is no longer in pre-approval state." - }, - "not_root": {"summary": "Not authorised. Must be root user."}, - } - } - }, - }, - status.HTTP_404_NOT_FOUND: { - "description": "Not found", - "content": { - "application/json": { - "examples": { - "db_id": {"summary": "User not found in db when checking claims."}, - "user_model": {"summary": "User model not found in db."}, - "org_model": {"summary": "Org model not found in db."}, - } - } - }, - }, - }, + "/self", + summary="Delete organisation from the hub as root user before it has been approved.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Unprocessable content.", + "content": { + "application/json": { + "examples": { + "org_id": {"summary": "Invalid or missing org ID."}, + "oidc_claims": {"summary": "Invalid or missing OIDC claims."}, + } + } + }, + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Unauthorized", + "content": { + "application/json": { + "examples": { + "awaiting_approval": { + "summary": "Organisation has not yet been approved." + }, + "expired_token": {"summary": "User token has expired."}, + "oidc": {"summary": "Failed to verify OIDC claims."}, + } + } + }, + }, + status.HTTP_403_FORBIDDEN: { + "description": "Forbidden", + "content": { + "application/json": { + "examples": { + "invalid_state": { + "summary": "Organisation is no longer in pre-approval state." + }, + "not_root": {"summary": "Not authorised. Must be root user."}, + } + } + }, + }, + status.HTTP_404_NOT_FOUND: { + "description": "Not found", + "content": { + "application/json": { + "examples": { + "db_id": {"summary": "User not found in db when checking claims."}, + "user_model": {"summary": "User model not found in db."}, + "org_model": {"summary": "Org model not found in db."}, + } + } + }, + }, + }, ) async def delete_preapproved_organisation_by_id( - db: DbSession, - org_model: org_model_root_claim_query_dependency, + db: DbSession, + org_model: org_model_root_claim_query_dependency, ): - """ - Removes an organisation from the hub before it has been approved, if user is root. - """ - org_status = StatusEnum(org_model.status) - if not org_status.is_pre_approval: - raise ForbiddenException(message="Organisation is no longer in pre-approval state.") + """ + Removes an organisation from the hub before it has been approved, if user is root. + """ + org_status = StatusEnum(org_model.status) + if not org_status.is_pre_approval: + raise ForbiddenException(message="Organisation is no longer in pre-approval state.") - db.delete(org_model) - db.commit() + db.delete(org_model) + db.commit() @router.patch( - "/root_user", - summary="Update the root user of the organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchRootResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated root user."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_401_UNAUTHORIZED: { - "description": "Not authorised. Must be super admin." - }, - }, + "/root_user", + summary="Update the root user of the organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchRootResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated root user."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_401_UNAUTHORIZED: { + "description": "Not authorised. Must be super admin." + }, + }, ) async def update_root_user( - db: DbSession, - org_model: org_model_body_dependency, - user_model: user_model_body_dependency, - su: super_admin_dependency, - request_model: OrgPatchRootRequest, + db: DbSession, + org_model: org_model_body_dependency, + user_model: user_model_body_dependency, + su: super_admin_dependency, + request_model: OrgPatchRootRequest, ): - """ - Promotes an existing organisation user to the root user, giving them full control of the org. - """ - if user_model not in org_model.user_rel: - raise UnprocessableContentException( - message="This user does not belong to your organisation." - ) - org_model.root_user_rel = user_model - db.flush() - response = OrgPatchRootResponse( - name=org_model.name, root_user_email=org_model.root_user_email - ) - db.commit() - return response + """ + Promotes an existing organisation user to the root user, giving them full control of the org. + """ + if user_model not in org_model.user_rel: + raise UnprocessableContentException( + message="This user does not belong to your organisation." + ) + org_model.root_user_rel = user_model + db.flush() + response = OrgPatchRootResponse( + name=org_model.name, root_user_email=org_model.root_user_email + ) + db.commit() + return response @router.get( - "/groups", - summary="Get all organisation IAM groups.", - status_code=status.HTTP_200_OK, - response_model=OrgGetGroupResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval of IAM groups."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Org ID missing or invalid." - }, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - }, + "/groups", + summary="Get all organisation IAM groups.", + status_code=status.HTTP_200_OK, + response_model=OrgGetGroupResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval of IAM groups."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Org ID missing or invalid." + }, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + }, ) async def get_org_groups(org_model: org_model_root_claim_query_dependency): - """ - Returns a list of the names of all IAM groups created by the organisation. - """ - return { - "organisation": org_model, - "groups": [{"id": group.id, "name": group.name} for group in org_model.group_rel], - } + """ + Returns a list of the names of all IAM groups created by the organisation. + """ + return { + "organisation": org_model, + "groups": [{"id": group.id, "name": group.name} for group in org_model.group_rel], + } @router.delete( - "/user", - summary="Remove user from organisation.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "Successfully removed user."}, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - }, + "/user", + summary="Remove user from organisation.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "Successfully removed user."}, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + }, ) async def remove_user_from_org( - db: DbSession, - org_model: org_model_root_claim_query_dependency, - user_model: user_model_query_dependency, + db: DbSession, + org_model: org_model_root_claim_query_dependency, + user_model: user_model_query_dependency, ): - """ - Revokes a user's membership in an organisation. - """ - if user_model not in org_model.user_rel: - return + """ + Revokes a user's membership in an organisation. + """ + if user_model not in org_model.user_rel: + return - org_model.user_rel.remove(user_model) - db.commit() + org_model.user_rel.remove(user_model) + db.commit() @router.get( - "/contact", - summary="Get contact for organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgGetContactResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval of contact."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - }, + "/contact", + summary="Get contact for organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgGetContactResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval of contact."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + }, ) async def get_contact( - org_model: org_model_root_claim_query_dependency, - contact_type: Annotated[ - ContactType, Query(description="Must be billing|security|owner") - ], + org_model: org_model_root_claim_query_dependency, + contact_type: Annotated[ + ContactType, Query(description="Must be billing|security|owner") + ], ): - """ - Gets full details for a contact point at an organisation. - """ - match contact_type: - case "billing": - contact_model = org_model.billing_contact_rel - case "security": - contact_model = org_model.security_contact_rel - case "owner": - contact_model = org_model.owner_contact_rel - case _: - raise UnprocessableContentException("Invalid contact type") + """ + Gets full details for a contact point at an organisation. + """ + match contact_type: + case "billing": + contact_model = org_model.billing_contact_rel + case "security": + contact_model = org_model.security_contact_rel + case "owner": + contact_model = org_model.owner_contact_rel + case _: + raise UnprocessableContentException("Invalid contact type") - if contact_model is None: - raise ContactNotFoundException() + if contact_model is None: + raise ContactNotFoundException() - address = ContactAddress.model_validate(contact_model) - contact_response = ContactModel.model_construct( - **contact_model.__dict__, address=address - ) + address = ContactAddress.model_validate(contact_model) + contact_response = ContactModel.model_construct( + **contact_model.__dict__, address=address + ) - return {"contact": contact_response, "organisation": org_model} + return {"contact": contact_response, "organisation": org_model} @router.patch( - "/contact", - summary="Update contact for organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchContactResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated contact."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - }, + "/contact", + summary="Update contact for organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchContactResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated contact."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + }, ) async def update_contact( - db: DbSession, - org_model: org_model_root_claim_body_dependency, - request_model: OrgPatchContactRequest, + db: DbSession, + org_model: org_model_root_claim_body_dependency, + request_model: OrgPatchContactRequest, ): - """ - Updates details for a contact point at an organisation. - """ - match request_model.contact_type: - case "billing": - contact_model = org_model.billing_contact_rel - case "security": - contact_model = org_model.security_contact_rel - case "owner": - contact_model = org_model.owner_contact_rel - case _: - raise UnprocessableContentException("Invalid contact type") + """ + Updates details for a contact point at an organisation. + """ + match request_model.contact_type: + case "billing": + contact_model = org_model.billing_contact_rel + case "security": + contact_model = org_model.security_contact_rel + case "owner": + contact_model = org_model.owner_contact_rel + case _: + raise UnprocessableContentException("Invalid contact type") - if contact_model is None: - raise ContactNotFoundException() + if contact_model is None: + raise ContactNotFoundException() - update_data = request_model.model_dump(exclude_none=True) - for key, value in update_data.items(): - if hasattr(contact_model, key): - setattr(contact_model, key, value) - else: - if key == "contact_type" or key == "organisation_id": - continue - raise UnprocessableContentException("Invalid keys in update request") - db.flush() + update_data = request_model.model_dump(exclude_none=True) + for key, value in update_data.items(): + if hasattr(contact_model, key): + setattr(contact_model, key, value) + else: + if key == "contact_type" or key == "organisation_id": + continue + raise UnprocessableContentException("Invalid keys in update request") + db.flush() - address = ContactAddress.model_validate(contact_model) - contact_response = ContactModel.model_construct( - **contact_model.__dict__, address=address - ) + address = ContactAddress.model_validate(contact_model) + contact_response = ContactModel.model_construct( + **contact_model.__dict__, address=address + ) - db.commit() + db.commit() - return {"contact": contact_response, "organisation": org_model} + return {"contact": contact_response, "organisation": org_model} diff --git a/src/organisation/schemas.py b/src/organisation/schemas.py index 522158e..d4bf592 100644 --- a/src/organisation/schemas.py +++ b/src/organisation/schemas.py @@ -12,139 +12,139 @@ from datetime import datetime from pydantic import EmailStr, ConfigDict, Field from src.schemas import ( - CustomBaseModel, - OrgIDMixin, - UserIDMixin, - GroupSummary, - OrgSummary, - UserSummary, + CustomBaseModel, + OrgIDMixin, + UserIDMixin, + GroupSummary, + OrgSummary, + UserSummary, ) from src.contact.schemas import ContactModel from src.organisation.constants import Status, ContactType from src.organisation.schemas_questionnaires import ( - QuestionnaireQuestionsVersion0 as CurrentQuestions, - questionnaire_union, + QuestionnaireQuestionsVersion0 as CurrentQuestions, + questionnaire_union, ) class QuestionnaireMetadata(CustomBaseModel): - version: int - submission_date: Optional[datetime] = None + version: int + submission_date: Optional[datetime] = None class Questionnaire(CustomBaseModel): - metadata: QuestionnaireMetadata - questions: questionnaire_union + metadata: QuestionnaireMetadata + questions: questionnaire_union class ContactSummary(CustomBaseModel): - id: int - email: Optional[EmailStr] = None + id: int + email: Optional[EmailStr] = None class OrgSchema(OrgIDMixin): - name: str - status: Status - root_user_email: EmailStr - intake_questionnaire: Optional[Questionnaire] = None + name: str + status: Status + root_user_email: EmailStr + intake_questionnaire: Optional[Questionnaire] = None - billing_contact: ContactSummary - owner_contact: ContactSummary - security_contact: ContactSummary + billing_contact: ContactSummary + owner_contact: ContactSummary + security_contact: ContactSummary class OrgPostOrgRequest(CustomBaseModel): - name: str = Field(min_length=3) - intake_questionnaire: Optional[CurrentQuestions] = None + name: str = Field(min_length=3) + intake_questionnaire: Optional[CurrentQuestions] = None class OrgPostOrgResponse(CustomBaseModel): - id: int - name: str - status: Status + id: int + name: str + status: Status class OrgPatchQuestionnaireRequest(OrgIDMixin): - intake_questionnaire: CurrentQuestions - partial: bool + intake_questionnaire: CurrentQuestions + partial: bool class OrgPatchQuestionnaireResponse(CustomBaseModel): - id: int - name: str - intake_questionnaire: Questionnaire - status: Status + id: int + name: str + intake_questionnaire: Questionnaire + status: Status class OrgPatchStatusRequest(OrgIDMixin): - status: Status + status: Status class OrgPatchStatusResponse(CustomBaseModel): - id: int - name: str - status: Status + id: int + name: str + status: Status class OrgPatchContactRequest(OrgIDMixin): - contact_type: ContactType + contact_type: ContactType - email: Optional[EmailStr] = None - first_name: Optional[str] = None - last_name: Optional[str] = None - phonenumber: Optional[str] = None - vat_number: Optional[str] = None - post_office_box_number: Optional[str] = None - street_address: Optional[str] = None - street_address_line_2: Optional[str] = None - locality: Optional[str] = None - address_region: Optional[str] = None - country_code: Optional[str] = None - postal_code: Optional[str] = None + email: Optional[EmailStr] = None + first_name: Optional[str] = None + last_name: Optional[str] = None + phonenumber: Optional[str] = None + vat_number: Optional[str] = None + post_office_box_number: Optional[str] = None + street_address: Optional[str] = None + street_address_line_2: Optional[str] = None + locality: Optional[str] = None + address_region: Optional[str] = None + country_code: Optional[str] = None + postal_code: Optional[str] = None class OrgPostUserRequest(OrgIDMixin, UserIDMixin): - pass + pass class OrgPostUserResponse(CustomBaseModel): - organisation: OrgSummary - users: list[UserSummary] + organisation: OrgSummary + users: list[UserSummary] class OrgPatchRootRequest(OrgIDMixin, UserIDMixin): - pass + pass class OrgPatchRootResponse(CustomBaseModel): - name: str - root_user_email: str + name: str + root_user_email: str class OrgGetUserResponse(CustomBaseModel): - users: list[dict[str, str | int]] - organisation: OrgSummary + users: list[dict[str, str | int]] + organisation: OrgSummary class OrgGetGroupResponse(CustomBaseModel): - organisation: OrgSummary - groups: list[GroupSummary] + organisation: OrgSummary + groups: list[GroupSummary] class OrgGetContactResponse(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") + model_config = ConfigDict(from_attributes=True, extra="ignore") - contact: ContactModel - organisation: OrgSummary + contact: ContactModel + organisation: OrgSummary class OrgPatchContactResponse(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") + model_config = ConfigDict(from_attributes=True, extra="ignore") - contact: ContactModel - organisation: OrgSummary + contact: ContactModel + organisation: OrgSummary class OrgGetOrgResponse(CustomBaseModel): - organisations: list[OrgSchema] + organisations: list[OrgSchema] diff --git a/src/organisation/schemas_questionnaires.py b/src/organisation/schemas_questionnaires.py index 491dfe5..b257c53 100644 --- a/src/organisation/schemas_questionnaires.py +++ b/src/organisation/schemas_questionnaires.py @@ -4,13 +4,13 @@ from src.schemas import CustomBaseModel class QuestionnaireQuestions(CustomBaseModel): - pass + pass class QuestionnaireQuestionsVersion0(QuestionnaireQuestions): - question_one: Optional[str] = None - question_two: Optional[str] = None - question_three: Optional[str] = None + question_one: Optional[str] = None + question_two: Optional[str] = None + question_three: Optional[str] = None questionnaire_union = QuestionnaireQuestionsVersion0 # | QuestionnaireQuestionsVersion1 diff --git a/src/organisation/service.py b/src/organisation/service.py index 4fc04f6..dffd044 100644 --- a/src/organisation/service.py +++ b/src/organisation/service.py @@ -11,57 +11,57 @@ from src.user.models import User async def add_default_org_permissions( - db: Session, - org_model: Org, - perm_list: list[int], + db: Session, + org_model: Org, + perm_list: list[int], ): - for permission in perm_list: - perm_model = db.get(Perm, permission) + for permission in perm_list: + perm_model = db.get(Perm, permission) - if perm_model is None: - continue + if perm_model is None: + continue - if perm_model in org_model.permission_rel: - continue + if perm_model in org_model.permission_rel: + continue - org_model.permission_rel.append(perm_model) - db.flush() + org_model.permission_rel.append(perm_model) + db.flush() - db.commit() + db.commit() async def assign_defaults( - db: Session, - org_id: int, - user_id: int, + db: Session, + org_id: int, + user_id: int, ): - default_org_permissions = [] + default_org_permissions = [] - default_user_permissions = [] + default_user_permissions = [] - org_model = db.get(Org, org_id) - if org_model is None: - print("Org not found while adding defaults") - return + org_model = db.get(Org, org_id) + if org_model is None: + print("Org not found while adding defaults") + return - user_model = db.get(User, user_id) - if user_model is None: - print("User not found while adding defaults") - return + user_model = db.get(User, user_id) + if user_model is None: + print("User not found while adding defaults") + return - await add_default_org_permissions(db, org_model, default_org_permissions) - await assign_default_group( - db=db, - org_model=org_model, - user_model=user_model, - group_name="Default Users", - perm_list=default_user_permissions, - ) - await assign_default_group( - db=db, - org_model=org_model, - user_model=user_model, - group_name="Root User", - perm_list=default_org_permissions, - ) - db.commit() + await add_default_org_permissions(db, org_model, default_org_permissions) + await assign_default_group( + db=db, + org_model=org_model, + user_model=user_model, + group_name="Default Users", + perm_list=default_user_permissions, + ) + await assign_default_group( + db=db, + org_model=org_model, + user_model=user_model, + group_name="Root User", + perm_list=default_org_permissions, + ) + db.commit() diff --git a/src/schemas.py b/src/schemas.py index 8bb4fa2..33178ce 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -11,54 +11,54 @@ from typing import Optional class CustomBaseModel(BaseModel): - pass + pass ### Mixins ### class OrgIDMixin(CustomBaseModel): - organisation_id: int = Field(gt=0) + organisation_id: int = Field(gt=0) class GroupIDMixin(CustomBaseModel): - group_id: int = Field(gt=0) + group_id: int = Field(gt=0) class PermIDMixin(CustomBaseModel): - permission_id: int = Field(gt=0) + permission_id: int = Field(gt=0) class ServiceIDMixin(CustomBaseModel): - service_id: int = Field(gt=0) + service_id: int = Field(gt=0) class UserIDMixin(CustomBaseModel): - user_id: int = Field(gt=0) + user_id: int = Field(gt=0) class ServiceNameMixin(CustomBaseModel): - service: str + service: str class OrgSummary(CustomBaseModel): - id: int - name: str + id: int + name: str class GroupSummary(CustomBaseModel): - id: int - name: str + id: int + name: str class UserSummary(CustomBaseModel): - id: int - email: str + id: int + email: str class ServiceSummary(CustomBaseModel): - id: int - name: str + id: int + name: str class ResourceName(ServiceNameMixin, OrgIDMixin): - resource: str - instance: Optional[str] = None + resource: str + instance: Optional[str] = None diff --git a/src/service/dependencies.py b/src/service/dependencies.py index ea42c89..5104035 100644 --- a/src/service/dependencies.py +++ b/src/service/dependencies.py @@ -16,25 +16,23 @@ from src.service.models import Service from src.service.schemas import ServiceIDMixin -async def get_service_model_query( - db: DbSession, service_id: Annotated[int, Query(gt=0)] -): - service_model = db.get(Service, service_id) - if service_model is None: - raise ServiceNotFoundException(service_id=service_id) +async def get_service_model_query(db: DbSession, service_id: Annotated[int, Query(gt=0)]): + service_model = db.get(Service, service_id) + if service_model is None: + raise ServiceNotFoundException(service_id=service_id) - return service_model + return service_model service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)] async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin): - service_model = db.get(Service, request_model.service_id) - if service_model is None: - raise ServiceNotFoundException(service_id=request_model.service_id) + service_model = db.get(Service, request_model.service_id) + if service_model is None: + raise ServiceNotFoundException(service_id=request_model.service_id) - return service_model + return service_model service_model_body_dependency = Annotated[Service, Depends(get_service_model_body)] diff --git a/src/service/exceptions.py b/src/service/exceptions.py index 36a927d..bece7d3 100644 --- a/src/service/exceptions.py +++ b/src/service/exceptions.py @@ -11,13 +11,13 @@ from fastapi import HTTPException, status class ServiceNotFoundException(HTTPException): - def __init__(self, service_id: Optional[int] = None) -> None: - detail = ( - "Service not found" - if service_id is None - else f"Service with ID '{service_id}' was not found." - ) - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - ) + def __init__(self, service_id: Optional[int] = None) -> None: + detail = ( + "Service not found" + if service_id is None + else f"Service with ID '{service_id}' was not found." + ) + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) diff --git a/src/service/models.py b/src/service/models.py index de6dcd6..7ab26fb 100644 --- a/src/service/models.py +++ b/src/service/models.py @@ -12,11 +12,11 @@ from src.models import CustomBase, IdMixin class Service(CustomBase, IdMixin): - __tablename__ = "service" + __tablename__ = "service" - name: Mapped[str] = mapped_column(unique=True) - api_key: Mapped[str] + name: Mapped[str] = mapped_column(unique=True) + api_key: Mapped[str] - permission_rel = relationship( - "Permission", back_populates="service_rel", cascade="all, delete-orphan" - ) + permission_rel = relationship( + "Permission", back_populates="service_rel", cascade="all, delete-orphan" + ) diff --git a/src/service/router.py b/src/service/router.py index 54a7a3a..25d4925 100644 --- a/src/service/router.py +++ b/src/service/router.py @@ -15,8 +15,8 @@ from psycopg.errors import UniqueViolation from src.exceptions import ConflictException from src.database import DbSession from src.auth.dependencies import ( - super_admin_dependency, - org_model_root_claim_query_dependency, + super_admin_dependency, + org_model_root_claim_query_dependency, ) from src.iam.service import service_key_dependency from src.iam.models import Permission as Perm @@ -25,212 +25,210 @@ from src.service.exceptions import ServiceNotFoundException from src.service.models import Service from src.service.utils import generate_api_key from src.service.dependencies import ( - service_model_body_dependency, - service_model_query_dependency, + service_model_body_dependency, + service_model_query_dependency, ) from src.service.schemas import ( - ServiceGetServiceResponse, - ServicePostServiceRequest, - ServicePostServiceResponse, - ServiceWithKeySchema, - ServicePatchKeyResponse, - ServicePatchKeyRequest, - ServicePostPermissionsResponse, - ServicePostPermissionsRequest, + ServiceGetServiceResponse, + ServicePostServiceRequest, + ServicePostServiceResponse, + ServiceWithKeySchema, + ServicePatchKeyResponse, + ServicePatchKeyRequest, + ServicePostPermissionsResponse, + ServicePostPermissionsRequest, ) router = APIRouter( - tags=["Service"], - prefix="/service", + tags=["Service"], + prefix="/service", ) @router.get( - "", - summary="Get all services", - status_code=status.HTTP_200_OK, - response_model=ServiceGetServiceResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - status.HTTP_401_UNAUTHORIZED: { - "description": "Unauthorized", - "content": { - "application/json": { - "examples": { - "awaiting_approval": { - "summary": "Organisation has not yet been approved." - }, - } - } - }, - }, - status.HTTP_403_FORBIDDEN: { - "description": "Forbidden", - "content": { - "application/json": { - "examples": { - "not_root": {"summary": "Not authorised. Must be root user."}, - } - } - }, - }, - }, + "", + summary="Get all services", + status_code=status.HTTP_200_OK, + response_model=ServiceGetServiceResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + status.HTTP_401_UNAUTHORIZED: { + "description": "Unauthorized", + "content": { + "application/json": { + "examples": { + "awaiting_approval": { + "summary": "Organisation has not yet been approved." + }, + } + } + }, + }, + status.HTTP_403_FORBIDDEN: { + "description": "Forbidden", + "content": { + "application/json": { + "examples": { + "not_root": {"summary": "Not authorised. Must be root user."}, + } + } + }, + }, + }, ) -async def get_all_services( - db: DbSession, org_model: org_model_root_claim_query_dependency -): - """ - Returns the ID and name of all services registered to the hub. - """ - permission_models = db.query(Service).all() +async def get_all_services(db: DbSession, org_model: org_model_root_claim_query_dependency): + """ + Returns the ID and name of all services registered to the hub. + """ + permission_models = db.query(Service).all() - return {"services": permission_models} + return {"services": permission_models} @router.post( - "", - summary="Register a new service.", - status_code=status.HTTP_200_OK, - response_model=ServicePostServiceResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully registered a new service"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"}, - }, + "", + summary="Register a new service.", + status_code=status.HTTP_200_OK, + response_model=ServicePostServiceResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully registered a new service"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"}, + }, ) async def register_service( - db: DbSession, - su: super_admin_dependency, - request_model: ServicePostServiceRequest, + db: DbSession, + su: super_admin_dependency, + request_model: ServicePostServiceRequest, ): - """ - Registers a new service to the hub, generating and returning an API key for it. - """ - key = generate_api_key() - service_model = Service(name=request_model.name, api_key=key) + """ + Registers a new service to the hub, generating and returning an API key for it. + """ + key = generate_api_key() + service_model = Service(name=request_model.name, api_key=key) - db.add(service_model) - try: - db.flush() - except IntegrityError as e: - if ( - isinstance(e.orig, UniqueViolation) # Postgres unique violation - or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation - ): - raise ConflictException(message="Service with this name already exists") - raise - response = ServiceWithKeySchema(**service_model.__dict__) - db.commit() - return {"service": response} + db.add(service_model) + try: + db.flush() + except IntegrityError as e: + if ( + isinstance(e.orig, UniqueViolation) # Postgres unique violation + or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation + ): + raise ConflictException(message="Service with this name already exists") + raise + response = ServiceWithKeySchema(**service_model.__dict__) + db.commit() + return {"service": response} @router.patch( - "/key", - summary="Regenerate service API key.", - status_code=status.HTTP_200_OK, - response_model=ServicePatchKeyResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful update of API key"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - }, + "/key", + summary="Regenerate service API key.", + status_code=status.HTTP_200_OK, + response_model=ServicePatchKeyResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful update of API key"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }, ) async def regenerate_api_key( - db: DbSession, - su: super_admin_dependency, - service_model: service_model_body_dependency, - request_model: ServicePatchKeyRequest, + db: DbSession, + su: super_admin_dependency, + service_model: service_model_body_dependency, + request_model: ServicePatchKeyRequest, ): - """ - Generates and returns a new API key for the service to access the hub. - """ - key = generate_api_key() - service_model.api_key = key + """ + Generates and returns a new API key for the service to access the hub. + """ + key = generate_api_key() + service_model.api_key = key - db.flush() - response = ServiceWithKeySchema(**service_model.__dict__) - db.commit() - return {"service": response} + db.flush() + response = ServiceWithKeySchema(**service_model.__dict__) + db.commit() + return {"service": response} @router.delete( - "", - summary="Remove a service.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - }, + "", + summary="Remove a service.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }, ) async def remove_service( - db: DbSession, - service_model: service_model_query_dependency, - su: super_admin_dependency, + db: DbSession, + service_model: service_model_query_dependency, + su: super_admin_dependency, ): - """ - Removes a service from the hub. - """ - db.delete(service_model) - db.commit() + """ + Removes a service from the hub. + """ + db.delete(service_model) + db.commit() @router.post( - path="/permissions", - summary="Service endpoint for creating its own permissions.", - status_code=status.HTTP_200_OK, - response_model=ServicePostPermissionsResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "API Key missing or invalid | Issue verifying user OIDC claims" - }, - }, + path="/permissions", + summary="Service endpoint for creating its own permissions.", + status_code=status.HTTP_200_OK, + response_model=ServicePostPermissionsResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "API Key missing or invalid | Issue verifying user OIDC claims" + }, + }, ) async def service_create_new_permissions( - db: DbSession, - request_model: ServicePostPermissionsRequest, - valid_key: service_key_dependency, + db: DbSession, + request_model: ServicePostPermissionsRequest, + valid_key: service_key_dependency, ): - """ - Allows a service to register its own set of permissions. - """ - service_model = ( - db.query(Service).filter(Service.name == request_model.rn.service).first() - ) - if service_model is None: - raise ServiceNotFoundException() - else: - service_id = service_model.id - response_list = [] - for new_permission in request_model.permissions: - perm_model = ( - db.query(Perm) - .filter(Perm.service_id == service_id) - .filter(Perm.resource == new_permission.resource) - .filter(Perm.action == new_permission.action) - .first() - ) - if perm_model is not None: - response_code = 409 - response = { - "id": perm_model.id, - "service_name": perm_model.service_name, - "resource": perm_model.resource, - "action": perm_model.action, - } - response_list.append((response, response_code)) - continue + """ + Allows a service to register its own set of permissions. + """ + service_model = ( + db.query(Service).filter(Service.name == request_model.rn.service).first() + ) + if service_model is None: + raise ServiceNotFoundException() + else: + service_id = service_model.id + response_list = [] + for new_permission in request_model.permissions: + perm_model = ( + db.query(Perm) + .filter(Perm.service_id == service_id) + .filter(Perm.resource == new_permission.resource) + .filter(Perm.action == new_permission.action) + .first() + ) + if perm_model is not None: + response_code = 409 + response = { + "id": perm_model.id, + "service_name": perm_model.service_name, + "resource": perm_model.resource, + "action": perm_model.action, + } + response_list.append((response, response_code)) + continue - new_perm_model = Perm(**new_permission.__dict__) - new_perm_model.service_id = service_id - db.add(new_perm_model) - db.flush() - response_code = 201 - response = { - "id": new_perm_model.id, - "service_name": new_perm_model.service_name, - "resource": new_perm_model.resource, - "action": new_perm_model.action, - } - response_list.append((response, response_code)) + new_perm_model = Perm(**new_permission.__dict__) + new_perm_model.service_id = service_id + db.add(new_perm_model) + db.flush() + response_code = 201 + response = { + "id": new_perm_model.id, + "service_name": new_perm_model.service_name, + "resource": new_perm_model.resource, + "action": new_perm_model.action, + } + response_list.append((response, response_code)) - db.commit() - return {"permissions": response_list} + db.commit() + return {"permissions": response_list} diff --git a/src/service/schemas.py b/src/service/schemas.py index 544dac3..e618d6d 100644 --- a/src/service/schemas.py +++ b/src/service/schemas.py @@ -10,10 +10,10 @@ from typing import Generic, TypeVar from pydantic import Field, ConfigDict from src.schemas import ( - CustomBaseModel, - ServiceIDMixin, - ServiceSummary, - ServiceNameMixin, + CustomBaseModel, + ServiceIDMixin, + ServiceSummary, + ServiceNameMixin, ) @@ -21,51 +21,51 @@ T = TypeVar("T", bound=ServiceNameMixin) class HasServiceName(CustomBaseModel, Generic[T]): - rn: T + rn: T class PermissionResponseSchema(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") + model_config = ConfigDict(from_attributes=True, extra="ignore") - id: int - service_name: str - resource: str - action: str + id: int + service_name: str + resource: str + action: str class PermissionRequestSchema(CustomBaseModel): - resource: str - action: str + resource: str + action: str class ServiceWithKeySchema(ServiceSummary): - api_key: str + api_key: str class ServiceGetServiceResponse(CustomBaseModel): - services: list[ServiceSummary] + services: list[ServiceSummary] class ServicePostServiceRequest(CustomBaseModel): - name: str = Field(min_length=3) + name: str = Field(min_length=3) class ServicePostServiceResponse(CustomBaseModel): - service: ServiceWithKeySchema + service: ServiceWithKeySchema class ServicePatchKeyRequest(ServiceIDMixin): - pass + pass class ServicePatchKeyResponse(CustomBaseModel): - service: ServiceWithKeySchema + service: ServiceWithKeySchema class ServicePostPermissionsRequest(CustomBaseModel): - rn: ServiceNameMixin - permissions: list[PermissionRequestSchema] + rn: ServiceNameMixin + permissions: list[PermissionRequestSchema] class ServicePostPermissionsResponse(CustomBaseModel): - permissions: list[tuple[PermissionResponseSchema, int]] + permissions: list[tuple[PermissionResponseSchema, int]] diff --git a/src/service/utils.py b/src/service/utils.py index 79bb91f..27cfc27 100644 --- a/src/service/utils.py +++ b/src/service/utils.py @@ -9,4 +9,4 @@ import uuid def generate_api_key() -> str: - return str(uuid.uuid4()) + return str(uuid.uuid4()) diff --git a/src/user/dependencies.py b/src/user/dependencies.py index de23693..7518c16 100644 --- a/src/user/dependencies.py +++ b/src/user/dependencies.py @@ -19,37 +19,37 @@ from src.schemas import UserIDMixin async def get_user_model_claims(claims: claims_dependency, db: DbSession): - user_id = claims.get("db_id", None) - if user_id is None: - raise UserNotFoundException() + user_id = claims.get("db_id", None) + if user_id is None: + raise UserNotFoundException() - user_model = db.get(User, user_id) - if user_model is None: - raise UserNotFoundException(user_id=user_id) + user_model = db.get(User, user_id) + if user_model is None: + raise UserNotFoundException(user_id=user_id) - return user_model + return user_model user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)] async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]): - user_model = db.get(User, user_id) - if user_model is None: - raise UserNotFoundException(user_id=user_id) + user_model = db.get(User, user_id) + if user_model is None: + raise UserNotFoundException(user_id=user_id) - return user_model + return user_model user_model_query_dependency = Annotated[User, Depends(get_user_model_query)] async def get_user_model_body(db: DbSession, request_model: UserIDMixin): - user_model = db.get(User, request_model.user_id) - if user_model is None: - raise UserNotFoundException(user_id=request_model.user_id) + user_model = db.get(User, request_model.user_id) + if user_model is None: + raise UserNotFoundException(user_id=request_model.user_id) - return user_model + return user_model user_model_body_dependency = Annotated[User, Depends(get_user_model_body)] diff --git a/src/user/exceptions.py b/src/user/exceptions.py index fa4db2f..9b0509f 100644 --- a/src/user/exceptions.py +++ b/src/user/exceptions.py @@ -11,13 +11,13 @@ from fastapi import HTTPException, status class UserNotFoundException(HTTPException): - def __init__(self, user_id: Optional[int] = None) -> None: - detail = ( - "User not found" - if user_id is None - else f"User with ID '{user_id}' was not found." - ) - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - ) + def __init__(self, user_id: Optional[int] = None) -> None: + detail = ( + "User not found" + if user_id is None + else f"User with ID '{user_id}' was not found." + ) + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) diff --git a/src/user/models.py b/src/user/models.py index 4803d54..333e78a 100644 --- a/src/user/models.py +++ b/src/user/models.py @@ -20,26 +20,26 @@ from src.models import CustomBase class User(CustomBase, IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin): - __tablename__ = "user" + __tablename__ = "user" - email: Mapped[str] - first_name: Mapped[str] - last_name: Mapped[str] - oidc_id: Mapped[str] = mapped_column(index=True, unique=True) + email: Mapped[str] + first_name: Mapped[str] + last_name: Mapped[str] + oidc_id: Mapped[str] = mapped_column(index=True, unique=True) - organisation_rel = relationship( - "Organisation", secondary="orgusers", back_populates="user_rel" - ) + organisation_rel = relationship( + "Organisation", secondary="orgusers", back_populates="user_rel" + ) - group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel") + group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel") - @property - def organisations(self): - return [{"name": org.name, "id": org.id} for org in self.organisation_rel] + @property + def organisations(self): + return [{"name": org.name, "id": org.id} for org in self.organisation_rel] - @property - def groups(self): - result = defaultdict(list) - for group in self.group_rel: - result[group.org_rel.name].append({"name": group.name, "id": group.id}) - return dict(result) + @property + def groups(self): + result = defaultdict(list) + for group in self.group_rel: + result[group.org_rel.name].append({"name": group.name, "id": group.id}) + return dict(result) diff --git a/src/user/router.py b/src/user/router.py index ad0ccb1..6c87e24 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -13,205 +13,205 @@ from fastapi import APIRouter, status, BackgroundTasks from src.iam.models import Group from src.organisation.exceptions import OrgNotFoundException from src.user.schemas import ( - UserResponse, - OIDCClaims, - UserPostInvitationRequest, - UserPostInvitationAcceptRequest, - UserGetSelfOrgsResponse, - UserPostInvitationResponse, - UserPostInvitationAcceptResponse, + UserResponse, + OIDCClaims, + UserPostInvitationRequest, + UserPostInvitationAcceptRequest, + UserGetSelfOrgsResponse, + UserPostInvitationResponse, + UserPostInvitationAcceptResponse, ) from src.user.dependencies import ( - user_model_claims_dependency, - user_model_query_dependency, + user_model_claims_dependency, + user_model_query_dependency, ) from src.user.service import send_invitation from src.organisation.models import Organisation as Org from src.auth.dependencies import ( - super_admin_dependency, - org_model_root_claim_body_dependency, + super_admin_dependency, + org_model_root_claim_body_dependency, ) from src.auth.service import claims_dependency from src.database import DbSession from src.utils import verify_email_token router = APIRouter( - prefix="/user", - tags=["User"], + prefix="/user", + tags=["User"], ) @router.get( - "/self/claims", - summary="Get current user OIDC claims.", - response_model=OIDCClaims, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - }, + "/self/claims", + summary="Get current user OIDC claims.", + response_model=OIDCClaims, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + }, ) async def current_user_claims(user: claims_dependency): - """ - Returns the full OIDC claims associated with the currently logged-in user. - """ - user["allowed_origins"] = user.get("allowed-origins", []) - return user + """ + Returns the full OIDC claims associated with the currently logged-in user. + """ + user["allowed_origins"] = user.get("allowed-origins", []) + return user @router.get( - "/self/db", - summary="Get current user hub details.", - response_model=UserResponse, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - }, + "/self/db", + summary="Get current user hub details.", + response_model=UserResponse, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + }, ) async def current_user(user_model: user_model_claims_dependency): - """ - Returns the database details associated with the currently logged-in user. - """ - return user_model + """ + Returns the database details associated with the currently logged-in user. + """ + return user_model @router.get( - "", - summary="Get user hub details by ID.", - response_model=UserResponse, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - }, + "", + summary="Get user hub details by ID.", + response_model=UserResponse, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + }, ) async def get_user_by_id( - user_model: user_model_query_dependency, su: super_admin_dependency + user_model: user_model_query_dependency, su: super_admin_dependency ): - """ - Returns the database details associated with the provided user ID. - """ - return user_model + """ + Returns the database details associated with the provided user ID. + """ + return user_model @router.delete( - "", - summary="Delete user from hub by ID.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - }, + "", + summary="Delete user from hub by ID.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + }, ) async def delete_user_by_id( - db: DbSession, - user_model: user_model_query_dependency, - su: super_admin_dependency, + db: DbSession, + user_model: user_model_query_dependency, + su: super_admin_dependency, ): - """ - Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login. - """ - db.delete(user_model) - db.commit() + """ + Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login. + """ + db.delete(user_model) + db.commit() @router.get( - "/self/orgs", - summary="Get all orgs the current user is a member of", - status_code=status.HTTP_200_OK, - response_model=UserGetSelfOrgsResponse, - responses={}, + "/self/orgs", + summary="Get all orgs the current user is a member of", + status_code=status.HTTP_200_OK, + response_model=UserGetSelfOrgsResponse, + responses={}, ) async def get_user_orgs(user_model: user_model_claims_dependency): - user_orgs = user_model.organisation_rel - response = [] - for org in user_orgs: - response.append( - { - "organisation_id": org.id, - "name": org.name, - "status": org.status, - "intake_questionnaire": org.intake_questionnaire, - "root_user_email": org.root_user_email, - "billing_contact": { - "id": org.billing_contact_id, - "email": org.billing_contact_rel.email, - }, - "owner_contact": { - "id": org.owner_contact_id, - "email": org.owner_contact_rel.email, - }, - "security_contact": { - "id": org.security_contact_id, - "email": org.security_contact_rel.email, - }, - } - ) + user_orgs = user_model.organisation_rel + response = [] + for org in user_orgs: + response.append( + { + "organisation_id": org.id, + "name": org.name, + "status": org.status, + "intake_questionnaire": org.intake_questionnaire, + "root_user_email": org.root_user_email, + "billing_contact": { + "id": org.billing_contact_id, + "email": org.billing_contact_rel.email, + }, + "owner_contact": { + "id": org.owner_contact_id, + "email": org.owner_contact_rel.email, + }, + "security_contact": { + "id": org.security_contact_id, + "email": org.security_contact_rel.email, + }, + } + ) - return {"organisations": response} + return {"organisations": response} @router.post( - "/invitation", - summary="Send an email invitation for a user to join an org", - status_code=status.HTTP_200_OK, - response_model=UserPostInvitationResponse, + "/invitation", + summary="Send an email invitation for a user to join an org", + status_code=status.HTTP_200_OK, + response_model=UserPostInvitationResponse, ) async def invitation( - background_tasks: BackgroundTasks, - org_model: org_model_root_claim_body_dependency, - request_model: UserPostInvitationRequest, + background_tasks: BackgroundTasks, + org_model: org_model_root_claim_body_dependency, + request_model: UserPostInvitationRequest, ): - org_id = org_model.id - org_name = org_model.name - user_email = request_model.user_email + org_id = org_model.id + org_name = org_model.name + user_email = request_model.user_email - background_tasks.add_task( - send_invitation, org_id=org_id, org_name=org_name, user_email=user_email - ) + background_tasks.add_task( + send_invitation, org_id=org_id, org_name=org_name, user_email=user_email + ) - response = { - "organisation": org_model, - "invited_email": user_email, - } + response = { + "organisation": org_model, + "invited_email": user_email, + } - return response + return response @router.post( - "/invitation/accept", - summary="Accept email invitation to join an org", - status_code=status.HTTP_200_OK, - response_model=UserPostInvitationAcceptResponse, + "/invitation/accept", + summary="Accept email invitation to join an org", + status_code=status.HTTP_200_OK, + response_model=UserPostInvitationAcceptResponse, ) async def accept_invitation( - db: DbSession, - user_model: user_model_claims_dependency, - request_model: UserPostInvitationAcceptRequest, + db: DbSession, + user_model: user_model_claims_dependency, + request_model: UserPostInvitationAcceptRequest, ): - email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model) + email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model) - org_model = db.get(Org, email_claims["org_id"]) - if org_model is None: - raise OrgNotFoundException() + org_model = db.get(Org, email_claims["org_id"]) + if org_model is None: + raise OrgNotFoundException() - org_model.user_rel.append(user_model) - db.flush() - group_model = ( - db.query(Group) - .filter(Group.org_id == org_model.id) - .filter(Group.name == "Default Users") - .first() - ) - if group_model is not None: - user_model.group_rel.append(group_model) + org_model.user_rel.append(user_model) + db.flush() + group_model = ( + db.query(Group) + .filter(Group.org_id == org_model.id) + .filter(Group.name == "Default Users") + .first() + ) + if group_model is not None: + user_model.group_rel.append(group_model) - response = { - "organisation": org_model, - "user": user_model, - } + response = { + "organisation": org_model, + "user": user_model, + } - db.commit() + db.commit() - return response + return response diff --git a/src/user/schemas.py b/src/user/schemas.py index 65ab530..666f91f 100644 --- a/src/user/schemas.py +++ b/src/user/schemas.py @@ -10,63 +10,63 @@ from src.schemas import CustomBaseModel, OrgIDMixin, OrgSummary, UserSummary class OIDCClaims(CustomBaseModel): - exp: int - iat: int - auth_time: int - jti: str - iss: str - aud: str - sub: str - typ: str - azp: str - sid: str - acr: str - allowed_origins: list[str] - realm_access: dict[str, list[str]] - resource_access: dict[str, dict[str, list[str]]] - scope: str - email_verified: bool - name: str - preferred_username: str - given_name: str - family_name: str - email: str - db_id: int + exp: int + iat: int + auth_time: int + jti: str + iss: str + aud: str + sub: str + typ: str + azp: str + sid: str + acr: str + allowed_origins: list[str] + realm_access: dict[str, list[str]] + resource_access: dict[str, dict[str, list[str]]] + scope: str + email_verified: bool + name: str + preferred_username: str + given_name: str + family_name: str + email: str + db_id: int class OIDCUser(CustomBaseModel): - first_name: str - last_name: str - email: str - oidc_id: str + first_name: str + last_name: str + email: str + oidc_id: str class UserResponse(CustomBaseModel): - id: int - first_name: str - last_name: str - email: str - organisations: list[Optional[dict[str, str | int]]] - groups: Optional[dict[str, list[dict[str, str | int]]]] = None + id: int + first_name: str + last_name: str + email: str + organisations: list[Optional[dict[str, str | int]]] + groups: Optional[dict[str, list[dict[str, str | int]]]] = None class UserPostInvitationRequest(OrgIDMixin): - user_email: EmailStr + user_email: EmailStr class UserPostInvitationAcceptRequest(CustomBaseModel): - jwt: str + jwt: str class UserGetSelfOrgsResponse(CustomBaseModel): - organisations: list[OrgSchema] + organisations: list[OrgSchema] class UserPostInvitationResponse(CustomBaseModel): - organisation: OrgSummary - invited_email: EmailStr + organisation: OrgSummary + invited_email: EmailStr class UserPostInvitationAcceptResponse(CustomBaseModel): - organisation: OrgSummary - user: UserSummary + organisation: OrgSummary + user: UserSummary diff --git a/src/user/service.py b/src/user/service.py index 8721b88..c450cbc 100644 --- a/src/user/service.py +++ b/src/user/service.py @@ -16,49 +16,49 @@ from src.user.models import User async def add_user(db: Session, user_claims: dict[str, Any]) -> int: - try: - valid_user = OIDCUser( - first_name=user_claims["given_name"], - last_name=user_claims["family_name"], - email=user_claims["email"], - oidc_id=user_claims["sub"], - ) - except Exception as e: - logging.exception(e) - raise UnprocessableContentException("Invalid or missing OIDC data") + try: + valid_user = OIDCUser( + first_name=user_claims["given_name"], + last_name=user_claims["family_name"], + email=user_claims["email"], + oidc_id=user_claims["sub"], + ) + except Exception as e: + logging.exception(e) + raise UnprocessableContentException("Invalid or missing OIDC data") - db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first() + db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first() - if not db_user: - user_model = User(**valid_user.model_dump()) - db.add(user_model) - user_id = user_model.id - db.commit() - return user_id + if not db_user: + user_model = User(**valid_user.model_dump()) + db.add(user_model) + user_id = user_model.id + db.commit() + return user_id - user_id = db_user.id - db_user.first_name = valid_user.first_name - db_user.last_name = valid_user.last_name - db.commit() - return user_id + user_id = db_user.id + db_user.first_name = valid_user.first_name + db_user.last_name = valid_user.last_name + db.commit() + return user_id async def send_invitation(user_email: str, org_name: str, org_id: int): - expiry_delta = timedelta(hours=24) - expiry = datetime.now(timezone.utc) + expiry_delta - claims = { - "email": user_email, - "org_id": org_id, - "exp": expiry, - "type": "org_invite", - } + expiry_delta = timedelta(hours=24) + expiry = datetime.now(timezone.utc) + expiry_delta + claims = { + "email": user_email, + "org_id": org_id, + "exp": expiry, + "type": "org_invite", + } - token = await generate_jwt(claims) - subject = f"You have been invited to join {org_name}" - body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}" + token = await generate_jwt(claims) + subject = f"You have been invited to join {org_name}" + body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}" - await send_email( - recipient=user_email, - subject=subject, - body=body, - ) + await send_email( + recipient=user_email, + subject=subject, + body=body, + ) diff --git a/src/utils.py b/src/utils.py index ff8cb89..27f5d90 100644 --- a/src/utils.py +++ b/src/utils.py @@ -11,52 +11,56 @@ KEY = jwk.import_key(settings.SECRET_KEY.get_secret_value(), "oct") async def generate_jwt(claims): - jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims) + jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims) - return jwt_token + return jwt_token async def decode_jwt(encoded): - try: - token = jwt.decode(encoded, key=KEY) - return token.claims - except errors.DecodeError: - raise UnauthorizedException("Invalid JWS") + try: + token = jwt.decode(encoded, key=KEY) + return token.claims + except errors.DecodeError: + raise UnauthorizedException("Invalid JWS") async def verify_email_token(user_model, token): - email_claims = await decode_jwt(token) + email_claims = await decode_jwt(token) - claimed_email = email_claims["email"] + claimed_email = email_claims["email"] - expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc) + expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc) - if expiry < datetime.now(timezone.utc): - raise UnauthorizedException("Invitation expired.") + if expiry < datetime.now(timezone.utc): + raise UnauthorizedException("Invitation expired.") - if user_model.email != claimed_email: - raise ForbiddenException("The logged in user and email do not match.") + if user_model.email != claimed_email: + raise ForbiddenException("The logged in user and email do not match.") - return email_claims + return email_claims async def send_email(recipient: str, subject: str, body: str): - if settings.ENVIRONMENT.is_testing: - return + if settings.ENVIRONMENT.is_testing: + return - lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value()) + lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value()) - if settings.ENVIRONMENT == "local": - recipient = "ok@testing.lettermint.co" + if settings.ENVIRONMENT == "local": + recipient = "ok@testing.lettermint.co" - try: - response = ( - lettermint.email.from_("noreply@sr2.uk") - .to(recipient) - .subject(subject) - .text(body) - .send() - ) - logging.info("Email sent to {} with subject {} (Status: {})".format(recipient, subject, response.status_code)) - except ValidationError as e: - logging.exception(e) + try: + response = ( + lettermint.email.from_("noreply@sr2.uk") + .to(recipient) + .subject(subject) + .text(body) + .send() + ) + logging.info( + "Email sent to {} with subject {} (Status: {})".format( + recipient, subject, response.status_code + ) + ) + except ValidationError as e: + logging.exception(e) diff --git a/test/conftest.py b/test/conftest.py index 6821b9d..022b8dc 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -22,15 +22,15 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @pytest.fixture() def db_session(): - CustomBase.metadata.drop_all(bind=engine) - CustomBase.metadata.create_all(bind=engine) - db = SessionLocal() - try: - _seed(db) # extracted seeding logic into a plain function - yield db - finally: - db.rollback() - db.close() + CustomBase.metadata.drop_all(bind=engine) + CustomBase.metadata.create_all(bind=engine) + db = SessionLocal() + try: + _seed(db) # extracted seeding logic into a plain function + yield db + finally: + db.rollback() + db.close() @pytest.fixture @@ -83,176 +83,176 @@ async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]: def _seed(db): - db.add( - User( - email="admin@test.com", - first_name="Admin", - last_name="Test", - oidc_id="abcd-efgh-ijkl-mnop", - ) - ) - db.add( - User( - email="user@orgone.com", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-qwer", - ) - ) - db.add( - User( - email="root@orgtwo.com", - first_name="Root", - last_name="Test", - oidc_id="abcd-efgh-ijkl-hjkl", - ) - ) - db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927")) - db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927")) - db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927")) - db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927")) - db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927")) - db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927")) - db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927")) - db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927")) - db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927")) - db.flush() - db.add( - Org( - name="Org One", - root_user_id=1, - billing_contact_id=1, - owner_contact_id=2, - security_contact_id=3, - status="approved", - intake_questionnaire={ - "metadata": {"version": 0, "submission_date": None}, - "questions": {"question_two": "answer two"}, - }, - ) - ) - db.add( - Org( - name="Org Two", - root_user_id=3, - billing_contact_id=4, - owner_contact_id=5, - security_contact_id=6, - status="approved", - intake_questionnaire={ - "metadata": {"version": 0, "submission_date": None}, - "questions": {"question_two": "answer two"}, - }, - ) - ) - db.add( - Org( - name="Org Three", - root_user_id=1, - billing_contact_id=7, - owner_contact_id=8, - security_contact_id=9, - status="partial", - intake_questionnaire={ - "metadata": {"version": 0, "submission_date": None}, - "questions": {"question_two": "answer two"}, - }, - ) - ) - db.add(OrgUsers(org_id=1, user_id=2)) - db.add(Service(name="Test Service", api_key="123456789")) - db.add(Permission(service_id=1, resource="test_resource", action="read")) - db.add(Permission(service_id=1, resource="test_resource", action="move")) - db.add(Permission(service_id=1, resource="test_resource", action="delete")) - db.add(OrgPermissions(org_id=1, permission_id=1)) - db.add(OrgPermissions(org_id=1, permission_id=2)) - db.add(Group(name="Org One Group", org_id=1)) - db.add(Group(name="Org Two Group", org_id=2)) - db.add(Group(name="Org One Group Two", org_id=1)) - db.flush() - group_model = db.get(Group, 1) - perm_model = db.get(Permission, 1) - group_model.permission_rel.append(perm_model) - user_model = db.get(User, 1) - org_model = db.get(Org, 1) - org_model.user_rel.append(user_model) - org_model.group_rel.append(group_model) - db.flush() - group_model.user_rel.append(user_model) - db.commit() + db.add( + User( + email="admin@test.com", + first_name="Admin", + last_name="Test", + oidc_id="abcd-efgh-ijkl-mnop", + ) + ) + db.add( + User( + email="user@orgone.com", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-qwer", + ) + ) + db.add( + User( + email="root@orgtwo.com", + first_name="Root", + last_name="Test", + oidc_id="abcd-efgh-ijkl-hjkl", + ) + ) + db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927")) + db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927")) + db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927")) + db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927")) + db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927")) + db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927")) + db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927")) + db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927")) + db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927")) + db.flush() + db.add( + Org( + name="Org One", + root_user_id=1, + billing_contact_id=1, + owner_contact_id=2, + security_contact_id=3, + status="approved", + intake_questionnaire={ + "metadata": {"version": 0, "submission_date": None}, + "questions": {"question_two": "answer two"}, + }, + ) + ) + db.add( + Org( + name="Org Two", + root_user_id=3, + billing_contact_id=4, + owner_contact_id=5, + security_contact_id=6, + status="approved", + intake_questionnaire={ + "metadata": {"version": 0, "submission_date": None}, + "questions": {"question_two": "answer two"}, + }, + ) + ) + db.add( + Org( + name="Org Three", + root_user_id=1, + billing_contact_id=7, + owner_contact_id=8, + security_contact_id=9, + status="partial", + intake_questionnaire={ + "metadata": {"version": 0, "submission_date": None}, + "questions": {"question_two": "answer two"}, + }, + ) + ) + db.add(OrgUsers(org_id=1, user_id=2)) + db.add(Service(name="Test Service", api_key="123456789")) + db.add(Permission(service_id=1, resource="test_resource", action="read")) + db.add(Permission(service_id=1, resource="test_resource", action="move")) + db.add(Permission(service_id=1, resource="test_resource", action="delete")) + db.add(OrgPermissions(org_id=1, permission_id=1)) + db.add(OrgPermissions(org_id=1, permission_id=2)) + db.add(Group(name="Org One Group", org_id=1)) + db.add(Group(name="Org Two Group", org_id=2)) + db.add(Group(name="Org One Group Two", org_id=1)) + db.flush() + group_model = db.get(Group, 1) + perm_model = db.get(Permission, 1) + group_model.permission_rel.append(perm_model) + user_model = db.get(User, 1) + org_model = db.get(Org, 1) + org_model.user_rel.append(user_model) + org_model.group_rel.append(group_model) + db.flush() + group_model.user_rel.append(user_model) + db.commit() def generate_query_and_status(params) -> list[tuple[str, int]]: - possible_values = [0, -1, 42, "banana", ""] + possible_values = [0, -1, 42, "banana", ""] - defaults = [f"{param}=1" for param in params] + defaults = [f"{param}=1" for param in params] - # Missing params - query_list = [ - "&".join(combo) - for r in range(len(defaults) + 1) - for combo in combinations(defaults, r) - ] + # Missing params + query_list = [ + "&".join(combo) + for r in range(len(defaults) + 1) + for combo in combinations(defaults, r) + ] - # Complete query as default for invalid checks - default_query = query_list.pop(-1) + # Complete query as default for invalid checks + default_query = query_list.pop(-1) - # Checks for each param being invalid - for param in params: - for value in possible_values: - new_value = f"&{param}={value}" - query_list.append(default_query.replace(f"{param}=1", new_value)) + # Checks for each param being invalid + for param in params: + for value in possible_values: + new_value = f"&{param}={value}" + query_list.append(default_query.replace(f"{param}=1", new_value)) - query_and_status = [] + query_and_status = [] - # Assign expected status - for query in query_list: - # ID 42 is used to represent a non-existent entry. So it should 404. - status = 404 if "42" in query else 422 - # Remove leading "&" if present - query = query if len(query) > 1 and query[0] != "&" else query[1:] - query_and_status.append((query, status)) + # Assign expected status + for query in query_list: + # ID 42 is used to represent a non-existent entry. So it should 404. + status = 404 if "42" in query else 422 + # Remove leading "&" if present + query = query if len(query) > 1 and query[0] != "&" else query[1:] + query_and_status.append((query, status)) - return query_and_status + return query_and_status def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]: - possible_values_int = [0, -1, 42, "banana", ""] - possible_values_str = [0, "", "a"] + possible_values_int = [0, -1, 42, "banana", ""] + possible_values_str = [0, "", "a"] - defaults = [{param: 1 for param in params.keys()}] + defaults = [{param: 1 for param in params.keys()}] - # Missing params - body_list = [ - {key: ("valid string" if params[key] == "str" else 1) for key in combo} - for r in range(len(defaults[0].keys()) + 1) - for combo in combinations(defaults[0].keys(), r) - ] + # Missing params + body_list = [ + {key: ("valid string" if params[key] == "str" else 1) for key in combo} + for r in range(len(defaults[0].keys()) + 1) + for combo in combinations(defaults[0].keys(), r) + ] - # Complete body as default for generating invalid checks - default_body = body_list.pop(-1) + # Complete body as default for generating invalid checks + default_body = body_list.pop(-1) - # Generates checks for each param being invalid - for param, typ in params.items(): - if typ == "int": - possible_values = possible_values_int - elif typ == "str": - possible_values = possible_values_str - else: - raise TypeError(f"Unknown type {typ}") - for value in possible_values: - new_record = default_body.copy() - new_record[param] = value - body_list.append(new_record) + # Generates checks for each param being invalid + for param, typ in params.items(): + if typ == "int": + possible_values = possible_values_int + elif typ == "str": + possible_values = possible_values_str + else: + raise TypeError(f"Unknown type {typ}") + for value in possible_values: + new_record = default_body.copy() + new_record[param] = value + body_list.append(new_record) - body_and_status = [] + body_and_status = [] - # Assign expected status - for body in body_list: - # ID 42 is used to represent a non-existent entry. So it should 404. - status = 404 if 42 in body.values() else 422 - body_and_status.append((body, status)) - return body_and_status + # Assign expected status + for body in body_list: + # ID 42 is used to represent a non-existent entry. So it should 404. + status = 404 if 42 in body.values() else 422 + body_and_status.append((body, status)) + return body_and_status def get_testable_routes(): diff --git a/test/test_auth_approval.py b/test/test_auth_approval.py index 42d0cd9..a23e410 100644 --- a/test/test_auth_approval.py +++ b/test/test_auth_approval.py @@ -8,181 +8,181 @@ import pytest from httpx import AsyncClient pytestmark = [ - pytest.mark.auth, - pytest.mark.preapproval, + pytest.mark.auth, + pytest.mark.preapproval, ] @pytest.mark.anyio async def test_get_org_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/org?org_id=3") - assert resp.status_code != 422 - assert resp.status_code == 200 + resp = await no_su_client.get("/org?org_id=3") + assert resp.status_code != 422 + assert resp.status_code == 200 @pytest.mark.anyio async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/questionnaire", - json={ - "organisation_id": 3, - "intake_questionnaire": { - "question_one": "new answer one", - "question_two": None, - "question_three": None, - }, - "partial": True, - }, - ) - assert resp.status_code != 422 - assert resp.status_code == 200 + resp = await no_su_client.patch( + "/org/questionnaire", + json={ + "organisation_id": 3, + "intake_questionnaire": { + "question_one": "new answer one", + "question_two": None, + "question_three": None, + }, + "partial": True, + }, + ) + assert resp.status_code != 422 + assert resp.status_code == 200 @pytest.mark.anyio async def test_get_org_users_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/users?org_id=3") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.get("/org/users?org_id=3") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_get_org_groups_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/groups?org_id=3") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.get("/org/groups?org_id=3") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_get_org_contact_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing") - assert resp.status_code != 422 - assert resp.status_code == 200 + resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing") + assert resp.status_code != 422 + assert resp.status_code == 200 @pytest.mark.anyio async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/contact", - json={ - "organisation_id": 3, - "contact_type": "billing", - "email": "user@example.com", - }, - ) - assert resp.status_code != 422 - assert resp.status_code == 200 + resp = await no_su_client.patch( + "/org/contact", + json={ + "organisation_id": 3, + "contact_type": "billing", + "email": "user@example.com", + }, + ) + assert resp.status_code != 422 + assert resp.status_code == 200 @pytest.mark.anyio async def test_get_service_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/service?org_id=3") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.get("/service?org_id=3") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_get_iam_group_permissions_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_get_iam_group_users_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_post_iam_group_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.post( - "/iam/group", json={"name": "New Group", "organisation_id": 3} - ) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.post( + "/iam/group", json={"name": "New Group", "organisation_id": 3} + ) + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.put( - "/iam/group/permission", - json={"permission_id": 1, "group_id": 2, "organisation_id": 3}, - ) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.put( + "/iam/group/permission", + json={"permission_id": 1, "group_id": 2, "organisation_id": 3}, + ) + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.put( - "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3} - ) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.put( + "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3} + ) + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_get_iam_permissions_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/permissions?org_id=3") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.get("/iam/permissions?org_id=3") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_post_iam_permissions_search_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.post( - "/iam/permissions/search", json={"organisation_id": 3, "action": "read"} - ) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.post( + "/iam/permissions/search", json={"organisation_id": 3, "action": "read"} + ) + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_delete_org_user_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.delete("/org/user?org_id=3&user_id=1") + resp = await no_su_client.delete("/org/user?org_id=3&user_id=1") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_delete_preapproval_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.delete("/org/self?org_id=3") + resp = await no_su_client.delete("/org/self?org_id=3") - assert resp.status_code != 422 - assert resp.status_code == 204 + assert resp.status_code != 422 + assert resp.status_code == 204 @pytest.mark.anyio async def test_post_user_invitation_auth_approval(no_su_client: AsyncClient): - body = {"user_email": "admin@test.com", "organisation_id": 3} - resp = await no_su_client.post("/user/invitation", json=body) + body = {"user_email": "admin@test.com", "organisation_id": 3} + resp = await no_su_client.post("/user/invitation", json=body) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_delete_group_permissions_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1") + resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_delete_group_users_success(no_su_client: AsyncClient): - resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1") + resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_put_group_user_invitation_success(no_su_client: AsyncClient): - body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1} - resp = await no_su_client.put("/iam/group/user/invitation", json=body) + body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1} + resp = await no_su_client.put("/iam/group/user/invitation", json=body) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] diff --git a/test/test_auth_general.py b/test/test_auth_general.py index 0599c33..543af7a 100644 --- a/test/test_auth_general.py +++ b/test/test_auth_general.py @@ -5,14 +5,14 @@ from httpx import AsyncClient pytestmark = [ - pytest.mark.auth, + pytest.mark.auth, ] @pytest.mark.anyio async def test_get_org_auth_root_su(default_client: AsyncClient): - # If a super admin can access a resource when not the root user - resp = await default_client.get("/org?org_id=2") - assert resp.status_code != 422 - assert resp.status_code == 200 - assert resp.json()["organisations"][0]["name"] == "Org Two" + # If a super admin can access a resource when not the root user + resp = await default_client.get("/org?org_id=2") + assert resp.status_code != 422 + assert resp.status_code == 200 + assert resp.json()["organisations"][0]["name"] == "Org Two" diff --git a/test/test_auth_root.py b/test/test_auth_root.py index d28d6b5..429bf4b 100644 --- a/test/test_auth_root.py +++ b/test/test_auth_root.py @@ -7,147 +7,147 @@ import pytest from httpx import AsyncClient pytestmark = [ - pytest.mark.auth, - pytest.mark.root_user, + pytest.mark.auth, + pytest.mark.root_user, ] @pytest.mark.anyio async def test_get_org_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/org?org_id=2") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/org?org_id=2") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/questionnaire", - json={ - "organisation_id": 2, - "intake_questionnaire": { - "question_one": "new answer one", - "question_two": None, - "question_three": None, - }, - "partial": True, - }, - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.patch( + "/org/questionnaire", + json={ + "organisation_id": 2, + "intake_questionnaire": { + "question_one": "new answer one", + "question_two": None, + "question_three": None, + }, + "partial": True, + }, + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_org_users_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/users?org_id=2") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/org/users?org_id=2") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_org_groups_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/groups?org_id=2") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/org/groups?org_id=2") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_org_contact_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_patch_org_contact_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/contact", - json={ - "organisation_id": 2, - "contact_type": "billing", - "email": "user@example.com", - }, - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.patch( + "/org/contact", + json={ + "organisation_id": 2, + "contact_type": "billing", + "email": "user@example.com", + }, + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_service_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/service?org_id=2") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/service?org_id=2") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_iam_group_permissions_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_post_iam_group_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.post( - "/iam/group", json={"name": "New Group", "organisation_id": 2} - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.post( + "/iam/group", json={"name": "New Group", "organisation_id": 2} + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.put( - "/iam/group/permission", - json={"permission_id": 1, "group_id": 2, "organisation_id": 2}, - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.put( + "/iam/group/permission", + json={"permission_id": 1, "group_id": 2, "organisation_id": 2}, + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_put_iam_group_user_auth_root( - no_su_client: AsyncClient, + no_su_client: AsyncClient, ): - resp = await no_su_client.put( - "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2} - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.put( + "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2} + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/permissions?org_id=2") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/iam/permissions?org_id=2") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.post( - "/iam/permissions/search", json={"organisation_id": 2, "action": "read"} - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.post( + "/iam/permissions/search", json={"organisation_id": 2, "action": "read"} + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] diff --git a/test/test_auth_su.py b/test/test_auth_su.py index 09ce558..f0136bf 100644 --- a/test/test_auth_su.py +++ b/test/test_auth_su.py @@ -7,69 +7,69 @@ import pytest from httpx import AsyncClient pytestmark = [ - pytest.mark.auth, - pytest.mark.super_admin, + pytest.mark.auth, + pytest.mark.super_admin, ] @pytest.mark.anyio async def test_get_user_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.get("/user?user_id=1") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert resp.json()["detail"] == "Must be super admin" + resp = await no_su_client.get("/user?user_id=1") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert resp.json()["detail"] == "Must be super admin" @pytest.mark.anyio async def test_patch_org_status_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/status", json={"organisation_id": 1, "status": "submitted"} - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert resp.json()["detail"] == "Must be super admin" + resp = await no_su_client.patch( + "/org/status", json={"organisation_id": 1, "status": "submitted"} + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert resp.json()["detail"] == "Must be super admin" @pytest.mark.anyio async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/root_user", json={"organisation_id": 1, "user_id": 2} - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert resp.json()["detail"] == "Must be super admin" + resp = await no_su_client.patch( + "/org/root_user", json={"organisation_id": 1, "user_id": 2} + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert resp.json()["detail"] == "Must be super admin" @pytest.mark.anyio async def test_patch_service_key_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.patch("/service/key", json={"service_id": 1}) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert resp.json()["detail"] == "Must be super admin" + resp = await no_su_client.patch("/service/key", json={"service_id": 1}) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert resp.json()["detail"] == "Must be super admin" @pytest.mark.anyio async def test_post_service_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.post("/service", json={"name": "New Test Service"}) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert resp.json()["detail"] == "Must be super admin" + resp = await no_su_client.post("/service", json={"name": "New Test Service"}) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert resp.json()["detail"] == "Must be super admin" @pytest.mark.anyio async def test_post_perm_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.post( - "/iam/permission", - json={"service_id": 1, "resource": "test_resource", "action": "create"}, - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert resp.json()["detail"] == "Must be super admin" + resp = await no_su_client.post( + "/iam/permission", + json={"service_id": 1, "resource": "test_resource", "action": "create"}, + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert resp.json()["detail"] == "Must be super admin" @pytest.mark.anyio async def test_post_org_user_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2}) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be super admin" in resp.json()["detail"] + resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2}) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be super admin" in resp.json()["detail"] diff --git a/test/test_auth_user.py b/test/test_auth_user.py index 99f5579..3d2f6ce 100644 --- a/test/test_auth_user.py +++ b/test/test_auth_user.py @@ -7,22 +7,22 @@ from httpx import AsyncClient pytestmark = [ - pytest.mark.auth, - pytest.mark.user, + pytest.mark.auth, + pytest.mark.user, ] @pytest.mark.anyio async def test_get_self_db_auth_user(no_user_client: AsyncClient): - resp = await no_user_client.get("/user/self/db") - assert resp.status_code != 422 - assert resp.status_code == 401 - assert resp.json()["detail"] == "Not authenticated" + resp = await no_user_client.get("/user/self/db") + assert resp.status_code != 422 + assert resp.status_code == 401 + assert resp.json()["detail"] == "Not authenticated" @pytest.mark.anyio async def test_post_org_success_auth_user(no_user_client: AsyncClient): - resp = await no_user_client.post("/org", json={"name": "New Test Org"}) - assert resp.status_code != 422 - assert resp.status_code == 401 - assert resp.json()["detail"] == "Not authenticated" + resp = await no_user_client.post("/org", json={"name": "New Test Org"}) + assert resp.status_code != 422 + assert resp.status_code == 401 + assert resp.json()["detail"] == "Not authenticated" diff --git a/test/test_healthcheck.py b/test/test_healthcheck.py index 47a3993..6fdb9be 100644 --- a/test/test_healthcheck.py +++ b/test/test_healthcheck.py @@ -4,7 +4,7 @@ from httpx import AsyncClient @pytest.mark.anyio async def test_healthcheck(default_client: AsyncClient): - resp = await default_client.get("/healthcheck") + resp = await default_client.get("/healthcheck") - assert resp.status_code == 200 - assert resp.json() == {"status": "ok"} + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} diff --git a/test/test_iam.py b/test/test_iam.py index ad45543..c890925 100644 --- a/test/test_iam.py +++ b/test/test_iam.py @@ -6,747 +6,747 @@ from httpx import AsyncClient from .conftest import generate_query_and_status, generate_body_and_status pytestmark = [ - pytest.mark.iam_module, + pytest.mark.iam_module, ] @pytest.mark.anyio async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient): - body = { - "rn": { - "service": "Test Service", - "organisation_id": 1, - "resource": "test_resource", - "instance": None, - }, - "action": "read", - } - headers = { - "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": "123456789", - } - resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) - data = resp.json() + body = { + "rn": { + "service": "Test Service", + "organisation_id": 1, + "resource": "test_resource", + "instance": None, + }, + "action": "read", + } + headers = { + "Authorization": "Bearer not_checked_when_auth_is_disabled", + "X-API-Key": "123456789", + } + resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) + data = resp.json() - assert resp.status_code == 200 - assert data["allowed"] is True + assert resp.status_code == 200 + assert data["allowed"] is True - print(data) + print(data) @pytest.mark.parametrize( - "service, api_key", - [("Test Service", "not_the_correct_key"), ("Test Service Two", "123456789")], + "service, api_key", + [("Test Service", "not_the_correct_key"), ("Test Service Two", "123456789")], ) @pytest.mark.anyio async def test_act_on_resource_wrong_key( - default_client: AsyncClient, service: str, api_key: str + default_client: AsyncClient, service: str, api_key: str ): - body = { - "rn": { - "service": service, - "organisation": "Test Org", - "resource": "test_resource", - }, - "action": "read", - } - headers = { - "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": api_key, - } - resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) - data = resp.json() + body = { + "rn": { + "service": service, + "organisation": "Test Org", + "resource": "test_resource", + }, + "action": "read", + } + headers = { + "Authorization": "Bearer not_checked_when_auth_is_disabled", + "X-API-Key": api_key, + } + resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) + data = resp.json() - assert resp.status_code == 401 - assert data["detail"] == "Invalid API key" + assert resp.status_code == 401 + assert data["detail"] == "Invalid API key" @pytest.mark.anyio async def test_act_on_resource_missing_key(default_client: AsyncClient): - body = { - "rn": { - "service": "Test Service", - "organisation": "Test Org", - "resource": "test_resource", - }, - "action": "read", - } - headers = {"Authorization": "Bearer not_checked_when_auth_is_disabled"} - resp = await default_client.post( - "/iam/can_act_on_resource?action=read", json=body, headers=headers - ) - data = resp.json() + body = { + "rn": { + "service": "Test Service", + "organisation": "Test Org", + "resource": "test_resource", + }, + "action": "read", + } + headers = {"Authorization": "Bearer not_checked_when_auth_is_disabled"} + resp = await default_client.post( + "/iam/can_act_on_resource?action=read", json=body, headers=headers + ) + data = resp.json() - assert resp.status_code == 401 - assert data["detail"] == "Missing API key" + assert resp.status_code == 401 + assert data["detail"] == "Missing API key" @pytest.mark.parametrize( - "service, org, resource, action, expected_status", - [ - (None, "Test Org", "test_resource", "read", 422), - (42, "Test Org", "test_resource", "read", 422), - ("Test Service", None, "test_resource", "read", 422), - ("Test Service", 42, "test_resource", "read", 422), - ("Test Service", "Test Org", None, "read", 422), - ("Test Service", "Test Org", 42, "read", 422), - ("Test Service", "Test Org", "test_resource", None, 422), - ("Test Service", "Test Org", "test_resource", 42, 422), - ], + "service, org, resource, action, expected_status", + [ + (None, "Test Org", "test_resource", "read", 422), + (42, "Test Org", "test_resource", "read", 422), + ("Test Service", None, "test_resource", "read", 422), + ("Test Service", 42, "test_resource", "read", 422), + ("Test Service", "Test Org", None, "read", 422), + ("Test Service", "Test Org", 42, "read", 422), + ("Test Service", "Test Org", "test_resource", None, 422), + ("Test Service", "Test Org", "test_resource", 42, 422), + ], ) @pytest.mark.anyio async def test_act_on_resource_endpoint_status_checks( - default_client: AsyncClient, service, org, resource, action, expected_status: int + default_client: AsyncClient, service, org, resource, action, expected_status: int ): - body = { - "rn": {"service": service, "organisation": org, "resource": resource}, - "action": action, - } - headers = { - "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": "123456789", - } - resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) + body = { + "rn": {"service": service, "organisation": org, "resource": resource}, + "action": action, + } + headers = { + "Authorization": "Bearer not_checked_when_auth_is_disabled", + "X-API-Key": "123456789", + } + resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "service, org, resource, action, expected_response", - [ - ("Test Service", 1, "test_resource", "read", True), - ("Test Service", 1, "test_resource", "create", False), - ("Test Service", 1, "no_access_here", "read", False), - ("Test Service", 2, "test_resource", "read", False), - ], + "service, org, resource, action, expected_response", + [ + ("Test Service", 1, "test_resource", "read", True), + ("Test Service", 1, "test_resource", "create", False), + ("Test Service", 1, "no_access_here", "read", False), + ("Test Service", 2, "test_resource", "read", False), + ], ) @pytest.mark.anyio async def test_act_on_resource_logic( - default_client: AsyncClient, - service, - org, - resource, - action, - expected_response: bool, + default_client: AsyncClient, + service, + org, + resource, + action, + expected_response: bool, ): - body = { - "rn": {"service": service, "organisation_id": org, "resource": resource}, - "action": action, - } - headers = { - "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": "123456789", - } - resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) - data = resp.json() + body = { + "rn": {"service": service, "organisation_id": org, "resource": resource}, + "action": action, + } + headers = { + "Authorization": "Bearer not_checked_when_auth_is_disabled", + "X-API-Key": "123456789", + } + resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) + data = resp.json() - assert resp.status_code == 200 - assert data["allowed"] == expected_response + assert resp.status_code == 200 + assert data["allowed"] == expected_response @pytest.mark.anyio async def test_get_group_permissions_success(default_client: AsyncClient): - resp = await default_client.get("/iam/group/permissions?org_id=1&group_id=1") - assert resp.status_code == 200 + resp = await default_client.get("/iam/group/permissions?org_id=1&group_id=1") + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "permissions" in data - assert isinstance(data["permissions"], list) + assert "permissions" in data + assert isinstance(data["permissions"], list) - permission = data["permissions"][0] - assert permission["id"] == 1 - assert permission["service_name"] == "Test Service" - assert permission["resource"] == "test_resource" - assert permission["action"] == "read" + permission = data["permissions"][0] + assert permission["id"] == 1 + assert permission["service_name"] == "Test Service" + assert permission["resource"] == "test_resource" + assert permission["action"] == "read" @pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["group_id", "org_id"]) + "query, expected_status", generate_query_and_status(["group_id", "org_id"]) ) @pytest.mark.anyio async def test_get_group_permissions_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/iam/group/permissions?{query}") + resp = await default_client.get(f"/iam/group/permissions?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "query", - [ - "org_id=1&group_id=2", - "org_id=2&group_id=1", - ], + "query", + [ + "org_id=1&group_id=2", + "org_id=2&group_id=1", + ], ) @pytest.mark.anyio async def test_get_group_permissions_mismatch(default_client: AsyncClient, query: str): - resp = await default_client.get(f"/iam/group/permissions?{query}") + resp = await default_client.get(f"/iam/group/permissions?{query}") - assert resp.status_code == 403 - assert resp.json()["detail"] == "Group does not belong to this organization" + assert resp.status_code == 403 + assert resp.json()["detail"] == "Group does not belong to this organization" @pytest.mark.anyio async def test_get_group_users_success(default_client: AsyncClient): - resp = await default_client.get("/iam/group/users?org_id=1&group_id=1") - assert resp.status_code == 200 + resp = await default_client.get("/iam/group/users?org_id=1&group_id=1") + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "users" in data - assert isinstance(data["users"], list) + assert "users" in data + assert isinstance(data["users"], list) - user = data["users"][0] - assert user["id"] == 1 - assert user["email"] == "admin@test.com" + user = data["users"][0] + assert user["id"] == 1 + assert user["email"] == "admin@test.com" - assert "group" in data - assert isinstance(data["group"], dict) - assert data["group"]["id"] == 1 - assert data["group"]["name"] == "Org One Group" + assert "group" in data + assert isinstance(data["group"], dict) + assert data["group"]["id"] == 1 + assert data["group"]["name"] == "Org One Group" - assert "organisation" in data - assert isinstance(data["organisation"], dict) - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert "organisation" in data + assert isinstance(data["organisation"], dict) + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" @pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["group_id", "org_id"]) + "query, expected_status", generate_query_and_status(["group_id", "org_id"]) ) @pytest.mark.anyio async def test_get_group_users_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/iam/group/users?{query}") + resp = await default_client.get(f"/iam/group/users?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "query", - [ - "org_id=1&group_id=2", - "org_id=2&group_id=1", - ], + "query", + [ + "org_id=1&group_id=2", + "org_id=2&group_id=1", + ], ) @pytest.mark.anyio async def test_get_group_users_mismatch(default_client: AsyncClient, query: str): - resp = await default_client.get(f"/iam/group/users?{query}") + resp = await default_client.get(f"/iam/group/users?{query}") - assert resp.status_code == 403 - assert resp.json()["detail"] == "Group does not belong to this organization" + assert resp.status_code == 403 + assert resp.json()["detail"] == "Group does not belong to this organization" @pytest.mark.anyio async def test_post_group_success(default_client: AsyncClient): - resp = await default_client.post( - "/iam/group", json={"name": "New Group", "organisation_id": 1} - ) - assert resp.status_code == 201 + resp = await default_client.post( + "/iam/group", json={"name": "New Group", "organisation_id": 1} + ) + assert resp.status_code == 201 - data = resp.json() + data = resp.json() - assert "group" in data - assert isinstance(data["group"], dict) - assert data["group"]["name"] == "New Group" - assert data["group"]["id"] == 4 + assert "group" in data + assert isinstance(data["group"], dict) + assert data["group"]["name"] == "New Group" + assert data["group"]["id"] == 4 @pytest.mark.parametrize( - "body, expected_status", - generate_body_and_status({"organisation_id": "int", "name": "str"}), + "body, expected_status", + generate_body_and_status({"organisation_id": "int", "name": "str"}), ) @pytest.mark.anyio async def test_post_group_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.post("/iam/group", json=body) + resp = await default_client.post("/iam/group", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_group_conflict(default_client: AsyncClient): - resp = await default_client.post( - "/iam/group", json={"organisation_id": 1, "name": "Org One Group"} - ) + resp = await default_client.post( + "/iam/group", json={"organisation_id": 1, "name": "Org One Group"} + ) - assert resp.status_code == 409 + assert resp.status_code == 409 @pytest.mark.anyio async def test_post_group_non_conflict(default_client: AsyncClient): - resp = await default_client.post( - "/iam/group", json={"organisation_id": 2, "name": "Org One Group"} - ) + resp = await default_client.post( + "/iam/group", json={"organisation_id": 2, "name": "Org One Group"} + ) - assert resp.status_code == 201 + assert resp.status_code == 201 @pytest.mark.anyio async def test_put_group_perm_success(default_client: AsyncClient): - resp = await default_client.put( - "/iam/group/permission", - json={"permission_id": 1, "group_id": 3, "organisation_id": 1}, - ) - assert resp.status_code == 200 + resp = await default_client.put( + "/iam/group/permission", + json={"permission_id": 1, "group_id": 3, "organisation_id": 1}, + ) + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "group" in data - assert isinstance(data["group"], dict) - assert data["group"]["name"] == "Org One Group Two" - assert data["group"]["id"] == 3 + assert "group" in data + assert isinstance(data["group"], dict) + assert data["group"]["name"] == "Org One Group Two" + assert data["group"]["id"] == 3 - assert "permissions" in data - assert isinstance(data["permissions"], list) + assert "permissions" in data + assert isinstance(data["permissions"], list) - permission = data["permissions"][0] - assert permission["id"] == 1 - assert permission["service_name"] == "Test Service" - assert permission["resource"] == "test_resource" - assert permission["action"] == "read" + permission = data["permissions"][0] + assert permission["id"] == 1 + assert permission["service_name"] == "Test Service" + assert permission["resource"] == "test_resource" + assert permission["action"] == "read" @pytest.mark.parametrize( - "body, expected_status", - generate_body_and_status( - {"organisation_id": "int", "group_id": "int", "permission_id": "int"} - ), + "body, expected_status", + generate_body_and_status( + {"organisation_id": "int", "group_id": "int", "permission_id": "int"} + ), ) @pytest.mark.anyio async def test_put_group_perm_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.put("/iam/group/permission", json=body) + resp = await default_client.put("/iam/group/permission", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_put_group_perm_conflict(default_client: AsyncClient): - resp = await default_client.put( - "/iam/group/permission", - json={"organisation_id": 1, "group_id": 1, "permission_id": 1}, - ) + resp = await default_client.put( + "/iam/group/permission", + json={"organisation_id": 1, "group_id": 1, "permission_id": 1}, + ) - assert resp.status_code == 409 + assert resp.status_code == 409 @pytest.mark.parametrize( - "body", - [ - {"organisation_id": 1, "group_id": 2, "permission_id": 1}, - {"organisation_id": 2, "group_id": 1, "permission_id": 1}, - ], + "body", + [ + {"organisation_id": 1, "group_id": 2, "permission_id": 1}, + {"organisation_id": 2, "group_id": 1, "permission_id": 1}, + ], ) @pytest.mark.anyio async def test_put_group_perm_mismatch(default_client: AsyncClient, body: dict): - resp = await default_client.put("/iam/group/permission", json=body) + resp = await default_client.put("/iam/group/permission", json=body) - assert resp.status_code == 403 - assert resp.json()["detail"] == "Group does not belong to this organization" + assert resp.status_code == 403 + assert resp.json()["detail"] == "Group does not belong to this organization" @pytest.mark.anyio async def test_put_group_user_success(default_client: AsyncClient): - resp = await default_client.put( - "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1} - ) - assert resp.status_code == 200 + resp = await default_client.put( + "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1} + ) + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "group" in data - assert isinstance(data["group"], dict) - assert data["group"]["name"] == "Org One Group" - assert data["group"]["id"] == 1 + assert "group" in data + assert isinstance(data["group"], dict) + assert data["group"]["name"] == "Org One Group" + assert data["group"]["id"] == 1 - assert "users" in data - assert isinstance(data["users"], list) + assert "users" in data + assert isinstance(data["users"], list) - user = data["users"][1] - assert user["id"] == 2 - assert user["first_name"] == "User" - assert user["last_name"] == "Test" - assert user["email"] == "user@orgone.com" + user = data["users"][1] + assert user["id"] == 2 + assert user["first_name"] == "User" + assert user["last_name"] == "Test" + assert user["email"] == "user@orgone.com" @pytest.mark.parametrize( - "body, expected_status", - generate_body_and_status( - {"organisation_id": "int", "group_id": "int", "user_id": "int"} - ), + "body, expected_status", + generate_body_and_status( + {"organisation_id": "int", "group_id": "int", "user_id": "int"} + ), ) @pytest.mark.anyio async def test_put_group_user_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.put("/iam/group/user", json=body) + resp = await default_client.put("/iam/group/user", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_get_permissions_success(default_client: AsyncClient): - resp = await default_client.get("/iam/permissions?org_id=1") - data = resp.json() + resp = await default_client.get("/iam/permissions?org_id=1") + data = resp.json() - assert resp.status_code == 200 - assert "permissions" in data - assert isinstance(data["permissions"], list) + assert resp.status_code == 200 + assert "permissions" in data + assert isinstance(data["permissions"], list) - permission = data["permissions"][0] - assert permission["id"] == 1 - assert permission["service_name"] == "Test Service" - assert permission["resource"] == "test_resource" - assert permission["action"] == "read" + permission = data["permissions"][0] + assert permission["id"] == 1 + assert permission["service_name"] == "Test Service" + assert permission["resource"] == "test_resource" + assert permission["action"] == "read" @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_permissions_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/iam/permissions?{query}") + resp = await default_client.get(f"/iam/permissions?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_perm_success(default_client: AsyncClient): - resp = await default_client.post( - "/iam/permission", - json={"service_id": 1, "resource": "test_resource", "action": "create"}, - ) - assert resp.status_code == 201 + resp = await default_client.post( + "/iam/permission", + json={"service_id": 1, "resource": "test_resource", "action": "create"}, + ) + assert resp.status_code == 201 - data = resp.json() + data = resp.json() - assert "permission" in data - assert isinstance(data["permission"], dict) + assert "permission" in data + assert isinstance(data["permission"], dict) - assert data["permission"]["id"] == 4 - assert data["permission"]["service_name"] == "Test Service" - assert data["permission"]["resource"] == "test_resource" - assert data["permission"]["action"] == "create" + assert data["permission"]["id"] == 4 + assert data["permission"]["service_name"] == "Test Service" + assert data["permission"]["resource"] == "test_resource" + assert data["permission"]["action"] == "create" @pytest.mark.parametrize( - "body, expected_status", - [ - ( - {"service_id": 1, "resource": "test_resource", "action": "read"}, - 409, - ), - # service_id tests - ( - {"service_id": 42, "resource": "test_resource", "action": "read"}, - 404, - ), # Non-existent service - ( - {"service_id": "banana", "resource": "test_resource", "action": "read"}, - 422, - ), # Invalid service ID - ( - {"service_id": "", "resource": "test_resource", "action": "read"}, - 422, - ), # Blank service ID - ( - {"service_id": -1, "resource": "test_resource", "action": "read"}, - 422, - ), # Negative service ID - # resource tests - ( - {"service_id": 1, "resource": 42, "action": "read"}, - 422, - ), # Invalid resource type - # action tests - ( - {"service_id": 1, "resource": "test_resource", "action": 42}, - 422, - ), # Invalid action type - # missing/partial body tests - ({}, 422), # Blank body - ({"resource": "test_resource"}, 422), # Only resource - ({"action": "read"}, 422), # Only action - ({"service_id": 1}, 422), # Only service - ({"service_id": 1, "action": "read"}, 422), # Missing resource - ({"service_id": 1, "resource": "test_resource"}, 422), # Missing action - ({"resource": "test_resource", "action": "read"}, 422), # Missing service - ], + "body, expected_status", + [ + ( + {"service_id": 1, "resource": "test_resource", "action": "read"}, + 409, + ), + # service_id tests + ( + {"service_id": 42, "resource": "test_resource", "action": "read"}, + 404, + ), # Non-existent service + ( + {"service_id": "banana", "resource": "test_resource", "action": "read"}, + 422, + ), # Invalid service ID + ( + {"service_id": "", "resource": "test_resource", "action": "read"}, + 422, + ), # Blank service ID + ( + {"service_id": -1, "resource": "test_resource", "action": "read"}, + 422, + ), # Negative service ID + # resource tests + ( + {"service_id": 1, "resource": 42, "action": "read"}, + 422, + ), # Invalid resource type + # action tests + ( + {"service_id": 1, "resource": "test_resource", "action": 42}, + 422, + ), # Invalid action type + # missing/partial body tests + ({}, 422), # Blank body + ({"resource": "test_resource"}, 422), # Only resource + ({"action": "read"}, 422), # Only action + ({"service_id": 1}, 422), # Only service + ({"service_id": 1, "action": "read"}, 422), # Missing resource + ({"service_id": 1, "resource": "test_resource"}, 422), # Missing action + ({"resource": "test_resource", "action": "read"}, 422), # Missing service + ], ) @pytest.mark.anyio async def test_post_perm_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.post("/iam/permission", json=body) + resp = await default_client.post("/iam/permission", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "body", - [ - { - "organisation_id": 1, - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - {"organisation_id": 1, "service_id": 1}, - {"organisation_id": 1, "resource": "test_resource"}, - {"organisation_id": 1, "action": "read"}, - {"organisation_id": 1, "service_id": 1, "action": "read"}, - {"organisation_id": 1, "service_id": 1, "resource": "test_resource"}, - {"organisation_id": 1, "resource": "test_resource", "action": "read"}, - ], + "body", + [ + { + "organisation_id": 1, + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + {"organisation_id": 1, "service_id": 1}, + {"organisation_id": 1, "resource": "test_resource"}, + {"organisation_id": 1, "action": "read"}, + {"organisation_id": 1, "service_id": 1, "action": "read"}, + {"organisation_id": 1, "service_id": 1, "resource": "test_resource"}, + {"organisation_id": 1, "resource": "test_resource", "action": "read"}, + ], ) @pytest.mark.anyio async def test_post_perm_search_success(default_client: AsyncClient, body): - resp = await default_client.post("/iam/permissions/search", json=body) - data = resp.json() + resp = await default_client.post("/iam/permissions/search", json=body) + data = resp.json() - assert resp.status_code == 200 - assert "permissions" in data - assert isinstance(data["permissions"], list) + assert resp.status_code == 200 + assert "permissions" in data + assert isinstance(data["permissions"], list) - permissions_filtered = [ - permission for permission in data["permissions"] if permission["id"] == 1 - ] - assert len(permissions_filtered) == 1 - permission = permissions_filtered[0] - assert permission["id"] == 1 - assert permission["service_name"] == "Test Service" - assert permission["resource"] == "test_resource" - assert permission["action"] == "read" + permissions_filtered = [ + permission for permission in data["permissions"] if permission["id"] == 1 + ] + assert len(permissions_filtered) == 1 + permission = permissions_filtered[0] + assert permission["id"] == 1 + assert permission["service_name"] == "Test Service" + assert permission["resource"] == "test_resource" + assert permission["action"] == "read" @pytest.mark.parametrize( - "body, expected_status", - [ - # organisation_id tests - ( - { - "organisation_id": 42, - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - 404, - ), # Non-existent organisation - ( - { - "organisation_id": "banana", - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Invalid organisation ID - ( - { - "organisation_id": "", - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Blank organisation ID - ( - { - "organisation_id": -1, - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Negative organisation ID - # service_id tests - ( - { - "organisation_id": 1, - "service_id": "banana", - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Invalid service ID - ( - { - "organisation_id": 1, - "service_id": "", - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Blank service ID - ( - { - "organisation_id": 1, - "service_id": -1, - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Negative service ID - # resource tests - ( - {"organisation_id": 1, "service_id": 1, "resource": 42, "action": "read"}, - 422, - ), # Invalid resource type - # action tests - ( - { - "organisation_id": 1, - "service_id": 1, - "resource": "test_resource", - "action": 42, - }, - 422, - ), # Invalid action type - # missing/partial body tests - ({}, 422), # Blank body - ], + "body, expected_status", + [ + # organisation_id tests + ( + { + "organisation_id": 42, + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + 404, + ), # Non-existent organisation + ( + { + "organisation_id": "banana", + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Invalid organisation ID + ( + { + "organisation_id": "", + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Blank organisation ID + ( + { + "organisation_id": -1, + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Negative organisation ID + # service_id tests + ( + { + "organisation_id": 1, + "service_id": "banana", + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Invalid service ID + ( + { + "organisation_id": 1, + "service_id": "", + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Blank service ID + ( + { + "organisation_id": 1, + "service_id": -1, + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Negative service ID + # resource tests + ( + {"organisation_id": 1, "service_id": 1, "resource": 42, "action": "read"}, + 422, + ), # Invalid resource type + # action tests + ( + { + "organisation_id": 1, + "service_id": 1, + "resource": "test_resource", + "action": 42, + }, + 422, + ), # Invalid action type + # missing/partial body tests + ({}, 422), # Blank body + ], ) @pytest.mark.anyio async def test_post_perm_search_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.post("/iam/permissions/search", json=body) + resp = await default_client.post("/iam/permissions/search", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_delete_group_permissions_success(default_client: AsyncClient): - resp = await default_client.delete( - "/iam/group/permission?org_id=1&group_id=1&perm_id=1" - ) - data = resp.json() + resp = await default_client.delete( + "/iam/group/permission?org_id=1&group_id=1&perm_id=1" + ) + data = resp.json() - assert resp.status_code == 200 - assert "permissions" in data - assert isinstance(data["permissions"], list) - assert len(data["permissions"]) == 0 - assert "group" in data - assert data["group"]["id"] == 1 - assert data["group"]["name"] == "Org One Group" + assert resp.status_code == 200 + assert "permissions" in data + assert isinstance(data["permissions"], list) + assert len(data["permissions"]) == 0 + assert "group" in data + assert data["group"]["id"] == 1 + assert data["group"]["name"] == "Org One Group" @pytest.mark.parametrize( - "query, expected_status", - generate_query_and_status(["group_id", "org_id", "perm_id"]), + "query, expected_status", + generate_query_and_status(["group_id", "org_id", "perm_id"]), ) @pytest.mark.anyio async def test_delete_group_permissions_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.delete(f"/iam/group/permission?{query}") - assert resp.status_code == expected_status + resp = await default_client.delete(f"/iam/group/permission?{query}") + assert resp.status_code == expected_status @pytest.mark.anyio async def test_delete_group_permissions_not_in_group(default_client: AsyncClient): - resp = await default_client.delete( - "/iam/group/permission?org_id=1&group_id=1&perm_id=2" - ) - assert resp.status_code == 422 + resp = await default_client.delete( + "/iam/group/permission?org_id=1&group_id=1&perm_id=2" + ) + assert resp.status_code == 422 @pytest.mark.anyio async def test_delete_permissions_success(default_client: AsyncClient): - resp = await default_client.delete("/iam/permission?perm_id=1") + resp = await default_client.delete("/iam/permission?perm_id=1") - assert resp.status_code == 204 + assert resp.status_code == 204 @pytest.mark.anyio async def test_delete_group_users_success(default_client: AsyncClient): - resp = await default_client.delete("/iam/group/user?org_id=1&group_id=1&user_id=1") - data = resp.json() + resp = await default_client.delete("/iam/group/user?org_id=1&group_id=1&user_id=1") + data = resp.json() - assert resp.status_code == 200 - assert "users" in data - assert isinstance(data["users"], list) - assert len(data["users"]) == 0 - assert "group" in data - assert data["group"]["id"] == 1 - assert data["group"]["name"] == "Org One Group" + assert resp.status_code == 200 + assert "users" in data + assert isinstance(data["users"], list) + assert len(data["users"]) == 0 + assert "group" in data + assert data["group"]["id"] == 1 + assert data["group"]["name"] == "Org One Group" @pytest.mark.anyio async def test_put_group_user_invitation_success(default_client: AsyncClient): - body = {"user_email": "admin@test.com", "organisation_id": 1, "group_id": 1} - resp = await default_client.put("/iam/group/user/invitation", json=body) + body = {"user_email": "admin@test.com", "organisation_id": 1, "group_id": 1} + resp = await default_client.put("/iam/group/user/invitation", json=body) - assert resp.status_code == 200 - data = resp.json() - assert "organisation" in data - assert isinstance(data["organisation"], dict) - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert resp.status_code == 200 + data = resp.json() + assert "organisation" in data + assert isinstance(data["organisation"], dict) + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" - assert "invited_email" in data - assert isinstance(data["invited_email"], str) - assert data["invited_email"] == "admin@test.com" + assert "invited_email" in data + assert isinstance(data["invited_email"], str) + assert data["invited_email"] == "admin@test.com" - assert "group" in data - assert isinstance(data["group"], dict) - assert data["group"]["name"] == "Org One Group" - assert data["group"]["id"] == 1 + assert "group" in data + assert isinstance(data["group"], dict) + assert data["group"]["name"] == "Org One Group" + assert data["group"]["id"] == 1 @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42, "user_email": "admin@test.com", "group_id": 1}, 404), - ( - { - "organisation_id": "Test Org", - "user_email": "admin@test.com", - "group_id": 1, - }, - 422, - ), - ({"organisation_id": "", "user_email": "admin@test.com", "group_id": 1}, 422), - ({}, 422), - ({"user_email": 42, "group_id": 1}, 422), - ({"organisation_id": 1, "user_email": "Test User", "group_id": 1}, 422), - ({"organisation_id": 1, "user_email": "admin@test.com", "group_id": 42}, 404), - ({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422), - ({"organisation_id": "", "user_email": "admin@test.com"}, 422), - ({"user_email": 42}, 422), - ], + "body, expected_status", + [ + ({"organisation_id": 42, "user_email": "admin@test.com", "group_id": 1}, 404), + ( + { + "organisation_id": "Test Org", + "user_email": "admin@test.com", + "group_id": 1, + }, + 422, + ), + ({"organisation_id": "", "user_email": "admin@test.com", "group_id": 1}, 422), + ({}, 422), + ({"user_email": 42, "group_id": 1}, 422), + ({"organisation_id": 1, "user_email": "Test User", "group_id": 1}, 422), + ({"organisation_id": 1, "user_email": "admin@test.com", "group_id": 42}, 404), + ({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422), + ({"organisation_id": "", "user_email": "admin@test.com"}, 422), + ({"user_email": 42}, 422), + ], ) @pytest.mark.anyio async def test_put_group_user_invitation_status_checks( - default_client: AsyncClient, body, expected_status + default_client: AsyncClient, body, expected_status ): - resp = await default_client.put("/iam/group/user/invitation", json=body) + resp = await default_client.put("/iam/group/user/invitation", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "body, expected_status", - [ - ({"jwt": "invalid"}, 401), - ({"jwt": ""}, 401), - ({"jwt": None}, 422), - ({"jwt": 42}, 422), - ], + "body, expected_status", + [ + ({"jwt": "invalid"}, 401), + ({"jwt": ""}, 401), + ({"jwt": None}, 422), + ({"jwt": 42}, 422), + ], ) @pytest.mark.anyio async def test_put_group_user_invitation_accept_status_checks( - default_client: AsyncClient, body, expected_status + default_client: AsyncClient, body, expected_status ): - resp = await default_client.put("/iam/group/user/invitation/accept", json=body) + resp = await default_client.put("/iam/group/user/invitation/accept", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status - if resp.status_code == 401: - assert resp.json()["detail"] == "Invalid JWS" + if resp.status_code == 401: + assert resp.json()["detail"] == "Invalid JWS" diff --git a/test/test_organisation.py b/test/test_organisation.py index 8c9cff6..49dd0a3 100644 --- a/test/test_organisation.py +++ b/test/test_organisation.py @@ -9,506 +9,506 @@ from .conftest import generate_query_and_status pytestmark = [ - pytest.mark.org_module, + pytest.mark.org_module, ] @pytest.mark.anyio async def test_get_org_success(default_client: AsyncClient): - resp = await default_client.get("/org?org_id=1") - data = resp.json() + resp = await default_client.get("/org?org_id=1") + data = resp.json() - assert resp.status_code == 200 + assert resp.status_code == 200 - org = data["organisations"][0] + org = data["organisations"][0] - assert isinstance(org, dict) - assert org["organisation_id"] == 1 - assert org["name"] == "Org One" - assert org["status"] == "approved" - assert org["root_user_email"] == "admin@test.com" - assert "intake_questionnaire" in org - assert isinstance(org["intake_questionnaire"], dict) + assert isinstance(org, dict) + assert org["organisation_id"] == 1 + assert org["name"] == "Org One" + assert org["status"] == "approved" + assert org["root_user_email"] == "admin@test.com" + assert "intake_questionnaire" in org + assert isinstance(org["intake_questionnaire"], dict) - assert org["billing_contact"]["email"] == "billing@orgone.com" - assert org["owner_contact"]["email"] == "owner@orgone.com" - assert org["security_contact"]["email"] == "security@orgone.com" + assert org["billing_contact"]["email"] == "billing@orgone.com" + assert org["owner_contact"]["email"] == "owner@orgone.com" + assert org["security_contact"]["email"] == "security@orgone.com" @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_org_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/org?{query}") + resp = await default_client.get(f"/org?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_org_success(default_client: AsyncClient): - resp = await default_client.post("/org", json={"name": "New Test Org"}) - data = resp.json() + resp = await default_client.post("/org", json={"name": "New Test Org"}) + data = resp.json() - assert resp.status_code == 201 - assert data["name"] == "New Test Org" - assert data["status"] == "partial" + assert resp.status_code == 201 + assert data["name"] == "New Test Org" + assert data["status"] == "partial" @pytest.mark.parametrize( - "body, expected_status", - [ - ({"name": 42}, 422), - ({}, 422), - ({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422), - ], + "body, expected_status", + [ + ({"name": 42}, 422), + ({}, 422), + ({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422), + ], ) @pytest.mark.anyio async def test_post_org_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.post("/org", json=body) + resp = await default_client.post("/org", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient): - resp = await default_client.patch( - "/org/questionnaire", - json={ - "organisation_id": 3, - "intake_questionnaire": { - "question_one": "new answer one", - "question_two": None, - "question_three": None, - }, - "partial": True, - }, - ) - data = resp.json() + resp = await default_client.patch( + "/org/questionnaire", + json={ + "organisation_id": 3, + "intake_questionnaire": { + "question_one": "new answer one", + "question_two": None, + "question_three": None, + }, + "partial": True, + }, + ) + data = resp.json() - assert resp.status_code == 200 - assert data["name"] == "Org Three" - assert data["status"] == "partial" - assert "intake_questionnaire" in data - assert isinstance(data["intake_questionnaire"], dict) - metadata = data["intake_questionnaire"]["metadata"] - assert metadata["version"] == 0 - assert metadata["submission_date"] is None - questions = data["intake_questionnaire"]["questions"] - assert questions["question_one"] == "new answer one" - assert questions["question_two"] == "answer two" - assert questions["question_three"] is None + assert resp.status_code == 200 + assert data["name"] == "Org Three" + assert data["status"] == "partial" + assert "intake_questionnaire" in data + assert isinstance(data["intake_questionnaire"], dict) + metadata = data["intake_questionnaire"]["metadata"] + assert metadata["version"] == 0 + assert metadata["submission_date"] is None + questions = data["intake_questionnaire"]["questions"] + assert questions["question_one"] == "new answer one" + assert questions["question_two"] == "answer two" + assert questions["question_three"] is None @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42}, 404), - ({"organisation_id": "Org One"}, 422), - ({"organisation_id": ""}, 422), - ({}, 422), - ( - { - "organisation_id": "1", - "intake_questionnaire": {"question_one": 42}, - "partial": True, - }, - 422, - ), - ( - {"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}}, - 422, - ), - ( - { - "organisation_id": "1", - "intake_questionnaire": {"question_one": "valid"}, - "partial": 42, - }, - 422, - ), - ], + "body, expected_status", + [ + ({"organisation_id": 42}, 404), + ({"organisation_id": "Org One"}, 422), + ({"organisation_id": ""}, 422), + ({}, 422), + ( + { + "organisation_id": "1", + "intake_questionnaire": {"question_one": 42}, + "partial": True, + }, + 422, + ), + ( + {"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}}, + 422, + ), + ( + { + "organisation_id": "1", + "intake_questionnaire": {"question_one": "valid"}, + "partial": 42, + }, + 422, + ), + ], ) @pytest.mark.anyio async def test_patch_questionnaire_partial_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.patch("/org/questionnaire", json=body) + resp = await default_client.patch("/org/questionnaire", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient): - resp = await default_client.patch( - "/org/questionnaire", - json={ - "organisation_id": 3, - "intake_questionnaire": { - "question_one": "new answer one", - "question_two": None, - "question_three": None, - }, - "partial": False, - }, - ) - data = resp.json() + resp = await default_client.patch( + "/org/questionnaire", + json={ + "organisation_id": 3, + "intake_questionnaire": { + "question_one": "new answer one", + "question_two": None, + "question_three": None, + }, + "partial": False, + }, + ) + data = resp.json() - assert resp.status_code == 200 - assert data["name"] == "Org Three" - assert data["status"] == "submitted" - assert "intake_questionnaire" in data - assert isinstance(data["intake_questionnaire"], dict) - metadata = data["intake_questionnaire"]["metadata"] - assert metadata["version"] == 0 - assert metadata["submission_date"] is not None - questions = data["intake_questionnaire"]["questions"] - assert questions["question_one"] == "new answer one" - assert questions["question_two"] == "answer two" - assert questions["question_three"] is None + assert resp.status_code == 200 + assert data["name"] == "Org Three" + assert data["status"] == "submitted" + assert "intake_questionnaire" in data + assert isinstance(data["intake_questionnaire"], dict) + metadata = data["intake_questionnaire"]["metadata"] + assert metadata["version"] == 0 + assert metadata["submission_date"] is not None + questions = data["intake_questionnaire"]["questions"] + assert questions["question_one"] == "new answer one" + assert questions["question_two"] == "answer two" + assert questions["question_three"] is None @pytest.mark.parametrize( - "status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"] + "status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"] ) @pytest.mark.anyio async def test_patch_org_status_success(default_client: AsyncClient, status: str): - resp = await default_client.patch( - "/org/status", json={"organisation_id": 1, "status": status} - ) - data = resp.json() + resp = await default_client.patch( + "/org/status", json={"organisation_id": 1, "status": status} + ) + data = resp.json() - assert resp.status_code == 200 - assert data["name"] == "Org One" - assert data["status"] == status + assert resp.status_code == 200 + assert data["name"] == "Org One" + assert data["status"] == status @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42}, 404), - ({"organisation_id": "Org One"}, 422), - ({"organisation_id": ""}, 422), - ({}, 422), - ({"organisation_id": "1", "status": True}, 422), - ({"organisation_id": "1", "status": 42}, 422), - ], + "body, expected_status", + [ + ({"organisation_id": 42}, 404), + ({"organisation_id": "Org One"}, 422), + ({"organisation_id": ""}, 422), + ({}, 422), + ({"organisation_id": "1", "status": True}, 422), + ({"organisation_id": "1", "status": 42}, 422), + ], ) @pytest.mark.anyio async def test_patch_org_status_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.patch("/org/status", json=body) + resp = await default_client.patch("/org/status", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_get_org_users_success(default_client: AsyncClient): - resp = await default_client.get("/org/users?org_id=1") - data = resp.json() + resp = await default_client.get("/org/users?org_id=1") + data = resp.json() - assert resp.status_code == 200 + assert resp.status_code == 200 - assert "users" in data - assert isinstance(data["users"], list) - assert len(data["users"]) == 2 + assert "users" in data + assert isinstance(data["users"], list) + assert len(data["users"]) == 2 - user = data["users"][0] - assert isinstance(user, dict) - assert user["email"] == "admin@test.com" - assert user["id"] == 1 + user = data["users"][0] + assert isinstance(user, dict) + assert user["email"] == "admin@test.com" + assert user["id"] == 1 - assert "organisation" in data - assert data["organisation"]["name"] == "Org One" - assert data["organisation"]["id"] == 1 + assert "organisation" in data + assert data["organisation"]["name"] == "Org One" + assert data["organisation"]["id"] == 1 @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_org_users_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/org/users?{query}") + resp = await default_client.get(f"/org/users?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_org_user_success(default_client: AsyncClient): - resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3}) + resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3}) - assert resp.status_code == 200 + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "organisation" in data - assert isinstance(data["organisation"], dict) - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert "organisation" in data + assert isinstance(data["organisation"], dict) + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" - assert "users" in data - assert isinstance(data["users"], list) - assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1 + assert "users" in data + assert isinstance(data["users"], list) + assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1 @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42}, 404), - ({}, 422), - ({"organisation_id": 1, "user_id": "id"}, 422), - ({"user_id": 2}, 422), - ({"organisation_id": 1, "user_id": 42}, 404), - ({"organisation_id": 1, "user_id": 1}, 409), - ], + "body, expected_status", + [ + ({"organisation_id": 42}, 404), + ({}, 422), + ({"organisation_id": 1, "user_id": "id"}, 422), + ({"user_id": 2}, 422), + ({"organisation_id": 1, "user_id": 42}, 404), + ({"organisation_id": 1, "user_id": 1}, 409), + ], ) @pytest.mark.anyio async def test_post_org_user_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.post("/org/user", json=body) + resp = await default_client.post("/org/user", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_patch_org_root_user_success(default_client: AsyncClient): - resp = await default_client.patch( - "/org/root_user", json={"organisation_id": 1, "user_id": 2} - ) - assert resp.status_code == 200 + resp = await default_client.patch( + "/org/root_user", json={"organisation_id": 1, "user_id": 2} + ) + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert data["name"] == "Org One" - assert data["root_user_email"] == "user@orgone.com" + assert data["name"] == "Org One" + assert data["root_user_email"] == "user@orgone.com" @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42, "user_id": 2}, 404), - ({"organisation_id": "Org One", "user_id": 2}, 422), - ({"organisation_id": "", "user_id": 2}, 422), - ({}, 422), - ({"user_id": 2}, 422), - ({"user_id": 42}, 404), - ({"organisation_id": 1, "user_id": "Test User"}, 422), - ], + "body, expected_status", + [ + ({"organisation_id": 42, "user_id": 2}, 404), + ({"organisation_id": "Org One", "user_id": 2}, 422), + ({"organisation_id": "", "user_id": 2}, 422), + ({}, 422), + ({"user_id": 2}, 422), + ({"user_id": 42}, 404), + ({"organisation_id": 1, "user_id": "Test User"}, 422), + ], ) @pytest.mark.anyio async def test_patch_root_user_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.patch("/org/root_user", json=body) + resp = await default_client.patch("/org/root_user", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_patch_org_root_user_non_member(default_client: AsyncClient): - resp = await default_client.patch( - "/org/root_user", json={"organisation_id": 1, "user_id": 3} - ) - data = resp.json() + resp = await default_client.patch( + "/org/root_user", json={"organisation_id": 1, "user_id": 3} + ) + data = resp.json() - assert resp.status_code == 422 - assert data["detail"] == "This user does not belong to your organisation." + assert resp.status_code == 422 + assert data["detail"] == "This user does not belong to your organisation." @pytest.mark.anyio async def test_get_org_groups_success(default_client: AsyncClient): - resp = await default_client.get("/org/groups?org_id=1") - assert resp.status_code == 200 + resp = await default_client.get("/org/groups?org_id=1") + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "organisation" in data - assert isinstance(data["organisation"], dict) - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert "organisation" in data + assert isinstance(data["organisation"], dict) + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" - assert "groups" in data - assert isinstance(data["groups"], list) - group = data["groups"][0] - assert isinstance(group, dict) - assert group["id"] == 1 - assert group["name"] == "Org One Group" + assert "groups" in data + assert isinstance(data["groups"], list) + group = data["groups"][0] + assert isinstance(group, dict) + assert group["id"] == 1 + assert group["name"] == "Org One Group" @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_org_groups_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/org/groups?{query}") + resp = await default_client.get(f"/org/groups?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize("contact_type", ["billing", "security", "owner"]) @pytest.mark.anyio async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str): - resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}") - data = resp.json() + resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}") + data = resp.json() - assert resp.status_code == 200 + assert resp.status_code == 200 - assert "organisation" in data - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert "organisation" in data + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" - attributes = [ - "email", - "first_name", - "last_name", - "phonenumber", - "vat_number", - "address", - ] + attributes = [ + "email", + "first_name", + "last_name", + "phonenumber", + "vat_number", + "address", + ] - for attribute in attributes: - assert attribute in data["contact"] + for attribute in attributes: + assert attribute in data["contact"] - address_attributes = [ - "post_office_box_number", - "street_address", - "street_address_line_2", - "locality", - "address_region", - "country_code", - "postal_code", - ] + address_attributes = [ + "post_office_box_number", + "street_address", + "street_address_line_2", + "locality", + "address_region", + "country_code", + "postal_code", + ] - for attribute in address_attributes: - assert attribute in data["contact"]["address"] + for attribute in address_attributes: + assert attribute in data["contact"]["address"] @pytest.mark.parametrize( - "query, expected_status", - [ - ("org_id=42&contact_type=billing", 404), - ("org_id=banana&contact_type=billing", 422), - ("", 422), - ("org_id=1&contact_type=contact", 422), - ("contact_type=billing", 422), - ], + "query, expected_status", + [ + ("org_id=42&contact_type=billing", 404), + ("org_id=banana&contact_type=billing", 422), + ("", 422), + ("org_id=1&contact_type=contact", 422), + ("contact_type=billing", 422), + ], ) @pytest.mark.anyio async def test_get_org_contact_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/org/contact?{query}") + resp = await default_client.get(f"/org/contact?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "key, value", - [ - ("email", "user@example.com"), - ("first_name", "John"), - ("last_name", "Doe"), - ("phonenumber", "+441234567890"), - ("vat_number", "GB123456789"), - ("post_office_box_number", "PO Box 123"), - ("street_address", "123 Example Street"), - ("street_address_line_2", "Suite 4B"), - ("locality", "Glasgow"), - ("address_region", "Glasgow City"), - ("country_code", "GB"), - ("postal_code", "G1 1AA"), - ], + "key, value", + [ + ("email", "user@example.com"), + ("first_name", "John"), + ("last_name", "Doe"), + ("phonenumber", "+441234567890"), + ("vat_number", "GB123456789"), + ("post_office_box_number", "PO Box 123"), + ("street_address", "123 Example Street"), + ("street_address_line_2", "Suite 4B"), + ("locality", "Glasgow"), + ("address_region", "Glasgow City"), + ("country_code", "GB"), + ("postal_code", "G1 1AA"), + ], ) @pytest.mark.anyio async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str): - resp = await default_client.patch( - "/org/contact", - json={"organisation_id": 1, "contact_type": "billing", key: value}, - ) - assert resp.status_code == 200 + resp = await default_client.patch( + "/org/contact", + json={"organisation_id": 1, "contact_type": "billing", key: value}, + ) + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "organisation" in data - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert "organisation" in data + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" - attributes = [ - "email", - "first_name", - "last_name", - "phonenumber", - "vat_number", - "address", - ] + attributes = [ + "email", + "first_name", + "last_name", + "phonenumber", + "vat_number", + "address", + ] - for attribute in attributes: - assert attribute in data["contact"] + for attribute in attributes: + assert attribute in data["contact"] - address_attributes = [ - "post_office_box_number", - "street_address", - "street_address_line_2", - "locality", - "address_region", - "country_code", - "postal_code", - ] + address_attributes = [ + "post_office_box_number", + "street_address", + "street_address_line_2", + "locality", + "address_region", + "country_code", + "postal_code", + ] - for attribute in address_attributes: - assert attribute in data["contact"]["address"] + for attribute in address_attributes: + assert attribute in data["contact"]["address"] - if key in data["contact"]: - assert data["contact"][key] == value - elif key in data["contact"]["address"]: - assert data["contact"]["address"][key] == value - else: - pytest.fail(f"Invalid contact key: {key}") + if key in data["contact"]: + assert data["contact"][key] == value + elif key in data["contact"]["address"]: + assert data["contact"]["address"][key] == value + else: + pytest.fail(f"Invalid contact key: {key}") @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42, "contact_type": "billing"}, 404), - ({"organisation_id": 1, "contact_type": "security"}, 200), - ({"organisation_id": 1, "contact_type": "owner"}, 200), - ({"organisation_id": "Org One", "contact_type": "billing"}, 422), - ({"organisation_id": "", "contact_type": "billing"}, 422), - ({}, 422), - ({"organisation_id": 1, "contact_type": "not_real"}, 422), - ({"organisation_id": 1, "contact_type": 42}, 422), - ({"organisation_id": 1, "contact_type": ""}, 422), - ], + "body, expected_status", + [ + ({"organisation_id": 42, "contact_type": "billing"}, 404), + ({"organisation_id": 1, "contact_type": "security"}, 200), + ({"organisation_id": 1, "contact_type": "owner"}, 200), + ({"organisation_id": "Org One", "contact_type": "billing"}, 422), + ({"organisation_id": "", "contact_type": "billing"}, 422), + ({}, 422), + ({"organisation_id": 1, "contact_type": "not_real"}, 422), + ({"organisation_id": 1, "contact_type": 42}, 422), + ({"organisation_id": 1, "contact_type": ""}, 422), + ], ) @pytest.mark.anyio async def test_patch_org_contact_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.patch("/org/contact", json=body) + resp = await default_client.patch("/org/contact", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_delete_org_success(default_client: AsyncClient): - resp = await default_client.delete("/org?org_id=1") + resp = await default_client.delete("/org?org_id=1") - assert resp.status_code == 204 + assert resp.status_code == 204 @pytest.mark.anyio async def test_delete_org_users_success(default_client: AsyncClient): - resp = await default_client.delete("/org/user?org_id=1&user_id=2") + resp = await default_client.delete("/org/user?org_id=1&user_id=2") - assert resp.status_code == 204 + assert resp.status_code == 204 @pytest.mark.anyio async def test_delete_preapproval_org_success(default_client: AsyncClient): - resp = await default_client.delete("/org/self?org_id=3") + resp = await default_client.delete("/org/self?org_id=3") - assert resp.status_code == 204 + assert resp.status_code == 204 diff --git a/test/test_service.py b/test/test_service.py index 43e8bc4..9b9b98b 100644 --- a/test/test_service.py +++ b/test/test_service.py @@ -8,90 +8,90 @@ from httpx import AsyncClient from .conftest import generate_query_and_status, generate_body_and_status pytestmark = [ - pytest.mark.service_module, + pytest.mark.service_module, ] @pytest.mark.anyio async def test_get_services_success(default_client: AsyncClient): - resp = await default_client.get("/service?org_id=1") - data = resp.json() + resp = await default_client.get("/service?org_id=1") + data = resp.json() - assert resp.status_code == 200 - assert "services" in data - assert isinstance(data["services"], list) - assert data["services"][0]["id"] == 1 - assert data["services"][0]["name"] == "Test Service" + assert resp.status_code == 200 + assert "services" in data + assert isinstance(data["services"], list) + assert data["services"][0]["id"] == 1 + assert data["services"][0]["name"] == "Test Service" @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_services_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/service?{query}") + resp = await default_client.get(f"/service?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_service_success(default_client: AsyncClient): - resp = await default_client.post("/service", json={"name": "New Test Service"}) - data = resp.json() + resp = await default_client.post("/service", json={"name": "New Test Service"}) + data = resp.json() - assert resp.status_code == 200 - assert "service" in data - assert isinstance(data["service"], dict) - assert data["service"]["name"] == "New Test Service" - assert data["service"]["id"] == 2 - assert isinstance(data["service"]["api_key"], str) + assert resp.status_code == 200 + assert "service" in data + assert isinstance(data["service"], dict) + assert data["service"]["name"] == "New Test Service" + assert data["service"]["id"] == 2 + assert isinstance(data["service"]["api_key"], str) @pytest.mark.parametrize("body, expected_status", generate_body_and_status({"name": "str"})) @pytest.mark.anyio async def test_post_service_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.post("/service", json=body) + resp = await default_client.post("/service", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_service_conflict(default_client: AsyncClient): - resp = await default_client.post("/service", json={"name": "Test Service"}) + resp = await default_client.post("/service", json={"name": "Test Service"}) - assert resp.status_code == 409 + assert resp.status_code == 409 @pytest.mark.anyio async def test_patch_service_success(default_client: AsyncClient): - resp = await default_client.patch("/service/key", json={"service_id": 1}) - data = resp.json() + resp = await default_client.patch("/service/key", json={"service_id": 1}) + data = resp.json() - assert resp.status_code == 200 - assert "service" in data - assert isinstance(data["service"], dict) - assert data["service"]["name"] == "Test Service" - assert data["service"]["id"] == 1 - assert isinstance(data["service"]["api_key"], str) + assert resp.status_code == 200 + assert "service" in data + assert isinstance(data["service"], dict) + assert data["service"]["name"] == "Test Service" + assert data["service"]["id"] == 1 + assert isinstance(data["service"]["api_key"], str) @pytest.mark.parametrize( - "body, expected_status", - generate_body_and_status({"service_id": "int"}), + "body, expected_status", + generate_body_and_status({"service_id": "int"}), ) @pytest.mark.anyio async def test_patch_services_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.patch("/service/key", json=body) + resp = await default_client.patch("/service/key", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_delete_service_success(default_client: AsyncClient): - resp = await default_client.delete("/service?service_id=1") + resp = await default_client.delete("/service?service_id=1") - assert resp.status_code == 204 + assert resp.status_code == 204 diff --git a/test/test_user.py b/test/test_user.py index a497fa9..b018841 100644 --- a/test/test_user.py +++ b/test/test_user.py @@ -11,191 +11,191 @@ from .conftest import generate_query_and_status pytestmark = [ - pytest.mark.user_module, + pytest.mark.user_module, ] @pytest.mark.anyio async def test_get_self_db_success(default_client: AsyncClient): - resp = await default_client.get("/user/self/db") - data = resp.json() + resp = await default_client.get("/user/self/db") + data = resp.json() - assert resp.status_code == 200 - assert data["first_name"] == "Admin" - assert data["last_name"] == "Test" - assert data["email"] == "admin@test.com" - assert "organisations" in data - assert isinstance(data["organisations"], list) - assert "groups" in data - assert isinstance(data["groups"], dict) + assert resp.status_code == 200 + assert data["first_name"] == "Admin" + assert data["last_name"] == "Test" + assert data["email"] == "admin@test.com" + assert "organisations" in data + assert isinstance(data["organisations"], list) + assert "groups" in data + assert isinstance(data["groups"], dict) @pytest.mark.anyio async def test_get_user_success(default_client: AsyncClient): - resp = await default_client.get("/user?user_id=1") - data = resp.json() + resp = await default_client.get("/user?user_id=1") + data = resp.json() - assert resp.status_code == 200 - assert data["first_name"] == "Admin" - assert data["last_name"] == "Test" - assert data["email"] == "admin@test.com" - assert "organisations" in data - assert isinstance(data["organisations"], list) - assert "groups" in data - assert isinstance(data["groups"], dict) + assert resp.status_code == 200 + assert data["first_name"] == "Admin" + assert data["last_name"] == "Test" + assert data["email"] == "admin@test.com" + assert "organisations" in data + assert isinstance(data["organisations"], list) + assert "groups" in data + assert isinstance(data["groups"], dict) @pytest.mark.anyio @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["user_id"])) async def test_get_user_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/user?{query}") + resp = await default_client.get(f"/user?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_delete_user_success(default_client: AsyncClient): - resp = await default_client.delete("/user?user_id=1") + resp = await default_client.delete("/user?user_id=1") - assert resp.status_code == 204 + assert resp.status_code == 204 @pytest.mark.anyio async def test_post_user_invitation_success(default_client: AsyncClient): - body = {"user_email": "admin@test.com", "organisation_id": 1} - resp = await default_client.post("/user/invitation", json=body) + body = {"user_email": "admin@test.com", "organisation_id": 1} + resp = await default_client.post("/user/invitation", json=body) - assert resp.status_code == 200 - data = resp.json() - assert "organisation" in data - assert isinstance(data["organisation"], dict) - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert resp.status_code == 200 + data = resp.json() + assert "organisation" in data + assert isinstance(data["organisation"], dict) + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" - assert "invited_email" in data - assert isinstance(data["invited_email"], str) - assert data["invited_email"] == "admin@test.com" + assert "invited_email" in data + assert isinstance(data["invited_email"], str) + assert data["invited_email"] == "admin@test.com" @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42, "user_email": "admin@test.com"}, 404), - ({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422), - ({"organisation_id": "", "user_email": "admin@test.com"}, 422), - ({}, 422), - ({"user_email": 42}, 422), - ({"organisation_id": 1, "user_email": "Test User"}, 422), - ], + "body, expected_status", + [ + ({"organisation_id": 42, "user_email": "admin@test.com"}, 404), + ({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422), + ({"organisation_id": "", "user_email": "admin@test.com"}, 422), + ({}, 422), + ({"user_email": 42}, 422), + ({"organisation_id": 1, "user_email": "Test User"}, 422), + ], ) @pytest.mark.anyio async def test_post_user_invitation_status_checks( - default_client: AsyncClient, body, expected_status + default_client: AsyncClient, body, expected_status ): - resp = await default_client.post("/user/invitation", json=body) + resp = await default_client.post("/user/invitation", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "body, expected_status", - [ - ({"jwt": "invalid"}, 401), - ({"jwt": ""}, 401), - ({"jwt": None}, 422), - ({"jwt": 42}, 422), - ], + "body, expected_status", + [ + ({"jwt": "invalid"}, 401), + ({"jwt": ""}, 401), + ({"jwt": None}, 422), + ({"jwt": 42}, 422), + ], ) @pytest.mark.anyio async def test_post_user_invitation_accept_status_checks( - default_client: AsyncClient, body, expected_status + default_client: AsyncClient, body, expected_status ): - resp = await default_client.post("/user/invitation/accept", json=body) + resp = await default_client.post("/user/invitation/accept", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status - if resp.status_code == 401: - assert resp.json()["detail"] == "Invalid JWS" + if resp.status_code == 401: + assert resp.json()["detail"] == "Invalid JWS" @pytest.mark.anyio async def test_get_self_orgs_success(default_client: AsyncClient): - resp = await default_client.get("/user/self/orgs") - assert resp.status_code == 200 + resp = await default_client.get("/user/self/orgs") + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "organisations" in data - assert isinstance(data["organisations"], list) - assert len(data["organisations"]) > 0 + assert "organisations" in data + assert isinstance(data["organisations"], list) + assert len(data["organisations"]) > 0 - org = data["organisations"][0] - assert org["organisation_id"] == 1 - assert org["name"] == "Org One" - assert org["status"] == "approved" - assert org["root_user_email"] == "admin@test.com" - assert "intake_questionnaire" in org - assert isinstance(org["intake_questionnaire"], dict) + org = data["organisations"][0] + assert org["organisation_id"] == 1 + assert org["name"] == "Org One" + assert org["status"] == "approved" + assert org["root_user_email"] == "admin@test.com" + assert "intake_questionnaire" in org + assert isinstance(org["intake_questionnaire"], dict) - assert isinstance(org["billing_contact"], dict) - assert org["billing_contact"]["email"] == "billing@orgone.com" - assert org["billing_contact"]["id"] == 1 + assert isinstance(org["billing_contact"], dict) + assert org["billing_contact"]["email"] == "billing@orgone.com" + assert org["billing_contact"]["id"] == 1 - assert isinstance(org["owner_contact"], dict) - assert org["owner_contact"]["email"] == "owner@orgone.com" - assert org["owner_contact"]["id"] == 2 + assert isinstance(org["owner_contact"], dict) + assert org["owner_contact"]["email"] == "owner@orgone.com" + assert org["owner_contact"]["id"] == 2 - assert isinstance(org["security_contact"], dict) - assert org["security_contact"]["email"] == "security@orgone.com" - assert org["security_contact"]["id"] == 3 + assert isinstance(org["security_contact"], dict) + assert org["security_contact"]["email"] == "security@orgone.com" + assert org["security_contact"]["id"] == 3 @pytest.mark.anyio async def test_get_self_orgs_dynamic(default_client: AsyncClient): - method = "GET" - path = "/user/self/orgs" - expected_data = { - "organisations": [ - { - "organisation_id": 1, - "name": "Org One", - "status": "approved", - "root_user_email": "admin@test.com", - "owner_contact": {"email": "owner@orgone.com", "id": 2}, - "security_contact": {"email": "security@orgone.com", "id": 3}, - "billing_contact": {"email": "billing@orgone.com", "id": 1}, - "intake_questionnaire": { - "questions": { - "question_one": None, - "question_three": None, - "question_two": "answer two", - }, - "metadata": {"version": 0, "submission_date": None}, - }, - } - ] - } + method = "GET" + path = "/user/self/orgs" + expected_data = { + "organisations": [ + { + "organisation_id": 1, + "name": "Org One", + "status": "approved", + "root_user_email": "admin@test.com", + "owner_contact": {"email": "owner@orgone.com", "id": 2}, + "security_contact": {"email": "security@orgone.com", "id": 3}, + "billing_contact": {"email": "billing@orgone.com", "id": 1}, + "intake_questionnaire": { + "questions": { + "question_one": None, + "question_three": None, + "question_two": "answer two", + }, + "metadata": {"version": 0, "submission_date": None}, + }, + } + ] + } - resp = await default_client.get(path) + resp = await default_client.get(path) - route = next( - route - for route in default_client._transport.app.routes # ty:ignore[unresolved-attribute] - if isinstance(route, APIRoute) and path in route.path and method in route.methods - ) + route = next( + route + for route in default_client._transport.app.routes # ty:ignore[unresolved-attribute] + if isinstance(route, APIRoute) and path in route.path and method in route.methods + ) - assert resp.status_code == route.status_code - if route.status_code == 204: - return + assert resp.status_code == route.status_code + if route.status_code == 204: + return - expected_response_schema = route.response_model - data = resp.json() + expected_response_schema = route.response_model + data = resp.json() - response_model = expected_response_schema(**data) - assert isinstance(response_model, expected_response_schema) + response_model = expected_response_schema(**data) + assert isinstance(response_model, expected_response_schema) - expected_response_model = expected_response_schema(**expected_data) + expected_response_model = expected_response_schema(**expected_data) - assert response_model == expected_response_model + assert response_model == expected_response_model