From 63e7d48c07004a0a217e6b2f2f866a2c1f5ab542 Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 12:04:39 +0100 Subject: [PATCH 01/33] ci: remove non-ty checks from ty job --- .forgejo/workflows/publish.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.forgejo/workflows/publish.yaml b/.forgejo/workflows/publish.yaml index 7abc689..85bbfcb 100644 --- a/.forgejo/workflows/publish.yaml +++ b/.forgejo/workflows/publish.yaml @@ -34,8 +34,6 @@ jobs: - run: uv python install # Gets Python version from pyproject.toml - run: uv sync --dev - run: uv run ty check - - run: uv run ruff format - - run: uv run pytest test env: ENVIRONMENT: testing From 02ddf9a3eddec09449085c43c1ffb6f258758b6d Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 12:06:43 +0100 Subject: [PATCH 02/33] fix: skip sending email process while running tests Removes the need for lettermint api key in CI. --- src/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/utils.py b/src/utils.py index b314d61..1aca16a 100644 --- a/src/utils.py +++ b/src/utils.py @@ -39,9 +39,12 @@ async def verify_email_token(user_model, token): async def send_email(recipient: str, subject: str, body: str): + if settings.ENVIRONMENT.is_testing: + return + lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value()) - if settings.ENVIRONMENT.is_testing or settings.ENVIRONMENT == "local": + if settings.ENVIRONMENT == "local": recipient = "ok@testing.lettermint.co" try: From a9e539ef747c1775681f31dcabc4b8a1598473af Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 12:21:02 +0100 Subject: [PATCH 03/33] fix(user): simplify add_user --- src/auth/service.py | 4 ++-- src/user/service.py | 26 ++++++++------------------ 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/src/auth/service.py b/src/auth/service.py index 1b90b8c..2675dc2 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -22,7 +22,7 @@ from src.organisation.exceptions import AwaitingApprovalException from src.organisation.models import Organisation as Org from src.exceptions import UnauthorizedException, ForbiddenException from src.auth.config import auth_settings -from src.user.service import add_user_to_db +from src.user.service import add_user from src.database import db_dependency @@ -53,7 +53,7 @@ async def get_current_user( claims_requests.validate(token.claims) except ExpiredTokenError: raise UnauthorizedException(message="Token is expired") - db_id = await add_user_to_db(db, token.claims) + db_id = await add_user(db, token.claims) token.claims["db_id"] = db_id diff --git a/src/user/service.py b/src/user/service.py index ff1da8b..b70056f 100644 --- a/src/user/service.py +++ b/src/user/service.py @@ -1,8 +1,5 @@ """ Module specific business logic for user module - -Exports: - - add_user_to_db: Creates a User record from OIDC claims, or updates user details """ from typing import Any @@ -17,7 +14,7 @@ from src.user.schemas import OIDCUser from src.user.models import User -async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: +async def add_user(db: Session, user_claims: dict[str, Any]) -> int: try: valid_user = OIDCUser( first_name=user_claims["given_name"], @@ -26,7 +23,7 @@ async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: oidc_id=user_claims["sub"], ) except Exception as e: - print(e) + logging.exception(e) raise UnprocessableContentException("Invalid or missing OIDC data") db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first() @@ -37,19 +34,12 @@ async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: user_id = user_model.id db.commit() return user_id - else: - user_id = db_user.id - change = False - if db_user.first_name != valid_user.first_name: - db_user.first_name = valid_user.first_name - change = True - if db_user.last_name != valid_user.last_name: - db_user.last_name = valid_user.last_name - change = True - if change: - db.add(db_user) - db.commit() - return user_id + + user_id = db_user.id + db_user.first_name = valid_user.first_name + db_user.last_name = valid_user.last_name + db.commit() + return user_id async def send_invitation(user_email: str, org_name: str, org_id: int): From 2a1d28bc5488c1b68bfffade3d1c444777949e47 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 12:58:37 +0100 Subject: [PATCH 04/33] feat(db): db tuning options and consistency --- src/auth/service.py | 4 +- src/config.py | 4 ++ src/database.py | 74 ++++++++++++++++++++------------ src/iam/dependencies.py | 10 ++--- src/iam/router.py | 26 +++++------ src/iam/service.py | 6 +-- src/organisation/dependencies.py | 6 +-- src/organisation/router.py | 22 +++++----- src/service/dependencies.py | 6 +-- src/service/router.py | 12 +++--- src/user/dependencies.py | 8 ++-- src/user/router.py | 6 +-- 12 files changed, 104 insertions(+), 80 deletions(-) diff --git a/src/auth/service.py b/src/auth/service.py index 1b90b8c..130a0f2 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -23,7 +23,7 @@ from src.organisation.models import Organisation as Org from src.exceptions import UnauthorizedException, ForbiddenException from src.auth.config import auth_settings from src.user.service import add_user_to_db -from src.database import db_dependency +from src.database import DbSession oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) @@ -35,7 +35,7 @@ async def get_dev_user(): async def get_current_user( - oidc_auth_string: oidc_dependency, db: db_dependency + oidc_auth_string: oidc_dependency, db: DbSession ) -> dict[str, Any]: config_url = urlopen(auth_settings.OIDC_CONFIG) config = json.loads(config_url.read()) diff --git a/src/config.py b/src/config.py index 6ebfcf5..6fc68f0 100644 --- a/src/config.py +++ b/src/config.py @@ -36,6 +36,10 @@ class Config(CustomBaseSettings): DATABASE_HOSTNAME: str = "localhost" DATABASE_CREDENTIALS: SecretStr = SecretStr(":") + DATABASE_POOL_SIZE: int = 16 + DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes + DATABASE_POOL_PRE_PING: bool = True + LETTERMINT_API_TOKEN: SecretStr = SecretStr("") diff --git a/src/database.py b/src/database.py index fb29a41..41509d4 100644 --- a/src/database.py +++ b/src/database.py @@ -1,13 +1,9 @@ """ -Database connections and init - -Exports: - - db_dependency - - Base (sqlalchemy base model) +Database connection and session utilities """ - -from typing import Annotated -from sqlalchemy import create_engine, StaticPool +from contextlib import contextmanager +from typing import Annotated, Generator +from sqlalchemy import create_engine, StaticPool, Connection from sqlalchemy.orm import sessionmaker, Session from fastapi import Depends @@ -16,28 +12,52 @@ from src.constants import Environment from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings if global_settings.ENVIRONMENT == Environment.TESTING: - connect_args = {"check_same_thread": False} - engine = create_engine( - SQLALCHEMY_DATABASE_URI.get_secret_value(), - connect_args=connect_args, - poolclass=StaticPool, - ) + connect_args = {"check_same_thread": False} + engine = create_engine( + SQLALCHEMY_DATABASE_URI.get_secret_value(), + connect_args=connect_args, + poolclass=StaticPool, + ) else: - engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value()) + engine = create_engine( + SQLALCHEMY_DATABASE_URI.get_secret_value(), + pool_size=global_settings.DATABASE_POOL_SIZE, + pool_recycle=global_settings.DATABASE_POOL_TTL, + pool_pre_ping=global_settings.DATABASE_POOL_PRE_PING, + ) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine) + +@contextmanager +def get_db_connection() -> Generator[Connection, None, None]: + with engine.connect() as connection: + try: + yield connection + except Exception: + connection.rollback() + raise + +def _get_db_connection() -> Generator[Connection, None]: + with get_db_connection() as connection: + yield connection + +DbConnection = Annotated[Connection, Depends(_get_db_connection)] + +@contextmanager +def get_db_session() -> Generator[Session, None, None]: + session = sm() + try: + yield session + except Exception: + session.rollback() + raise + finally: + session.close() -def get_db(): - db = SessionLocal() - try: - yield db - except: - db.rollback() - raise - finally: - db.close() +def _get_db_session() -> Generator[Session, None]: + with get_db_session() as session: + yield session - -db_dependency = Annotated[Session, Depends(get_db)] +DbSession = Annotated[Session, Depends(_get_db_session)] diff --git a/src/iam/dependencies.py b/src/iam/dependencies.py index 14fb4ad..113a3c6 100644 --- a/src/iam/dependencies.py +++ b/src/iam/dependencies.py @@ -11,7 +11,7 @@ from typing import Annotated, Optional from fastapi import Depends, Query -from src.database import db_dependency +from src.database import DbSession from src.iam.models import Group, Permission from src.iam.exceptions import GroupNotFoundException, PermNotFoundException @@ -19,7 +19,7 @@ from src.iam.schemas import GroupIDMixin, PermIDMixin def get_group_model_query( - db: db_dependency, group_id: Annotated[int, Query(gt=0)] + db: DbSession, group_id: Annotated[int, Query(gt=0)] ) -> Group: group_model = db.get(Group, group_id) if group_model is None: @@ -32,7 +32,7 @@ group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)] def get_group_model_body( - db: db_dependency, request_model: Optional[GroupIDMixin] = None + db: DbSession, request_model: Optional[GroupIDMixin] = None ) -> Group: group_id = getattr(request_model, "group_id", None) if group_id is None: @@ -48,7 +48,7 @@ group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)] def get_perm_model_body( - db: db_dependency, request_model: Optional[PermIDMixin] = None + db: DbSession, request_model: Optional[PermIDMixin] = None ) -> Permission: perm_id = getattr(request_model, "permission_id", None) if perm_id is None: @@ -64,7 +64,7 @@ perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)] def get_perm_model_query( - db: db_dependency, perm_id: Annotated[int, Query(gt=0)] + db: DbSession, perm_id: Annotated[int, Query(gt=0)] ) -> Permission: perm_model = db.get(Permission, perm_id) if perm_model is None: diff --git a/src/iam/router.py b/src/iam/router.py index af783a1..1b2c297 100644 --- a/src/iam/router.py +++ b/src/iam/router.py @@ -32,7 +32,7 @@ from src.exceptions import ( ForbiddenException, UnprocessableContentException, ) -from src.database import db_dependency +from src.database import DbSession from src.auth.service import claims_dependency from src.auth.dependencies import ( org_model_root_claim_query_dependency, @@ -107,7 +107,7 @@ router = APIRouter( ) async def can_act_on_resource( valid_key: service_key_dependency, - db: db_dependency, + db: DbSession, user_claims: claims_dependency, request_model: IAMCAoRRequest, ): @@ -270,7 +270,7 @@ async def get_group_users( }, ) async def create_group( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_body_dependency, request_model: IAMPostGroupRequest, ): @@ -310,7 +310,7 @@ async def create_group( }, ) async def add_group_permission( - db: db_dependency, + db: DbSession, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, @@ -356,7 +356,7 @@ async def add_group_permission( }, ) async def add_group_user( - db: db_dependency, + db: DbSession, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, @@ -399,7 +399,7 @@ async def add_group_user( }, ) async def remove_group_permission( - db: db_dependency, + db: DbSession, group_model: group_model_query_dependency, perm_model: perm_model_query_dependency, org_model: org_model_root_claim_query_dependency, @@ -436,7 +436,7 @@ async def remove_group_permission( }, ) async def remove_group_user( - db: db_dependency, + db: DbSession, group_model: group_model_query_dependency, user_model: user_model_query_dependency, org_model: org_model_root_claim_query_dependency, @@ -469,7 +469,7 @@ async def remove_group_user( }, ) async def get_permissions( - db: db_dependency, org_model: org_model_root_claim_query_dependency + db: DbSession, org_model: org_model_root_claim_query_dependency ): """ Returns a full list of permissions. @@ -493,7 +493,7 @@ async def get_permissions( }, ) async def create_new_permission( - db: db_dependency, + db: DbSession, su: super_admin_dependency, request_model: IAMPostPermissionRequest, service_model: service_model_body_dependency, # Used to verify service model exists @@ -529,7 +529,7 @@ async def create_new_permission( responses={}, ) async def delete_permission( - db: db_dependency, + db: DbSession, su: super_admin_dependency, perm_model: perm_model_query_dependency, ): @@ -548,7 +548,7 @@ async def delete_permission( responses={}, ) async def permissions_search( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_body_dependency, request_model: IAMGetPermissionsSearchRequest, ): @@ -632,7 +632,7 @@ async def invitation( }, ) async def accept_invitation( - db: db_dependency, + db: DbSession, user_model: user_model_claims_dependency, request_model: IAMPutGroupInvitationAcceptRequest, ): @@ -678,7 +678,7 @@ async def accept_invitation( }, ) async def add_org_permissions( - db: db_dependency, + db: DbSession, su: super_admin_dependency, org_model: org_model_body_dependency, request_model: IAMPutOrgPermissionsRequest, diff --git a/src/iam/service.py b/src/iam/service.py index 4f23969..2112c6c 100644 --- a/src/iam/service.py +++ b/src/iam/service.py @@ -10,7 +10,7 @@ from datetime import datetime, timedelta, timezone from fastapi import Request, Depends from sqlalchemy.orm import Session -from src.database import db_dependency +from src.database import DbSession from src.exceptions import UnauthorizedException from src.utils import send_email, generate_jwt from src.iam.models import Group @@ -23,7 +23,7 @@ from src.service.schemas import HasServiceName def valid_service_key( - db: db_dependency, request: Request, request_model: HasServiceName + db: DbSession, request: Request, request_model: HasServiceName ) -> bool: rn = request_model.rn api_key = request.headers.get("X-API-Key", None) @@ -90,7 +90,7 @@ async def create_group_and_assign_perms( async def assign_default_group( - db: db_dependency, + db: DbSession, org_model: Org, user_model: User, group_name: str, diff --git a/src/organisation/dependencies.py b/src/organisation/dependencies.py index 1ecdca8..2686560 100644 --- a/src/organisation/dependencies.py +++ b/src/organisation/dependencies.py @@ -10,14 +10,14 @@ from typing import Annotated, Optional from fastapi import Depends, Query -from src.database import db_dependency +from src.database import DbSession from src.organisation.schemas import OrgIDMixin from src.organisation.models import Organisation as Org from src.organisation.exceptions import OrgNotFoundException -def get_org_model_query(db: db_dependency, org_id: Annotated[int, Query(gt=0)]) -> Org: +def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org: org_model = db.get(Org, org_id) if org_model is None: raise OrgNotFoundException(org_id) @@ -27,7 +27,7 @@ def get_org_model_query(db: db_dependency, org_id: Annotated[int, Query(gt=0)]) org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)] -def get_org_model_body(db: db_dependency, request_model: OrgIDMixin) -> Org: +def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org: org_id: Optional[int] = getattr(request_model, "organisation_id", None) if org_id is None: raise OrgNotFoundException() diff --git a/src/organisation/router.py b/src/organisation/router.py index d968e8e..5f58852 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -33,7 +33,7 @@ from src.exceptions import ( from src.contact.models import Contact from src.contact.schemas import ContactAddress from src.contact.exceptions import ContactNotFoundException -from src.database import db_dependency +from src.database import DbSession from src.organisation.schemas_questionnaires import QuestionnaireQuestionsVersion0 from src.organisation.service import assign_defaults from src.user.dependencies import ( @@ -98,7 +98,7 @@ router = APIRouter( }, ) async def get_org_by_id( - db: db_dependency, org_model: org_model_root_claim_query_dependency + db: DbSession, org_model: org_model_root_claim_query_dependency ): """ Returns organisation details including key member email addresses @@ -143,7 +143,7 @@ async def get_org_by_id( }, ) async def create_org( - db: db_dependency, + db: DbSession, user_model: user_model_claims_dependency, request_model: OrgPostOrgRequest, background_tasks: BackgroundTasks, @@ -217,7 +217,7 @@ async def create_org( }, ) async def update_questionnaire( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchQuestionnaireRequest, ): @@ -281,7 +281,7 @@ async def update_questionnaire( }, ) async def update_status( - db: db_dependency, + db: DbSession, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchStatusRequest, @@ -338,7 +338,7 @@ async def get_users(org_model: org_model_root_claim_query_dependency): }, ) async def add_user_to_org( - db: db_dependency, + db: DbSession, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, @@ -380,7 +380,7 @@ async def add_user_to_org( }, ) async def delete_organisation_by_id( - db: db_dependency, + db: DbSession, org_model: org_model_query_dependency, su: super_admin_dependency, ): @@ -450,7 +450,7 @@ async def delete_organisation_by_id( }, ) async def delete_preapproved_organisation_by_id( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_query_dependency, ): """ @@ -478,7 +478,7 @@ async def delete_preapproved_organisation_by_id( }, ) async def update_root_user( - db: db_dependency, + db: DbSession, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, @@ -538,7 +538,7 @@ async def get_org_groups(org_model: org_model_root_claim_query_dependency): }, ) async def remove_user_from_org( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_query_dependency, user_model: user_model_query_dependency, ): @@ -609,7 +609,7 @@ async def get_contact( }, ) async def update_contact( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchContactRequest, ): diff --git a/src/service/dependencies.py b/src/service/dependencies.py index bf6b314..ea42c89 100644 --- a/src/service/dependencies.py +++ b/src/service/dependencies.py @@ -9,7 +9,7 @@ Exports: from typing import Annotated from fastapi import Depends, Query -from src.database import db_dependency +from src.database import DbSession from src.service.exceptions import ServiceNotFoundException from src.service.models import Service @@ -17,7 +17,7 @@ from src.service.schemas import ServiceIDMixin async def get_service_model_query( - db: db_dependency, service_id: Annotated[int, Query(gt=0)] + db: DbSession, service_id: Annotated[int, Query(gt=0)] ): service_model = db.get(Service, service_id) if service_model is None: @@ -29,7 +29,7 @@ async def get_service_model_query( service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)] -async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixin): +async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin): service_model = db.get(Service, request_model.service_id) if service_model is None: raise ServiceNotFoundException(service_id=request_model.service_id) diff --git a/src/service/router.py b/src/service/router.py index 143fd7a..54a7a3a 100644 --- a/src/service/router.py +++ b/src/service/router.py @@ -13,7 +13,7 @@ from sqlalchemy.exc import IntegrityError from psycopg.errors import UniqueViolation from src.exceptions import ConflictException -from src.database import db_dependency +from src.database import DbSession from src.auth.dependencies import ( super_admin_dependency, org_model_root_claim_query_dependency, @@ -77,7 +77,7 @@ router = APIRouter( }, ) async def get_all_services( - db: db_dependency, org_model: org_model_root_claim_query_dependency + db: DbSession, org_model: org_model_root_claim_query_dependency ): """ Returns the ID and name of all services registered to the hub. @@ -99,7 +99,7 @@ async def get_all_services( }, ) async def register_service( - db: db_dependency, + db: DbSession, su: super_admin_dependency, request_model: ServicePostServiceRequest, ): @@ -135,7 +135,7 @@ async def register_service( }, ) async def regenerate_api_key( - db: db_dependency, + db: DbSession, su: super_admin_dependency, service_model: service_model_body_dependency, request_model: ServicePatchKeyRequest, @@ -162,7 +162,7 @@ async def regenerate_api_key( }, ) async def remove_service( - db: db_dependency, + db: DbSession, service_model: service_model_query_dependency, su: super_admin_dependency, ): @@ -185,7 +185,7 @@ async def remove_service( }, ) async def service_create_new_permissions( - db: db_dependency, + db: DbSession, request_model: ServicePostPermissionsRequest, valid_key: service_key_dependency, ): diff --git a/src/user/dependencies.py b/src/user/dependencies.py index 0d50daa..de23693 100644 --- a/src/user/dependencies.py +++ b/src/user/dependencies.py @@ -14,11 +14,11 @@ from src.user.exceptions import UserNotFoundException from src.user.models import User from src.auth.service import claims_dependency -from src.database import db_dependency +from src.database import DbSession from src.schemas import UserIDMixin -async def get_user_model_claims(claims: claims_dependency, db: db_dependency): +async def get_user_model_claims(claims: claims_dependency, db: DbSession): user_id = claims.get("db_id", None) if user_id is None: raise UserNotFoundException() @@ -33,7 +33,7 @@ async def get_user_model_claims(claims: claims_dependency, db: db_dependency): user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)] -async def get_user_model_query(db: db_dependency, user_id: Annotated[int, Query(gt=0)]): +async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]): user_model = db.get(User, user_id) if user_model is None: raise UserNotFoundException(user_id=user_id) @@ -44,7 +44,7 @@ async def get_user_model_query(db: db_dependency, user_id: Annotated[int, Query( user_model_query_dependency = Annotated[User, Depends(get_user_model_query)] -async def get_user_model_body(db: db_dependency, request_model: UserIDMixin): +async def get_user_model_body(db: DbSession, request_model: UserIDMixin): user_model = db.get(User, request_model.user_id) if user_model is None: raise UserNotFoundException(user_id=request_model.user_id) diff --git a/src/user/router.py b/src/user/router.py index a5ae4f5..ad0ccb1 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -33,7 +33,7 @@ from src.auth.dependencies import ( org_model_root_claim_body_dependency, ) from src.auth.service import claims_dependency -from src.database import db_dependency +from src.database import DbSession from src.utils import verify_email_token router = APIRouter( @@ -105,7 +105,7 @@ async def get_user_by_id( }, ) async def delete_user_by_id( - db: db_dependency, + db: DbSession, user_model: user_model_query_dependency, su: super_admin_dependency, ): @@ -186,7 +186,7 @@ async def invitation( response_model=UserPostInvitationAcceptResponse, ) async def accept_invitation( - db: db_dependency, + db: DbSession, user_model: user_model_claims_dependency, request_model: UserPostInvitationAcceptRequest, ): From c2777db2e313b236d180f4b5f4dd7a854a7f089a Mon Sep 17 00:00:00 2001 From: renovate Date: Wed, 20 May 2026 09:25:00 +0000 Subject: [PATCH 05/33] Add renovate.json --- renovate.json | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 renovate.json diff --git a/renovate.json b/renovate.json new file mode 100644 index 0000000..c433e3a --- /dev/null +++ b/renovate.json @@ -0,0 +1,8 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "config:recommended" + ], + "minimumReleaseAge": "14 days", + "gitAuthor": "Renovate" +} From fe8f627fa5599ebc6743b56934d1a1a772af19a1 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 12:02:29 +0000 Subject: [PATCH 06/33] ci: reduce min age for renovate to 7 days --- renovate.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/renovate.json b/renovate.json index c433e3a..400c290 100644 --- a/renovate.json +++ b/renovate.json @@ -3,6 +3,6 @@ "extends": [ "config:recommended" ], - "minimumReleaseAge": "14 days", + "minimumReleaseAge": "7 days", "gitAuthor": "Renovate" } From 53b42b24dded04f051f24a83e3436fbc42374516 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 13:26:47 +0100 Subject: [PATCH 07/33] feat(utils): use logging around email send --- src/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/utils.py b/src/utils.py index b314d61..d471cfa 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,3 +1,5 @@ +import logging + from lettermint import Lettermint, ValidationError from datetime import datetime, timezone from joserfc import jwt, jwk, errors @@ -52,8 +54,6 @@ async def send_email(recipient: str, subject: str, body: str): .text(body) .send() ) - - print(response.status_code) - except ValidationError: - # Error thrown if domain not approved for project - print("Lettermint validation error") + logging.info("Email sent to {} with subject {} (Status: {})".format(recipient, subject, response.status_code)) + except ValidationError as e: + logging.exception(e) From 0baa50d10f5a2f69f1c7003a9279599b6ac2c8f7 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 13:30:53 +0100 Subject: [PATCH 08/33] misc: add frontend dir to .gitignore --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 3dc86d6..8a25589 100644 --- a/.gitignore +++ b/.gitignore @@ -206,5 +206,7 @@ marimo/_static/ marimo/_lsp/ __marimo__/ +endpoints.txt -endpoints.txt \ No newline at end of file +# React Frontend +/frontend/ \ No newline at end of file From 7e1ab6c6ee145ecfdf17938fd00d1e362747671b Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 13:23:06 +0100 Subject: [PATCH 09/33] feat: db model mixins --- src/models.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/models.py b/src/models.py index f2467de..3f2295d 100644 --- a/src/models.py +++ b/src/models.py @@ -5,8 +5,8 @@ Global database models from datetime import datetime from typing import Any -from sqlalchemy import DateTime, JSON -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy import DateTime, JSON, func +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column class CustomBase(DeclarativeBase): @@ -14,3 +14,24 @@ class CustomBase(DeclarativeBase): datetime: DateTime(timezone=True), dict[str, Any]: JSON, } + + +class ActivatedMixin: + active: Mapped[bool] = mapped_column(default=True) + + +class DeletedTimestampMixin: + deleted_at: Mapped[datetime | None] = mapped_column(nullable=True) + + +class DescriptionMixin: + description: Mapped[str] + + +class IdMixin: + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + + +class TimestampMixin: + created_at: Mapped[datetime] = mapped_column(default=func.now()) + updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now()) From c28b4dc37b289b05ba16e77602dee65ad71c1cdf Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 13:45:37 +0100 Subject: [PATCH 10/33] feat: applied model mixins IdMixin used on every table with an ID index (no changes needed to db) Timestamp and Deleted mixins applied to org and user tables. ActivatedMixin added to users. --- alembic/versions/2026-06-22_model_mixins.py | 44 +++++++++++++++++++++ src/contact/models.py | 5 ++- src/iam/models.py | 9 ++--- src/organisation/models.py | 7 ++-- src/service/models.py | 5 +-- src/user/models.py | 5 ++- 6 files changed, 60 insertions(+), 15 deletions(-) create mode 100644 alembic/versions/2026-06-22_model_mixins.py diff --git a/alembic/versions/2026-06-22_model_mixins.py b/alembic/versions/2026-06-22_model_mixins.py new file mode 100644 index 0000000..db030e1 --- /dev/null +++ b/alembic/versions/2026-06-22_model_mixins.py @@ -0,0 +1,44 @@ +"""model mixins + +Revision ID: 661202797ecd +Revises: 869d48618a1c +Create Date: 2026-06-22 13:29:39.689067 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '661202797ecd' +down_revision: Union[str, Sequence[str], None] = '869d48618a1c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('organisation', sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now())) + op.add_column('organisation', sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now())) + op.add_column('organisation', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True)) + op.add_column('user', sa.Column('active', sa.Boolean(), nullable=False, server_default=sa.false())) + op.add_column('user', sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now())) + op.add_column('user', sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now())) + op.add_column('user', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('user', 'deleted_at') + op.drop_column('user', 'updated_at') + op.drop_column('user', 'created_at') + op.drop_column('user', 'active') + op.drop_column('organisation', 'deleted_at') + op.drop_column('organisation', 'updated_at') + op.drop_column('organisation', 'created_at') + # ### end Alembic commands ### diff --git a/src/contact/models.py b/src/contact/models.py index ca359cd..37665e5 100644 --- a/src/contact/models.py +++ b/src/contact/models.py @@ -6,16 +6,17 @@ Models: street_address, street_address_line_2, post_office_box_number, address_locality, country_code, address_region, postal_code """ +from src.models import IdMixin + from sqlalchemy import ForeignKey from sqlalchemy.orm import mapped_column, Mapped from src.models import CustomBase -class Contact(CustomBase): +class Contact(CustomBase, IdMixin): __tablename__ = "contact" - id: Mapped[int] = mapped_column(primary_key=True) email: Mapped[str] = mapped_column(default=None, nullable=True) first_name: Mapped[str] = mapped_column(default=None, nullable=True) last_name: Mapped[str] = mapped_column(default=None, nullable=True) diff --git a/src/iam/models.py b/src/iam/models.py index 1f6d9ba..7fdb7c7 100644 --- a/src/iam/models.py +++ b/src/iam/models.py @@ -21,13 +21,12 @@ Models: from sqlalchemy import ForeignKey, UniqueConstraint from sqlalchemy.orm import relationship, mapped_column, Mapped -from src.models import CustomBase +from src.models import CustomBase, IdMixin -class Permission(CustomBase): +class Permission(CustomBase, IdMixin): __tablename__ = "permission" - id: Mapped[int] = mapped_column(primary_key=True) resource: Mapped[str] action: Mapped[str] @@ -61,9 +60,9 @@ class Permission(CustomBase): return self.service_rel.name -class Group(CustomBase): +class Group(CustomBase, IdMixin): __tablename__ = "group" - id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE")) diff --git a/src/organisation/models.py b/src/organisation/models.py index e6f5acf..97d1247 100644 --- a/src/organisation/models.py +++ b/src/organisation/models.py @@ -14,18 +14,19 @@ Models: - OrgUsers: org_id[FK][PK], user_id[FK][PK] """ +from src.models import IdMixin, DeletedTimestampMixin + from typing import Any from sqlalchemy import ForeignKey from sqlalchemy.orm import relationship, Mapped, mapped_column -from src.models import CustomBase +from src.models import CustomBase, TimestampMixin -class Organisation(CustomBase): +class Organisation(CustomBase, IdMixin, TimestampMixin, DeletedTimestampMixin): __tablename__ = "organisation" - id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] status: Mapped[str] = mapped_column(default="partial") intake_questionnaire: Mapped[dict[str, Any] | None] diff --git a/src/service/models.py b/src/service/models.py index 63719a6..de6dcd6 100644 --- a/src/service/models.py +++ b/src/service/models.py @@ -8,13 +8,12 @@ Models: from sqlalchemy.orm import relationship, mapped_column, Mapped -from src.models import CustomBase +from src.models import CustomBase, IdMixin -class Service(CustomBase): +class Service(CustomBase, IdMixin): __tablename__ = "service" - id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(unique=True) api_key: Mapped[str] diff --git a/src/user/models.py b/src/user/models.py index 5f603ee..4803d54 100644 --- a/src/user/models.py +++ b/src/user/models.py @@ -10,6 +10,8 @@ Models: - groups: Calc property dict of {group_rel.org_rel.name: group_rel.name} """ +from src.models import IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin + from collections import defaultdict from sqlalchemy.orm import relationship, mapped_column, Mapped @@ -17,10 +19,9 @@ from sqlalchemy.orm import relationship, mapped_column, Mapped from src.models import CustomBase -class User(CustomBase): +class User(CustomBase, IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin): __tablename__ = "user" - id: Mapped[int] = mapped_column(primary_key=True) email: Mapped[str] first_name: Mapped[str] last_name: Mapped[str] From 9e1d6026b5545d33ad539176972a3dcbddc70e56 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 14:24:34 +0100 Subject: [PATCH 11/33] feat: adds Containerfile with frontend serving --- .python-version | 2 +- Containerfile | 42 ++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 6 ++++-- src/main.py | 2 ++ uv.lock | 11 +++++++---- 5 files changed, 56 insertions(+), 7 deletions(-) create mode 100644 Containerfile diff --git a/.python-version b/.python-version index 6324d40..e4fba21 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.14 +3.12 diff --git a/Containerfile b/Containerfile new file mode 100644 index 0000000..ab6c9c4 --- /dev/null +++ b/Containerfile @@ -0,0 +1,42 @@ +FROM node:22-slim AS react-builder + +WORKDIR /app +COPY frontend/ /app/ +RUN --mount=type=cache,target=/root/.npm npm ci +RUN npm run build # Outputs to /app/dist + +FROM ghcr.io/astral-sh/uv:python3.12-trixie-slim AS python-builder + +ENV UV_PYTHON_DOWNLOADS=0 + +WORKDIR /app + +# Install dependencies first (layer caching) +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,source=uv.lock,target=uv.lock \ + --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ + uv sync --locked --no-install-project --no-editable + +# Copy project source and install the project itself +COPY ./ /app/ +RUN --mount=type=cache,target=/root/.cache/uv \ + uv sync --locked --no-editable + + +FROM python:3.12-slim-trixie + +WORKDIR /app + +COPY alembic /app/alembic +COPY alembic.ini /app +COPY src /app/src +COPY --from=python-builder /app/.venv /app/.venv +COPY --from=react-builder /app/dist /app/static + +# Ensure venv is on PATH +ENV PATH="/app/.venv/bin:$PATH" \ + UV_PYTHON_DOWNLOADS=0 + +EXPOSE 8000 + +CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/pyproject.toml b/pyproject.toml index bd23ff7..bb7d902 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires-python = ">=3.12" dependencies = [ "alembic>=1.18.4", "email-validator>=2.3.0", - "fastapi>=0.136.3", + "fastapi>=0.138.0", "httptools>=0.7.1", "httpx>=0.28.1", "itsdangerous>=2.2.0", @@ -34,11 +34,13 @@ line-length = 92 [tool.ruff.format] quote-style = "double" -indent-style = "tab" [tool.uv] add-bounds = "major" exclude-newer = "P2W" +exclude-newer-package = { + "fastapi" = "2026-06-22T00:00:00Z" +} [dependency-groups] dev = [ diff --git a/src/main.py b/src/main.py index bf671db..022cb6f 100644 --- a/src/main.py +++ b/src/main.py @@ -77,3 +77,5 @@ if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL): app.include_router(api_router) + +app.frontend("/ui", directory="/app/static", fallback="index.html") diff --git a/uv.lock b/uv.lock index f4be9e1..147ce12 100644 --- a/uv.lock +++ b/uv.lock @@ -6,6 +6,9 @@ requires-python = ">=3.12" exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. exclude-newer-span = "P2W" +[options.exclude-newer-package] +fastapi = "2026-06-22T00:00:00Z" + [[package]] name = "alembic" version = "1.18.4" @@ -238,7 +241,7 @@ dev = [ requires-dist = [ { name = "alembic", specifier = ">=1.18.4" }, { name = "email-validator", specifier = ">=2.3.0" }, - { name = "fastapi", specifier = ">=0.136.3" }, + { name = "fastapi", specifier = ">=0.138.0" }, { name = "httptools", specifier = ">=0.7.1" }, { name = "httpx", specifier = ">=0.28.1" }, { name = "itsdangerous", specifier = ">=2.2.0" }, @@ -349,7 +352,7 @@ wheels = [ [[package]] name = "fastapi" -version = "0.136.3" +version = "0.138.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-doc" }, @@ -358,9 +361,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "typing-inspection" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/81/2d/ff8d91d7b564d464629a0fd50a4489c97fcb836ac230bf3a7269232a9b1f/fastapi-0.136.3.tar.gz", hash = "sha256:e487fae93ad408e6f47641ee4dfe389864fd7bec92e547ea8498fc13f43e83ab", size = 396410, upload-time = "2026-05-23T18:53:15.192Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5b/58/ff455d9fe47c60abadb34b9e05a304b1f05f5ab8000ac01565156b6f5e43/fastapi-0.138.0.tar.gz", hash = "sha256:d445a4877636ad191e7053e08c9bf98cb921a6756776848400bb773d1740c061", size = 419240, upload-time = "2026-06-20T01:18:05.259Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e0/82/45359b62a067409bd929ae8a56b8ed13e5a8c8a61194b3c236920999ab83/fastapi-0.136.3-py3-none-any.whl", hash = "sha256:3d2a69bdf04b7e9f3afa292c3bc7a98816bbfafa10bc9b45f3f3700d2f761620", size = 117481, upload-time = "2026-05-23T18:53:16.924Z" }, + { url = "https://files.pythonhosted.org/packages/6c/ff/8496d9847a5fedae775eb49460722d3efaa80487854273e9647ae876218c/fastapi-0.138.0-py3-none-any.whl", hash = "sha256:b6f54fd1bd72c80b0f899f172c61a600f6f7af9b43d4d772a018f35624048cb0", size = 126779, upload-time = "2026-06-20T01:18:03.483Z" }, ] [[package]] From df8ab32cb1b23efea62bf223aa788acee7463475 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 14:38:11 +0100 Subject: [PATCH 12/33] ci: build and publish OCI image --- .forgejo/workflows/publish.yaml | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/.forgejo/workflows/publish.yaml b/.forgejo/workflows/publish.yaml index 85bbfcb..a16e5a9 100644 --- a/.forgejo/workflows/publish.yaml +++ b/.forgejo/workflows/publish.yaml @@ -52,3 +52,27 @@ jobs: - run: uv run pytest test env: ENVIRONMENT: testing + + build: + needs: [ ruff, ty, tests ] + if: ${{ always() && needs.ruff.result == 'success' && needs.ty.result == 'success' && needs.tests.result == 'success' }} + container: + image: ghcr.io/catthehacker/ubuntu:act-latest + options: -v /dind/docker.sock:/var/run/docker.sock + steps: + - name: Checkout the repo + uses: actions/checkout@v4 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Login to the registry + uses: docker/login-action@v3 + with: + registry: guardianproject.dev + username: irl + password: ${{ secrets.PACKAGE_TOKEN }} + - name: Build and push + uses: docker/build-push-action@v6 + with: + file: Containerfile + push: true + tags: guardianproject.dev/${{ github.repository }}:${{ github.branch }} From 1384ee7bd62e59a5df64dd09c86a781685644550 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 14:40:02 +0100 Subject: [PATCH 13/33] feat: adds empty static directory for frontend --- static/.gitkeep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 static/.gitkeep diff --git a/static/.gitkeep b/static/.gitkeep new file mode 100644 index 0000000..e69de29 From d395b01997b429ac056d84c620b5d3e9d4b1fa61 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 14:42:13 +0100 Subject: [PATCH 14/33] fix: only serve frontend if present in prod --- src/main.py | 5 +++-- static/.gitkeep | 0 2 files changed, 3 insertions(+), 2 deletions(-) delete mode 100644 static/.gitkeep diff --git a/src/main.py b/src/main.py index 022cb6f..a96eddc 100644 --- a/src/main.py +++ b/src/main.py @@ -1,7 +1,7 @@ """ Application root file: Inits the FastAPI application """ - +import os.path from contextlib import asynccontextmanager from typing import AsyncGenerator @@ -78,4 +78,5 @@ if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL): app.include_router(api_router) -app.frontend("/ui", directory="/app/static", fallback="index.html") +if os.path.exists("/app/static"): + app.frontend("/ui", directory="/app/static", fallback="index.html") diff --git a/static/.gitkeep b/static/.gitkeep deleted file mode 100644 index e69de29..0000000 From 40918fd8b8faebc18a3565b9660b41ed6322b5c7 Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 13:53:15 +0100 Subject: [PATCH 15/33] feat: delete org soft deletes --- src/organisation/router.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/organisation/router.py b/src/organisation/router.py index d968e8e..8bd5ad7 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -387,7 +387,8 @@ async def delete_organisation_by_id( """ Removes an organisation from the hub. """ - db.delete(org_model) + org_model.status = "removed" + org_model.deleted_at = datetime.now(tz=timezone.utc) db.commit() From 84ba3b6bee0dbc65abfed27206a35997c72dee42 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 12:58:37 +0100 Subject: [PATCH 16/33] feat(db): db tuning options and consistency --- src/auth/service.py | 4 +- src/config.py | 4 ++ src/database.py | 74 ++++++++++++++++++++------------ src/iam/dependencies.py | 10 ++--- src/iam/router.py | 26 +++++------ src/iam/service.py | 6 +-- src/organisation/dependencies.py | 6 +-- src/organisation/router.py | 22 +++++----- src/service/dependencies.py | 6 +-- src/service/router.py | 12 +++--- src/user/dependencies.py | 8 ++-- src/user/router.py | 6 +-- 12 files changed, 104 insertions(+), 80 deletions(-) diff --git a/src/auth/service.py b/src/auth/service.py index 2675dc2..4b42590 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -23,7 +23,7 @@ from src.organisation.models import Organisation as Org from src.exceptions import UnauthorizedException, ForbiddenException from src.auth.config import auth_settings from src.user.service import add_user -from src.database import db_dependency +from src.database import DbSession oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) @@ -35,7 +35,7 @@ async def get_dev_user(): async def get_current_user( - oidc_auth_string: oidc_dependency, db: db_dependency + oidc_auth_string: oidc_dependency, db: DbSession ) -> dict[str, Any]: config_url = urlopen(auth_settings.OIDC_CONFIG) config = json.loads(config_url.read()) diff --git a/src/config.py b/src/config.py index 6ebfcf5..6fc68f0 100644 --- a/src/config.py +++ b/src/config.py @@ -36,6 +36,10 @@ class Config(CustomBaseSettings): DATABASE_HOSTNAME: str = "localhost" DATABASE_CREDENTIALS: SecretStr = SecretStr(":") + DATABASE_POOL_SIZE: int = 16 + DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes + DATABASE_POOL_PRE_PING: bool = True + LETTERMINT_API_TOKEN: SecretStr = SecretStr("") diff --git a/src/database.py b/src/database.py index fb29a41..41509d4 100644 --- a/src/database.py +++ b/src/database.py @@ -1,13 +1,9 @@ """ -Database connections and init - -Exports: - - db_dependency - - Base (sqlalchemy base model) +Database connection and session utilities """ - -from typing import Annotated -from sqlalchemy import create_engine, StaticPool +from contextlib import contextmanager +from typing import Annotated, Generator +from sqlalchemy import create_engine, StaticPool, Connection from sqlalchemy.orm import sessionmaker, Session from fastapi import Depends @@ -16,28 +12,52 @@ from src.constants import Environment from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings if global_settings.ENVIRONMENT == Environment.TESTING: - connect_args = {"check_same_thread": False} - engine = create_engine( - SQLALCHEMY_DATABASE_URI.get_secret_value(), - connect_args=connect_args, - poolclass=StaticPool, - ) + connect_args = {"check_same_thread": False} + engine = create_engine( + SQLALCHEMY_DATABASE_URI.get_secret_value(), + connect_args=connect_args, + poolclass=StaticPool, + ) else: - engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value()) + engine = create_engine( + SQLALCHEMY_DATABASE_URI.get_secret_value(), + pool_size=global_settings.DATABASE_POOL_SIZE, + pool_recycle=global_settings.DATABASE_POOL_TTL, + pool_pre_ping=global_settings.DATABASE_POOL_PRE_PING, + ) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine) + +@contextmanager +def get_db_connection() -> Generator[Connection, None, None]: + with engine.connect() as connection: + try: + yield connection + except Exception: + connection.rollback() + raise + +def _get_db_connection() -> Generator[Connection, None]: + with get_db_connection() as connection: + yield connection + +DbConnection = Annotated[Connection, Depends(_get_db_connection)] + +@contextmanager +def get_db_session() -> Generator[Session, None, None]: + session = sm() + try: + yield session + except Exception: + session.rollback() + raise + finally: + session.close() -def get_db(): - db = SessionLocal() - try: - yield db - except: - db.rollback() - raise - finally: - db.close() +def _get_db_session() -> Generator[Session, None]: + with get_db_session() as session: + yield session - -db_dependency = Annotated[Session, Depends(get_db)] +DbSession = Annotated[Session, Depends(_get_db_session)] diff --git a/src/iam/dependencies.py b/src/iam/dependencies.py index 14fb4ad..113a3c6 100644 --- a/src/iam/dependencies.py +++ b/src/iam/dependencies.py @@ -11,7 +11,7 @@ from typing import Annotated, Optional from fastapi import Depends, Query -from src.database import db_dependency +from src.database import DbSession from src.iam.models import Group, Permission from src.iam.exceptions import GroupNotFoundException, PermNotFoundException @@ -19,7 +19,7 @@ from src.iam.schemas import GroupIDMixin, PermIDMixin def get_group_model_query( - db: db_dependency, group_id: Annotated[int, Query(gt=0)] + db: DbSession, group_id: Annotated[int, Query(gt=0)] ) -> Group: group_model = db.get(Group, group_id) if group_model is None: @@ -32,7 +32,7 @@ group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)] def get_group_model_body( - db: db_dependency, request_model: Optional[GroupIDMixin] = None + db: DbSession, request_model: Optional[GroupIDMixin] = None ) -> Group: group_id = getattr(request_model, "group_id", None) if group_id is None: @@ -48,7 +48,7 @@ group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)] def get_perm_model_body( - db: db_dependency, request_model: Optional[PermIDMixin] = None + db: DbSession, request_model: Optional[PermIDMixin] = None ) -> Permission: perm_id = getattr(request_model, "permission_id", None) if perm_id is None: @@ -64,7 +64,7 @@ perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)] def get_perm_model_query( - db: db_dependency, perm_id: Annotated[int, Query(gt=0)] + db: DbSession, perm_id: Annotated[int, Query(gt=0)] ) -> Permission: perm_model = db.get(Permission, perm_id) if perm_model is None: diff --git a/src/iam/router.py b/src/iam/router.py index af783a1..1b2c297 100644 --- a/src/iam/router.py +++ b/src/iam/router.py @@ -32,7 +32,7 @@ from src.exceptions import ( ForbiddenException, UnprocessableContentException, ) -from src.database import db_dependency +from src.database import DbSession from src.auth.service import claims_dependency from src.auth.dependencies import ( org_model_root_claim_query_dependency, @@ -107,7 +107,7 @@ router = APIRouter( ) async def can_act_on_resource( valid_key: service_key_dependency, - db: db_dependency, + db: DbSession, user_claims: claims_dependency, request_model: IAMCAoRRequest, ): @@ -270,7 +270,7 @@ async def get_group_users( }, ) async def create_group( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_body_dependency, request_model: IAMPostGroupRequest, ): @@ -310,7 +310,7 @@ async def create_group( }, ) async def add_group_permission( - db: db_dependency, + db: DbSession, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, @@ -356,7 +356,7 @@ async def add_group_permission( }, ) async def add_group_user( - db: db_dependency, + db: DbSession, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, @@ -399,7 +399,7 @@ async def add_group_user( }, ) async def remove_group_permission( - db: db_dependency, + db: DbSession, group_model: group_model_query_dependency, perm_model: perm_model_query_dependency, org_model: org_model_root_claim_query_dependency, @@ -436,7 +436,7 @@ async def remove_group_permission( }, ) async def remove_group_user( - db: db_dependency, + db: DbSession, group_model: group_model_query_dependency, user_model: user_model_query_dependency, org_model: org_model_root_claim_query_dependency, @@ -469,7 +469,7 @@ async def remove_group_user( }, ) async def get_permissions( - db: db_dependency, org_model: org_model_root_claim_query_dependency + db: DbSession, org_model: org_model_root_claim_query_dependency ): """ Returns a full list of permissions. @@ -493,7 +493,7 @@ async def get_permissions( }, ) async def create_new_permission( - db: db_dependency, + db: DbSession, su: super_admin_dependency, request_model: IAMPostPermissionRequest, service_model: service_model_body_dependency, # Used to verify service model exists @@ -529,7 +529,7 @@ async def create_new_permission( responses={}, ) async def delete_permission( - db: db_dependency, + db: DbSession, su: super_admin_dependency, perm_model: perm_model_query_dependency, ): @@ -548,7 +548,7 @@ async def delete_permission( responses={}, ) async def permissions_search( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_body_dependency, request_model: IAMGetPermissionsSearchRequest, ): @@ -632,7 +632,7 @@ async def invitation( }, ) async def accept_invitation( - db: db_dependency, + db: DbSession, user_model: user_model_claims_dependency, request_model: IAMPutGroupInvitationAcceptRequest, ): @@ -678,7 +678,7 @@ async def accept_invitation( }, ) async def add_org_permissions( - db: db_dependency, + db: DbSession, su: super_admin_dependency, org_model: org_model_body_dependency, request_model: IAMPutOrgPermissionsRequest, diff --git a/src/iam/service.py b/src/iam/service.py index 4f23969..2112c6c 100644 --- a/src/iam/service.py +++ b/src/iam/service.py @@ -10,7 +10,7 @@ from datetime import datetime, timedelta, timezone from fastapi import Request, Depends from sqlalchemy.orm import Session -from src.database import db_dependency +from src.database import DbSession from src.exceptions import UnauthorizedException from src.utils import send_email, generate_jwt from src.iam.models import Group @@ -23,7 +23,7 @@ from src.service.schemas import HasServiceName def valid_service_key( - db: db_dependency, request: Request, request_model: HasServiceName + db: DbSession, request: Request, request_model: HasServiceName ) -> bool: rn = request_model.rn api_key = request.headers.get("X-API-Key", None) @@ -90,7 +90,7 @@ async def create_group_and_assign_perms( async def assign_default_group( - db: db_dependency, + db: DbSession, org_model: Org, user_model: User, group_name: str, diff --git a/src/organisation/dependencies.py b/src/organisation/dependencies.py index 1ecdca8..2686560 100644 --- a/src/organisation/dependencies.py +++ b/src/organisation/dependencies.py @@ -10,14 +10,14 @@ from typing import Annotated, Optional from fastapi import Depends, Query -from src.database import db_dependency +from src.database import DbSession from src.organisation.schemas import OrgIDMixin from src.organisation.models import Organisation as Org from src.organisation.exceptions import OrgNotFoundException -def get_org_model_query(db: db_dependency, org_id: Annotated[int, Query(gt=0)]) -> Org: +def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org: org_model = db.get(Org, org_id) if org_model is None: raise OrgNotFoundException(org_id) @@ -27,7 +27,7 @@ def get_org_model_query(db: db_dependency, org_id: Annotated[int, Query(gt=0)]) org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)] -def get_org_model_body(db: db_dependency, request_model: OrgIDMixin) -> Org: +def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org: org_id: Optional[int] = getattr(request_model, "organisation_id", None) if org_id is None: raise OrgNotFoundException() diff --git a/src/organisation/router.py b/src/organisation/router.py index 8bd5ad7..78159ca 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -33,7 +33,7 @@ from src.exceptions import ( from src.contact.models import Contact from src.contact.schemas import ContactAddress from src.contact.exceptions import ContactNotFoundException -from src.database import db_dependency +from src.database import DbSession from src.organisation.schemas_questionnaires import QuestionnaireQuestionsVersion0 from src.organisation.service import assign_defaults from src.user.dependencies import ( @@ -98,7 +98,7 @@ router = APIRouter( }, ) async def get_org_by_id( - db: db_dependency, org_model: org_model_root_claim_query_dependency + db: DbSession, org_model: org_model_root_claim_query_dependency ): """ Returns organisation details including key member email addresses @@ -143,7 +143,7 @@ async def get_org_by_id( }, ) async def create_org( - db: db_dependency, + db: DbSession, user_model: user_model_claims_dependency, request_model: OrgPostOrgRequest, background_tasks: BackgroundTasks, @@ -217,7 +217,7 @@ async def create_org( }, ) async def update_questionnaire( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchQuestionnaireRequest, ): @@ -281,7 +281,7 @@ async def update_questionnaire( }, ) async def update_status( - db: db_dependency, + db: DbSession, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchStatusRequest, @@ -338,7 +338,7 @@ async def get_users(org_model: org_model_root_claim_query_dependency): }, ) async def add_user_to_org( - db: db_dependency, + db: DbSession, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, @@ -380,7 +380,7 @@ async def add_user_to_org( }, ) async def delete_organisation_by_id( - db: db_dependency, + db: DbSession, org_model: org_model_query_dependency, su: super_admin_dependency, ): @@ -451,7 +451,7 @@ async def delete_organisation_by_id( }, ) async def delete_preapproved_organisation_by_id( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_query_dependency, ): """ @@ -479,7 +479,7 @@ async def delete_preapproved_organisation_by_id( }, ) async def update_root_user( - db: db_dependency, + db: DbSession, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, @@ -539,7 +539,7 @@ async def get_org_groups(org_model: org_model_root_claim_query_dependency): }, ) async def remove_user_from_org( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_query_dependency, user_model: user_model_query_dependency, ): @@ -610,7 +610,7 @@ async def get_contact( }, ) async def update_contact( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchContactRequest, ): diff --git a/src/service/dependencies.py b/src/service/dependencies.py index bf6b314..ea42c89 100644 --- a/src/service/dependencies.py +++ b/src/service/dependencies.py @@ -9,7 +9,7 @@ Exports: from typing import Annotated from fastapi import Depends, Query -from src.database import db_dependency +from src.database import DbSession from src.service.exceptions import ServiceNotFoundException from src.service.models import Service @@ -17,7 +17,7 @@ from src.service.schemas import ServiceIDMixin async def get_service_model_query( - db: db_dependency, service_id: Annotated[int, Query(gt=0)] + db: DbSession, service_id: Annotated[int, Query(gt=0)] ): service_model = db.get(Service, service_id) if service_model is None: @@ -29,7 +29,7 @@ async def get_service_model_query( service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)] -async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixin): +async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin): service_model = db.get(Service, request_model.service_id) if service_model is None: raise ServiceNotFoundException(service_id=request_model.service_id) diff --git a/src/service/router.py b/src/service/router.py index 143fd7a..54a7a3a 100644 --- a/src/service/router.py +++ b/src/service/router.py @@ -13,7 +13,7 @@ from sqlalchemy.exc import IntegrityError from psycopg.errors import UniqueViolation from src.exceptions import ConflictException -from src.database import db_dependency +from src.database import DbSession from src.auth.dependencies import ( super_admin_dependency, org_model_root_claim_query_dependency, @@ -77,7 +77,7 @@ router = APIRouter( }, ) async def get_all_services( - db: db_dependency, org_model: org_model_root_claim_query_dependency + db: DbSession, org_model: org_model_root_claim_query_dependency ): """ Returns the ID and name of all services registered to the hub. @@ -99,7 +99,7 @@ async def get_all_services( }, ) async def register_service( - db: db_dependency, + db: DbSession, su: super_admin_dependency, request_model: ServicePostServiceRequest, ): @@ -135,7 +135,7 @@ async def register_service( }, ) async def regenerate_api_key( - db: db_dependency, + db: DbSession, su: super_admin_dependency, service_model: service_model_body_dependency, request_model: ServicePatchKeyRequest, @@ -162,7 +162,7 @@ async def regenerate_api_key( }, ) async def remove_service( - db: db_dependency, + db: DbSession, service_model: service_model_query_dependency, su: super_admin_dependency, ): @@ -185,7 +185,7 @@ async def remove_service( }, ) async def service_create_new_permissions( - db: db_dependency, + db: DbSession, request_model: ServicePostPermissionsRequest, valid_key: service_key_dependency, ): diff --git a/src/user/dependencies.py b/src/user/dependencies.py index 0d50daa..de23693 100644 --- a/src/user/dependencies.py +++ b/src/user/dependencies.py @@ -14,11 +14,11 @@ from src.user.exceptions import UserNotFoundException from src.user.models import User from src.auth.service import claims_dependency -from src.database import db_dependency +from src.database import DbSession from src.schemas import UserIDMixin -async def get_user_model_claims(claims: claims_dependency, db: db_dependency): +async def get_user_model_claims(claims: claims_dependency, db: DbSession): user_id = claims.get("db_id", None) if user_id is None: raise UserNotFoundException() @@ -33,7 +33,7 @@ async def get_user_model_claims(claims: claims_dependency, db: db_dependency): user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)] -async def get_user_model_query(db: db_dependency, user_id: Annotated[int, Query(gt=0)]): +async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]): user_model = db.get(User, user_id) if user_model is None: raise UserNotFoundException(user_id=user_id) @@ -44,7 +44,7 @@ async def get_user_model_query(db: db_dependency, user_id: Annotated[int, Query( user_model_query_dependency = Annotated[User, Depends(get_user_model_query)] -async def get_user_model_body(db: db_dependency, request_model: UserIDMixin): +async def get_user_model_body(db: DbSession, request_model: UserIDMixin): user_model = db.get(User, request_model.user_id) if user_model is None: raise UserNotFoundException(user_id=request_model.user_id) diff --git a/src/user/router.py b/src/user/router.py index a5ae4f5..ad0ccb1 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -33,7 +33,7 @@ from src.auth.dependencies import ( org_model_root_claim_body_dependency, ) from src.auth.service import claims_dependency -from src.database import db_dependency +from src.database import DbSession from src.utils import verify_email_token router = APIRouter( @@ -105,7 +105,7 @@ async def get_user_by_id( }, ) async def delete_user_by_id( - db: db_dependency, + db: DbSession, user_model: user_model_query_dependency, su: super_admin_dependency, ): @@ -186,7 +186,7 @@ async def invitation( response_model=UserPostInvitationAcceptResponse, ) async def accept_invitation( - db: db_dependency, + db: DbSession, user_model: user_model_claims_dependency, request_model: UserPostInvitationAcceptRequest, ): From a343b76f63be88cbf92cdb5d3f77f4a1800b7405 Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 15:01:36 +0100 Subject: [PATCH 17/33] fix: invalid toml syntax --- pyproject.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bb7d902..0bcba23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,9 +38,7 @@ quote-style = "double" [tool.uv] add-bounds = "major" exclude-newer = "P2W" -exclude-newer-package = { - "fastapi" = "2026-06-22T00:00:00Z" -} +exclude-newer-package = { "fastapi" = "2026-06-22T00:00:00Z" } [dependency-groups] dev = [ From 1a851859d04df23b4a92f732ceb2f2a27dcd05d6 Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 15:02:04 +0100 Subject: [PATCH 18/33] fix: logging import for email --- src/user/service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/user/service.py b/src/user/service.py index b70056f..8721b88 100644 --- a/src/user/service.py +++ b/src/user/service.py @@ -5,6 +5,7 @@ Module specific business logic for user module from typing import Any from datetime import datetime, timedelta, timezone +import logging from sqlalchemy.orm import Session from src.exceptions import UnprocessableContentException From b2921b73b8f2912088b28e18ac7fd19150be8465 Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 15:02:39 +0100 Subject: [PATCH 19/33] fix: conftest match db changes --- test/conftest.py | 100 ++++++++++++++++++++++++----------------------- 1 file changed, 51 insertions(+), 49 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 6411b96..6821b9d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -14,7 +14,7 @@ from src.iam.models import Group, Permission, OrgPermissions from src.auth.service import get_current_user, get_dev_user from src.auth.dependencies import empty_su_list, get_super_admin_list, testing_su_list from src.main import app # inited FastAPI app -from src.database import engine, get_db +from src.database import engine, get_db_session from src.models import CustomBase SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @@ -35,51 +35,51 @@ def db_session(): @pytest.fixture async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]: - def get_db_override(): - return db_session + def get_db_override(): + return db_session - app.dependency_overrides[get_db] = get_db_override - app.dependency_overrides[get_current_user] = get_dev_user - app.dependency_overrides[get_super_admin_list] = testing_su_list - transport = ASGITransport(app=app) - async with AsyncClient( - transport=transport, base_url="http://localhost:8000/api/v1" - ) as ac: - yield ac + app.dependency_overrides[get_db_session] = get_db_override + app.dependency_overrides[get_current_user] = get_dev_user + app.dependency_overrides[get_super_admin_list] = testing_su_list + transport = ASGITransport(app=app) + async with AsyncClient( + transport=transport, base_url="http://localhost:8000/api/v1" + ) as ac: + yield ac - app.dependency_overrides.clear() + app.dependency_overrides.clear() @pytest.fixture async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]: - def get_db_override(): - return db_session + def get_db_override(): + return db_session - app.dependency_overrides[get_db] = get_db_override - transport = ASGITransport(app=app) - async with AsyncClient( - transport=transport, base_url="http://localhost:8000/api/v1" - ) as ac: - yield ac + app.dependency_overrides[get_db_session] = get_db_override + transport = ASGITransport(app=app) + async with AsyncClient( + transport=transport, base_url="http://localhost:8000/api/v1" + ) as ac: + yield ac - app.dependency_overrides.clear() + app.dependency_overrides.clear() @pytest.fixture async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]: - def get_db_override(): - return db_session + def get_db_override(): + return db_session - app.dependency_overrides[get_db] = get_db_override - app.dependency_overrides[get_current_user] = get_dev_user - app.dependency_overrides[get_super_admin_list] = empty_su_list - transport = ASGITransport(app=app) - async with AsyncClient( - transport=transport, base_url="http://localhost:8000/api/v1" - ) as ac: - yield ac + app.dependency_overrides[get_db_session] = get_db_override + app.dependency_overrides[get_current_user] = get_dev_user + app.dependency_overrides[get_super_admin_list] = empty_su_list + transport = ASGITransport(app=app) + async with AsyncClient( + transport=transport, base_url="http://localhost:8000/api/v1" + ) as ac: + yield ac - app.dependency_overrides.clear() + app.dependency_overrides.clear() def _seed(db): @@ -256,27 +256,29 @@ def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]: def get_testable_routes(): - routes = [] + routes = [] - for route in app.routes: - if not isinstance(route, APIRoute): - continue + for route in app.routes: + if not isinstance(route, APIRoute): + continue + if not route.methods: + continue - for method in route.methods: - if method in {"HEAD", "OPTIONS"}: - continue + for method in route.methods: + if method in {"HEAD", "OPTIONS"}: + continue - routes.append( - ( - method, - route.path, - route.status_code, - route.response_model, - route.summary, - ) - ) + routes.append( + ( + method, + route.path, + route.status_code, + route.response_model, + route.summary, + ) + ) - return routes + return routes # with open("endpoints.txt", "w") as f: From fab228bf8f78418702ca4e224038941f32ce14c3 Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 15:04:11 +0100 Subject: [PATCH 20/33] minor: ruff format Tabs -> spaces --- src/_module_template/router.py | 2 +- src/admin/router.py | 4 +- src/api.py | 14 +- src/auth/config.py | 6 +- src/auth/dependencies.py | 78 +- src/auth/router.py | 2 +- src/auth/service.py | 72 +- src/config.py | 46 +- src/constants.py | 42 +- src/contact/exceptions.py | 20 +- src/contact/models.py | 32 +- src/contact/router.py | 4 +- src/contact/schemas.py | 28 +- src/database.py | 6 + src/exceptions.py | 48 +- src/iam/dependencies.py | 56 +- src/iam/exceptions.py | 40 +- src/iam/models.py | 124 +-- src/iam/router.py | 1050 ++++++++++---------- src/iam/schemas.py | 134 +-- src/iam/service.py | 124 +-- src/main.py | 71 +- src/models.py | 20 +- src/organisation/constants.py | 70 +- src/organisation/dependencies.py | 22 +- src/organisation/exceptions.py | 40 +- src/organisation/models.py | 74 +- src/organisation/router.py | 1000 ++++++++++--------- src/organisation/schemas.py | 132 +-- src/organisation/schemas_questionnaires.py | 8 +- src/organisation/service.py | 82 +- src/schemas.py | 34 +- src/service/dependencies.py | 20 +- src/service/exceptions.py | 20 +- src/service/models.py | 12 +- src/service/router.py | 342 ++++--- src/service/schemas.py | 42 +- src/service/utils.py | 2 +- src/user/dependencies.py | 30 +- src/user/exceptions.py | 20 +- src/user/models.py | 36 +- src/user/router.py | 280 +++--- src/user/schemas.py | 78 +- src/user/service.py | 76 +- src/utils.py | 66 +- test/conftest.py | 320 +++--- test/test_auth_approval.py | 182 ++-- test/test_auth_general.py | 12 +- test/test_auth_root.py | 172 ++-- test/test_auth_su.py | 74 +- test/test_auth_user.py | 20 +- test/test_healthcheck.py | 6 +- test/test_iam.py | 1000 +++++++++---------- test/test_organisation.py | 644 ++++++------ test/test_service.py | 78 +- test/test_user.py | 242 ++--- 56 files changed, 3629 insertions(+), 3630 deletions(-) diff --git a/src/_module_template/router.py b/src/_module_template/router.py index 09250df..7d29b56 100644 --- a/src/_module_template/router.py +++ b/src/_module_template/router.py @@ -22,5 +22,5 @@ from fastapi import APIRouter router = APIRouter( - tags=[""], + tags=[""], ) diff --git a/src/admin/router.py b/src/admin/router.py index 9fe91eb..8742405 100644 --- a/src/admin/router.py +++ b/src/admin/router.py @@ -8,6 +8,6 @@ Exports: from fastapi import APIRouter router = APIRouter( - tags=["admin"], - prefix="/admin", + tags=["admin"], + prefix="/admin", ) diff --git a/src/api.py b/src/api.py index 0e46a53..9ad8b5d 100644 --- a/src/api.py +++ b/src/api.py @@ -26,15 +26,15 @@ api_router.include_router(iam_router) class HealthCheckResponse(CustomBaseModel): - status: str + status: str @api_router.get( - path="/healthcheck", - status_code=status.HTTP_200_OK, - response_model=HealthCheckResponse, - include_in_schema=False, + path="/healthcheck", + status_code=status.HTTP_200_OK, + response_model=HealthCheckResponse, + include_in_schema=False, ) def healthcheck(): - """Simple health check endpoint.""" - return {"status": "ok"} + """Simple health check endpoint.""" + return {"status": "ok"} diff --git a/src/auth/config.py b/src/auth/config.py index 030c36e..0e2734a 100644 --- a/src/auth/config.py +++ b/src/auth/config.py @@ -9,9 +9,9 @@ from src.config import CustomBaseSettings class AuthConfig(CustomBaseSettings): - OIDC_CONFIG: str = "" - OIDC_ISSUER: str = "" - CLIENT_ID: str = "" + OIDC_CONFIG: str = "" + OIDC_ISSUER: str = "" + CLIENT_ID: str = "" auth_settings = AuthConfig() diff --git a/src/auth/dependencies.py b/src/auth/dependencies.py index 7cf4e9f..65fb16d 100644 --- a/src/auth/dependencies.py +++ b/src/auth/dependencies.py @@ -16,92 +16,92 @@ from src.exceptions import ForbiddenException from src.user.dependencies import user_model_claims_dependency from src.user.models import User from src.organisation.dependencies import ( - org_model_query_dependency, - org_model_body_dependency, + org_model_query_dependency, + org_model_body_dependency, ) from src.organisation.models import Organisation as Org async def org_query_user_claims( - org_model: org_model_query_dependency, user_model: user_model_claims_dependency + org_model: org_model_query_dependency, user_model: user_model_claims_dependency ): - if user_model in org_model.user_rel: - return True + if user_model in org_model.user_rel: + return True - raise ForbiddenException() + raise ForbiddenException() org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)] def get_super_admin_list(): - return [] + return [] def empty_su_list(): - return [] + return [] def testing_su_list(): - return ["admin@test.com"] + return ["admin@test.com"] su_list_dependency = Annotated[list[str | None], Depends(get_super_admin_list)] async def user_model_super_admin( - user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency + user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency ): - if user_model.email in super_admin_emails: - return user_model + if user_model.email in super_admin_emails: + return user_model - raise ForbiddenException(message="Must be super admin") + raise ForbiddenException(message="Must be super admin") super_admin_dependency = Annotated[User, Depends(user_model_super_admin)] async def org_query_root_claims( - user_model: user_model_claims_dependency, - org_model: org_model_query_dependency, - su_emails: su_list_dependency, - request: Request, + user_model: user_model_claims_dependency, + org_model: org_model_query_dependency, + su_emails: su_list_dependency, + request: Request, ): - try: - if await user_model_super_admin(user_model, su_emails): - return org_model - except ForbiddenException: - pass + try: + if await user_model_super_admin(user_model, su_emails): + return org_model + except ForbiddenException: + pass - await org_status_check(org_model, request) + await org_status_check(org_model, request) - if org_model.root_user_id == user_model.id: - return org_model + if org_model.root_user_id == user_model.id: + return org_model - raise ForbiddenException(message="Must be the org's root user") + raise ForbiddenException(message="Must be the org's root user") org_model_root_claim_query_dependency = Annotated[Org, Depends(org_query_root_claims)] async def org_body_root_claims( - user_model: user_model_claims_dependency, - org_model: org_model_body_dependency, - su_emails: su_list_dependency, - request: Request, + user_model: user_model_claims_dependency, + org_model: org_model_body_dependency, + su_emails: su_list_dependency, + request: Request, ): - try: - if await user_model_super_admin(user_model, su_emails): - return org_model - except ForbiddenException: - pass + try: + if await user_model_super_admin(user_model, su_emails): + return org_model + except ForbiddenException: + pass - await org_status_check(org_model, request) + await org_status_check(org_model, request) - if org_model.root_user_id == user_model.id: - return org_model + if org_model.root_user_id == user_model.id: + return org_model - raise ForbiddenException(message="Must be the org's root user") + raise ForbiddenException(message="Must be the org's root user") org_model_root_claim_body_dependency = Annotated[Org, Depends(org_body_root_claims)] diff --git a/src/auth/router.py b/src/auth/router.py index ee32033..b472586 100644 --- a/src/auth/router.py +++ b/src/auth/router.py @@ -8,5 +8,5 @@ Exports: from fastapi import APIRouter router = APIRouter( - tags=["auth"], + tags=["auth"], ) diff --git a/src/auth/service.py b/src/auth/service.py index 4b42590..1808341 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -31,56 +31,56 @@ oidc_dependency = Annotated[str, Depends(oidc)] async def get_dev_user(): - return {"db_id": 1, "email": "chris@sr2.uk"} + return {"db_id": 1, "email": "chris@sr2.uk"} async def get_current_user( - oidc_auth_string: oidc_dependency, db: DbSession + oidc_auth_string: oidc_dependency, db: DbSession ) -> dict[str, Any]: - config_url = urlopen(auth_settings.OIDC_CONFIG) - config = json.loads(config_url.read()) - jwks_uri = config["jwks_uri"] - key_response = requests.get(jwks_uri) - jwk_keys = KeySet.import_key_set(key_response.json()) + config_url = urlopen(auth_settings.OIDC_CONFIG) + config = json.loads(config_url.read()) + jwks_uri = config["jwks_uri"] + key_response = requests.get(jwks_uri) + jwk_keys = KeySet.import_key_set(key_response.json()) - token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys) + token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys) - claims_requests = jwt.JWTClaimsRegistry( - exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER} - ) + claims_requests = jwt.JWTClaimsRegistry( + exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER} + ) - try: - claims_requests.validate(token.claims) - except ExpiredTokenError: - raise UnauthorizedException(message="Token is expired") - db_id = await add_user(db, token.claims) + try: + claims_requests.validate(token.claims) + except ExpiredTokenError: + raise UnauthorizedException(message="Token is expired") + db_id = await add_user(db, token.claims) - token.claims["db_id"] = db_id + token.claims["db_id"] = db_id - return token.claims + return token.claims claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] async def org_status_check(org_model: Org, request: Request): - org_status = OrgStatus(org_model.status) - if org_status.is_blocked: - raise ForbiddenException("This organisation cannot perform this action.") + org_status = OrgStatus(org_model.status) + if org_status.is_blocked: + raise ForbiddenException("This organisation cannot perform this action.") - root = "/api/v1" + root = "/api/v1" - pre_approval_endpoints = [ - f"PATCH{root}/org/status", - f"PATCH{root}/org/questionnaire", - f"GET{root}/org", - f"GET{root}/org/contact", - f"PATCH{root}/org/contact", - f"DELETE{root}/org/self", - ] - current_request = f"{request.method}{request.url.path}" - if ( - current_request not in pre_approval_endpoints - and org_model.status != OrgStatus.APPROVED - ): - raise AwaitingApprovalException(org_model.id) + pre_approval_endpoints = [ + f"PATCH{root}/org/status", + f"PATCH{root}/org/questionnaire", + f"GET{root}/org", + f"GET{root}/org/contact", + f"PATCH{root}/org/contact", + f"DELETE{root}/org/self", + ] + current_request = f"{request.method}{request.url.path}" + if ( + current_request not in pre_approval_endpoints + and org_model.status != OrgStatus.APPROVED + ): + raise AwaitingApprovalException(org_model.id) diff --git a/src/config.py b/src/config.py index 6fc68f0..1b46118 100644 --- a/src/config.py +++ b/src/config.py @@ -16,31 +16,31 @@ from src.constants import Environment class CustomBaseSettings(BaseSettings): - model_config = SettingsConfigDict( - env_file=".env", env_file_encoding="utf-8", extra="ignore" - ) + model_config = SettingsConfigDict( + env_file=".env", env_file_encoding="utf-8", extra="ignore" + ) class Config(CustomBaseSettings): - APP_VERSION: str = "0.1" - ENVIRONMENT: Environment = Environment.PRODUCTION - SECRET_KEY: SecretStr = SecretStr("") - DISABLE_AUTH: bool = False + APP_VERSION: str = "0.1" + ENVIRONMENT: Environment = Environment.PRODUCTION + SECRET_KEY: SecretStr = SecretStr("") + DISABLE_AUTH: bool = False - CORS_ORIGINS: list[str] = ["*"] - CORS_ORIGINS_REGEX: str | None = None - CORS_HEADERS: list[str] = ["*"] + CORS_ORIGINS: list[str] = ["*"] + CORS_ORIGINS_REGEX: str | None = None + CORS_HEADERS: list[str] = ["*"] - DATABASE_NAME: str = "fastapi-exp" - DATABASE_PORT: str = "5432" - DATABASE_HOSTNAME: str = "localhost" - DATABASE_CREDENTIALS: SecretStr = SecretStr(":") + DATABASE_NAME: str = "fastapi-exp" + DATABASE_PORT: str = "5432" + DATABASE_HOSTNAME: str = "localhost" + DATABASE_CREDENTIALS: SecretStr = SecretStr(":") - DATABASE_POOL_SIZE: int = 16 - DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes - DATABASE_POOL_PRE_PING: bool = True + DATABASE_POOL_SIZE: int = 16 + DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes + DATABASE_POOL_PRE_PING: bool = True - LETTERMINT_API_TOKEN: SecretStr = SecretStr("") + LETTERMINT_API_TOKEN: SecretStr = SecretStr("") settings = Config() @@ -51,20 +51,20 @@ DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value() # this will support special chars for credentials _DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split( - ":" + ":" ) _QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD)) SQLALCHEMY_DATABASE_URI = SecretStr( - f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}" + f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}" ) if settings.ENVIRONMENT == Environment.TESTING: - SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:") + SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:") app_configs: dict[str, Any] = {"title": "App API"} if settings.ENVIRONMENT.is_deployed: - app_configs["root_path"] = f"/v{settings.APP_VERSION}" + app_configs["root_path"] = f"/v{settings.APP_VERSION}" if not settings.ENVIRONMENT.is_debug: - app_configs["openapi_url"] = None # hide docs + app_configs["openapi_url"] = None # hide docs diff --git a/src/constants.py b/src/constants.py index f9237e4..d18e300 100644 --- a/src/constants.py +++ b/src/constants.py @@ -9,29 +9,29 @@ from enum import StrEnum, auto class Environment(StrEnum): - """ - Enumeration of environments. + """ + Enumeration of environments. - Attributes: - LOCAL (str): Application is running locally - TESTING (str): Application is running in testing mode - STAGING (str): Application is running in staging mode (ie not testing) - PRODUCTION (str): Application is running in production mode - """ + Attributes: + LOCAL (str): Application is running locally + TESTING (str): Application is running in testing mode + STAGING (str): Application is running in staging mode (ie not testing) + PRODUCTION (str): Application is running in production mode + """ - LOCAL = auto() - TESTING = auto() - STAGING = auto() - PRODUCTION = auto() + LOCAL = auto() + TESTING = auto() + STAGING = auto() + PRODUCTION = auto() - @property - def is_debug(self): - return self in (self.LOCAL, self.STAGING, self.TESTING) + @property + def is_debug(self): + return self in (self.LOCAL, self.STAGING, self.TESTING) - @property - def is_testing(self): - return self == self.TESTING + @property + def is_testing(self): + return self == self.TESTING - @property - def is_deployed(self) -> bool: - return self in (self.STAGING, self.PRODUCTION) + @property + def is_deployed(self) -> bool: + return self in (self.STAGING, self.PRODUCTION) diff --git a/src/contact/exceptions.py b/src/contact/exceptions.py index 55e9e30..7c36941 100644 --- a/src/contact/exceptions.py +++ b/src/contact/exceptions.py @@ -11,13 +11,13 @@ from fastapi import HTTPException, status class ContactNotFoundException(HTTPException): - def __init__(self, contact_id: Optional[int] = None) -> None: - detail = ( - "Contact not found" - if contact_id is None - else f"Contact with ID '{contact_id}' was not found." - ) - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - ) + def __init__(self, contact_id: Optional[int] = None) -> None: + detail = ( + "Contact not found" + if contact_id is None + else f"Contact with ID '{contact_id}' was not found." + ) + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) diff --git a/src/contact/models.py b/src/contact/models.py index 37665e5..4bf3a79 100644 --- a/src/contact/models.py +++ b/src/contact/models.py @@ -15,22 +15,22 @@ from src.models import CustomBase class Contact(CustomBase, IdMixin): - __tablename__ = "contact" + __tablename__ = "contact" - email: Mapped[str] = mapped_column(default=None, nullable=True) - first_name: Mapped[str] = mapped_column(default=None, nullable=True) - last_name: Mapped[str] = mapped_column(default=None, nullable=True) - phonenumber: Mapped[str] = mapped_column(default=None, nullable=True) - vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True) + email: Mapped[str] = mapped_column(default=None, nullable=True) + first_name: Mapped[str] = mapped_column(default=None, nullable=True) + last_name: Mapped[str] = mapped_column(default=None, nullable=True) + phonenumber: Mapped[str] = mapped_column(default=None, nullable=True) + vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True) - street_address: Mapped[str] = mapped_column(default=None, nullable=True) - street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True) - post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True) - locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City - country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB - address_region: Mapped[str | None] = mapped_column(default=None, nullable=True) - postal_code: Mapped[str] = mapped_column(default=None, nullable=True) + street_address: Mapped[str] = mapped_column(default=None, nullable=True) + street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True) + post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True) + locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City + country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB + address_region: Mapped[str | None] = mapped_column(default=None, nullable=True) + postal_code: Mapped[str] = mapped_column(default=None, nullable=True) - org_id: Mapped[int] = mapped_column( - ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False - ) + org_id: Mapped[int] = mapped_column( + ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False + ) diff --git a/src/contact/router.py b/src/contact/router.py index 2e5f8f4..6f2376b 100644 --- a/src/contact/router.py +++ b/src/contact/router.py @@ -6,6 +6,6 @@ from fastapi import APIRouter router = APIRouter( - prefix="/contact", - tags=["contact"], + prefix="/contact", + tags=["contact"], ) diff --git a/src/contact/schemas.py b/src/contact/schemas.py index b9cec61..f9d4635 100644 --- a/src/contact/schemas.py +++ b/src/contact/schemas.py @@ -14,22 +14,22 @@ from src.schemas import CustomBaseModel class ContactAddress(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") + model_config = ConfigDict(from_attributes=True, extra="ignore") - post_office_box_number: Optional[str] = None - street_address: Optional[str] = None - street_address_line_2: Optional[str] = None - locality: Optional[str] = None - address_region: Optional[str] = None - country_code: Optional[str] = None - postal_code: Optional[str] = None + post_office_box_number: Optional[str] = None + street_address: Optional[str] = None + street_address_line_2: Optional[str] = None + locality: Optional[str] = None + address_region: Optional[str] = None + country_code: Optional[str] = None + postal_code: Optional[str] = None class ContactModel(CustomBaseModel): - email: Optional[EmailStr] = None - first_name: Optional[str] = None - last_name: Optional[str] = None - phonenumber: Optional[str] = None - vat_number: Optional[str] = None + email: Optional[EmailStr] = None + first_name: Optional[str] = None + last_name: Optional[str] = None + phonenumber: Optional[str] = None + vat_number: Optional[str] = None - address: ContactAddress + address: ContactAddress diff --git a/src/database.py b/src/database.py index 41509d4..01c55c8 100644 --- a/src/database.py +++ b/src/database.py @@ -1,6 +1,7 @@ """ Database connection and session utilities """ + from contextlib import contextmanager from typing import Annotated, Generator from sqlalchemy import create_engine, StaticPool, Connection @@ -29,6 +30,7 @@ else: sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine) + @contextmanager def get_db_connection() -> Generator[Connection, None, None]: with engine.connect() as connection: @@ -38,12 +40,15 @@ def get_db_connection() -> Generator[Connection, None, None]: connection.rollback() raise + def _get_db_connection() -> Generator[Connection, None]: with get_db_connection() as connection: yield connection + DbConnection = Annotated[Connection, Depends(_get_db_connection)] + @contextmanager def get_db_session() -> Generator[Session, None, None]: session = sm() @@ -60,4 +65,5 @@ def _get_db_session() -> Generator[Session, None]: with get_db_session() as session: yield session + DbSession = Annotated[Session, Depends(_get_db_session)] diff --git a/src/exceptions.py b/src/exceptions.py index 1f5fdcb..f34ad35 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -12,36 +12,36 @@ from fastapi import HTTPException, status class UnprocessableContentException(HTTPException): - def __init__(self, message: Optional[str] = None) -> None: - detail = "Unprocessable content" if not message else message - super().__init__( - status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, - detail=detail, - ) + def __init__(self, message: Optional[str] = None) -> None: + detail = "Unprocessable content" if not message else message + super().__init__( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail=detail, + ) class ConflictException(HTTPException): - def __init__(self, message: Optional[str] = None) -> None: - detail = "Conflict" if not message else message - super().__init__( - status_code=status.HTTP_409_CONFLICT, - detail=detail, - ) + def __init__(self, message: Optional[str] = None) -> None: + detail = "Conflict" if not message else message + super().__init__( + status_code=status.HTTP_409_CONFLICT, + detail=detail, + ) class ForbiddenException(HTTPException): - def __init__(self, message: Optional[str] = None) -> None: - detail = "Forbidden" if not message else message - super().__init__( - status_code=status.HTTP_403_FORBIDDEN, - detail=detail, - ) + def __init__(self, message: Optional[str] = None) -> None: + detail = "Forbidden" if not message else message + super().__init__( + status_code=status.HTTP_403_FORBIDDEN, + detail=detail, + ) class UnauthorizedException(HTTPException): - def __init__(self, message: Optional[str] = None) -> None: - detail = "Not authorized" if not message else message - super().__init__( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=detail, - ) + def __init__(self, message: Optional[str] = None) -> None: + detail = "Not authorized" if not message else message + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=detail, + ) diff --git a/src/iam/dependencies.py b/src/iam/dependencies.py index 113a3c6..e559bc1 100644 --- a/src/iam/dependencies.py +++ b/src/iam/dependencies.py @@ -18,59 +18,55 @@ from src.iam.exceptions import GroupNotFoundException, PermNotFoundException from src.iam.schemas import GroupIDMixin, PermIDMixin -def get_group_model_query( - db: DbSession, group_id: Annotated[int, Query(gt=0)] -) -> Group: - group_model = db.get(Group, group_id) - if group_model is None: - raise GroupNotFoundException(group_id) +def get_group_model_query(db: DbSession, group_id: Annotated[int, Query(gt=0)]) -> Group: + group_model = db.get(Group, group_id) + if group_model is None: + raise GroupNotFoundException(group_id) - return group_model + return group_model group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)] def get_group_model_body( - db: DbSession, request_model: Optional[GroupIDMixin] = None + db: DbSession, request_model: Optional[GroupIDMixin] = None ) -> Group: - group_id = getattr(request_model, "group_id", None) - if group_id is None: - raise GroupNotFoundException() - group_model = db.get(Group, group_id) - if group_model is None: - raise GroupNotFoundException(group_id) + group_id = getattr(request_model, "group_id", None) + if group_id is None: + raise GroupNotFoundException() + group_model = db.get(Group, group_id) + if group_model is None: + raise GroupNotFoundException(group_id) - return group_model + return group_model group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)] def get_perm_model_body( - db: DbSession, request_model: Optional[PermIDMixin] = None + db: DbSession, request_model: Optional[PermIDMixin] = None ) -> Permission: - perm_id = getattr(request_model, "permission_id", None) - if perm_id is None: - raise PermNotFoundException - perm_model = db.get(Permission, perm_id) - if perm_model is None: - raise PermNotFoundException(perm_id) + perm_id = getattr(request_model, "permission_id", None) + if perm_id is None: + raise PermNotFoundException + perm_model = db.get(Permission, perm_id) + if perm_model is None: + raise PermNotFoundException(perm_id) - return perm_model + return perm_model perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)] -def get_perm_model_query( - db: DbSession, perm_id: Annotated[int, Query(gt=0)] -) -> Permission: - perm_model = db.get(Permission, perm_id) - if perm_model is None: - raise PermNotFoundException(perm_id) +def get_perm_model_query(db: DbSession, perm_id: Annotated[int, Query(gt=0)]) -> Permission: + perm_model = db.get(Permission, perm_id) + if perm_model is None: + raise PermNotFoundException(perm_id) - return perm_model + return perm_model perm_model_query_dependency = Annotated[Permission, Depends(get_perm_model_query)] diff --git a/src/iam/exceptions.py b/src/iam/exceptions.py index 503b844..7d38887 100644 --- a/src/iam/exceptions.py +++ b/src/iam/exceptions.py @@ -12,26 +12,26 @@ from fastapi import HTTPException, status class GroupNotFoundException(HTTPException): - def __init__(self, group_id: Optional[int] = None) -> None: - detail = ( - "Group not found" - if group_id is None - else f"User with ID '{group_id}' was not found." - ) - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - ) + def __init__(self, group_id: Optional[int] = None) -> None: + detail = ( + "Group not found" + if group_id is None + else f"User with ID '{group_id}' was not found." + ) + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) class PermNotFoundException(HTTPException): - def __init__(self, perm_id: Optional[int] = None) -> None: - detail = ( - "Permission not found" - if perm_id is None - else f"User with ID '{perm_id}' was not found." - ) - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - ) + def __init__(self, perm_id: Optional[int] = None) -> None: + detail = ( + "Permission not found" + if perm_id is None + else f"User with ID '{perm_id}' was not found." + ) + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) diff --git a/src/iam/models.py b/src/iam/models.py index 7fdb7c7..9a0fc36 100644 --- a/src/iam/models.py +++ b/src/iam/models.py @@ -25,90 +25,90 @@ from src.models import CustomBase, IdMixin class Permission(CustomBase, IdMixin): - __tablename__ = "permission" + __tablename__ = "permission" - resource: Mapped[str] - action: Mapped[str] + resource: Mapped[str] + action: Mapped[str] - service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE")) + service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE")) - __table_args__ = ( - UniqueConstraint( - "service_id", - "resource", - "action", - name="uniq_permission_resource_and_action", - ), - ) + __table_args__ = ( + UniqueConstraint( + "service_id", + "resource", + "action", + name="uniq_permission_resource_and_action", + ), + ) - service_rel = relationship( - "Service", - back_populates="permission_rel", - foreign_keys="Permission.service_id", - ) + service_rel = relationship( + "Service", + back_populates="permission_rel", + foreign_keys="Permission.service_id", + ) - group_rel = relationship( - "Group", secondary="group_permissions", back_populates="permission_rel" - ) + group_rel = relationship( + "Group", secondary="group_permissions", back_populates="permission_rel" + ) - org_rel = relationship( - "Organisation", secondary="org_permissions", back_populates="permission_rel" - ) + org_rel = relationship( + "Organisation", secondary="org_permissions", back_populates="permission_rel" + ) - @property - def service_name(self): - return self.service_rel.name + @property + def service_name(self): + return self.service_rel.name class Group(CustomBase, IdMixin): - __tablename__ = "group" + __tablename__ = "group" - name: Mapped[str] + name: Mapped[str] - org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE")) + org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE")) - __table_args__ = ( - UniqueConstraint( - "name", - "org_id", - name="uniq_group_name_org_id", - ), - ) + __table_args__ = ( + UniqueConstraint( + "name", + "org_id", + name="uniq_group_name_org_id", + ), + ) - user_rel = relationship("User", secondary="user_groups", back_populates="group_rel") + user_rel = relationship("User", secondary="user_groups", back_populates="group_rel") - org_rel = relationship("Organisation", back_populates="group_rel") + org_rel = relationship("Organisation", back_populates="group_rel") - permission_rel = relationship( - "Permission", secondary="group_permissions", back_populates="group_rel" - ) + permission_rel = relationship( + "Permission", secondary="group_permissions", back_populates="group_rel" + ) class GroupPermissions(CustomBase): - __tablename__ = "group_permissions" - group_id: Mapped[int] = mapped_column( - ForeignKey("group.id", ondelete="CASCADE"), primary_key=True - ) - permission_id: Mapped[int] = mapped_column( - ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True - ) + __tablename__ = "group_permissions" + group_id: Mapped[int] = mapped_column( + ForeignKey("group.id", ondelete="CASCADE"), primary_key=True + ) + permission_id: Mapped[int] = mapped_column( + ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True + ) class UserGroups(CustomBase): - __tablename__ = "user_groups" - user_id: Mapped[int] = mapped_column( - ForeignKey("user.id", ondelete="CASCADE"), primary_key=True - ) - group_id: Mapped[int] = mapped_column( - ForeignKey("group.id", ondelete="CASCADE"), primary_key=True - ) + __tablename__ = "user_groups" + user_id: Mapped[int] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), primary_key=True + ) + group_id: Mapped[int] = mapped_column( + ForeignKey("group.id", ondelete="CASCADE"), primary_key=True + ) class OrgPermissions(CustomBase): - __tablename__ = "org_permissions" - org_id: Mapped[int] = mapped_column( - ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True - ) - permission_id: Mapped[int] = mapped_column( - ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True - ) + __tablename__ = "org_permissions" + org_id: Mapped[int] = mapped_column( + ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True + ) + permission_id: Mapped[int] = mapped_column( + ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True + ) diff --git a/src/iam/router.py b/src/iam/router.py index 1b2c297..2d50b36 100644 --- a/src/iam/router.py +++ b/src/iam/router.py @@ -28,674 +28,672 @@ from src.organisation.exceptions import OrgNotFoundException from src.schemas import GroupSummary, OrgSummary from src.service.dependencies import service_model_body_dependency from src.exceptions import ( - ConflictException, - ForbiddenException, - UnprocessableContentException, + ConflictException, + ForbiddenException, + UnprocessableContentException, ) from src.database import DbSession from src.auth.service import claims_dependency from src.auth.dependencies import ( - org_model_root_claim_query_dependency, - org_model_root_claim_body_dependency, - super_admin_dependency, + org_model_root_claim_query_dependency, + org_model_root_claim_body_dependency, + super_admin_dependency, ) from src.user.models import User from src.user.dependencies import ( - user_model_body_dependency, - user_model_query_dependency, - user_model_claims_dependency, + user_model_body_dependency, + user_model_query_dependency, + user_model_claims_dependency, ) from src.organisation.models import Organisation as Org from src.service.models import Service from src.iam.service import service_key_dependency, send_user_group_invitation from src.iam.models import ( - Permission as Perm, - GroupPermissions as GPerms, - Group, - UserGroups, + Permission as Perm, + GroupPermissions as GPerms, + Group, + UserGroups, ) from src.iam.dependencies import ( - group_model_query_dependency, - group_model_body_dependency, - perm_model_body_dependency, - perm_model_query_dependency, + group_model_query_dependency, + group_model_body_dependency, + perm_model_body_dependency, + perm_model_query_dependency, ) from src.iam.schemas import ( - IAMCAoRRequest, - IAMGetGroupPermissionsResponse, - IAMGetGroupUsersResponse, - IAMPostGroupRequest, - IAMPostGroupResponse, - IAMPutGroupPermissionRequest, - IAMPutGroupPermissionResponse, - IAMPutGroupUserRequest, - IAMPutGroupUserResponse, - IAMDeleteGroupPermissionResponse, - IAMDeleteGroupUserResponse, - IAMGetPermissionsResponse, - IAMPostPermissionRequest, - IAMPostPermissionResponse, - IAMGetPermissionsSearchRequest, - IAMGetPermissionsSearchResponse, - IAMPutGroupInvitationRequest, - IAMPutGroupInvitationAcceptRequest, - IAMCAoRResponse, - IAMPutGroupInvitationAcceptResponse, - IAMPutGroupInvitationResponse, - IAMPutOrgPermissionsRequest, - IAMPutOrgPermissionsResponse, + IAMCAoRRequest, + IAMGetGroupPermissionsResponse, + IAMGetGroupUsersResponse, + IAMPostGroupRequest, + IAMPostGroupResponse, + IAMPutGroupPermissionRequest, + IAMPutGroupPermissionResponse, + IAMPutGroupUserRequest, + IAMPutGroupUserResponse, + IAMDeleteGroupPermissionResponse, + IAMDeleteGroupUserResponse, + IAMGetPermissionsResponse, + IAMPostPermissionRequest, + IAMPostPermissionResponse, + IAMGetPermissionsSearchRequest, + IAMGetPermissionsSearchResponse, + IAMPutGroupInvitationRequest, + IAMPutGroupInvitationAcceptRequest, + IAMCAoRResponse, + IAMPutGroupInvitationAcceptResponse, + IAMPutGroupInvitationResponse, + IAMPutOrgPermissionsRequest, + IAMPutOrgPermissionsResponse, ) from src.utils import verify_email_token router = APIRouter( - tags=["IAM"], - prefix="/iam", + tags=["IAM"], + prefix="/iam", ) @router.post( - path="/can_act_on_resource", - summary="Used for services to check user access permission", - status_code=status.HTTP_200_OK, - response_model=IAMCAoRResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "API Key missing or invalid | Issue verifying user OIDC claims" - }, - }, + path="/can_act_on_resource", + summary="Used for services to check user access permission", + status_code=status.HTTP_200_OK, + response_model=IAMCAoRResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "API Key missing or invalid | Issue verifying user OIDC claims" + }, + }, ) async def can_act_on_resource( - valid_key: service_key_dependency, - db: DbSession, - user_claims: claims_dependency, - request_model: IAMCAoRRequest, + valid_key: service_key_dependency, + db: DbSession, + user_claims: claims_dependency, + request_model: IAMCAoRRequest, ): - """ - This endpoint is not meant for the Hub frontend to interact with.\n - Services accessing this endpoint must be already registered within the Hub and been issued an API key.\n - Resource Names have an instance property but permissions do not presently have that level of granularity.\n - """ - response = { - "allowed": False, - "rn": request_model.rn, - "action": "", - "user": {"id": 0, "email": ""}, - } + """ + This endpoint is not meant for the Hub frontend to interact with.\n + Services accessing this endpoint must be already registered within the Hub and been issued an API key.\n + Resource Names have an instance property but permissions do not presently have that level of granularity.\n + """ + response = { + "allowed": False, + "rn": request_model.rn, + "action": "", + "user": {"id": 0, "email": ""}, + } - try: - rn = request_model.rn - action = request_model.action - user_id = user_claims["db_id"] - rn_org = rn.organisation_id - rn_service = rn.service - rn_resource = rn.resource + try: + rn = request_model.rn + action = request_model.action + user_id = user_claims["db_id"] + rn_org = rn.organisation_id + rn_service = rn.service + rn_resource = rn.resource - response["user"] = {"id": user_id, "email": user_claims["email"]} - response["action"] = action - response["rn"] = rn + response["user"] = {"id": user_id, "email": user_claims["email"]} + response["action"] = action + response["rn"] = rn - result = ( - db.query(Perm) - .join(Service, Service.id == Perm.service_id) - .join(GPerms, GPerms.permission_id == Perm.id) - .join(Group, Group.id == GPerms.group_id) - .join(Org, Org.id == Group.org_id) - .join(UserGroups, UserGroups.group_id == Group.id) - .join(User, User.id == UserGroups.user_id) - .filter(User.id == user_id) - .filter(Org.id == rn_org) - .filter(Service.name == rn_service) - .filter(Perm.resource == rn_resource) - .filter(Perm.action == action) - ).first() + result = ( + db.query(Perm) + .join(Service, Service.id == Perm.service_id) + .join(GPerms, GPerms.permission_id == Perm.id) + .join(Group, Group.id == GPerms.group_id) + .join(Org, Org.id == Group.org_id) + .join(UserGroups, UserGroups.group_id == Group.id) + .join(User, User.id == UserGroups.user_id) + .filter(User.id == user_id) + .filter(Org.id == rn_org) + .filter(Service.name == rn_service) + .filter(Perm.resource == rn_resource) + .filter(Perm.action == action) + ).first() - if result: - response["allowed"] = True - else: - response["allowed"] = False - except Exception as e: - print(e) - response["allowed"] = False + if result: + response["allowed"] = True + else: + response["allowed"] = False + except Exception as e: + print(e) + response["allowed"] = False - return response + return response @router.get( - path="/group/permissions", - summary="Gets a list of permissions granted to a group", - status_code=status.HTTP_200_OK, - response_model=IAMGetGroupPermissionsResponse, - responses={ - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Unprocessable content.", - "content": { - "application/json": { - "examples": { - "org_id": {"summary": "Invalid or missing org ID."}, - "oidc_claims": {"summary": "Invalid or missing OIDC claims."}, - } - } - }, - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Unauthorized", - "content": { - "application/json": { - "examples": { - "awaiting_approval": { - "summary": "Organisation has not yet been approved." - }, - "expired_token": {"summary": "User token has expired."}, - "oidc": {"summary": "Failed to verify OIDC claims."}, - } - } - }, - }, - status.HTTP_403_FORBIDDEN: { - "description": "Forbidden", - "content": { - "application/json": { - "examples": { - "not_root": {"summary": "Not authorised. Must be root user."}, - } - } - }, - }, - status.HTTP_404_NOT_FOUND: { - "description": "Not found", - "content": { - "application/json": { - "examples": { - "db_id": {"summary": "User not found in db when checking claims."}, - "user_model": {"summary": "User model not found in db."}, - "org_model": {"summary": "Org model not found in db."}, - "group_model": {"summary": "Group model not found in db."}, - } - } - }, - }, - }, + path="/group/permissions", + summary="Gets a list of permissions granted to a group", + status_code=status.HTTP_200_OK, + response_model=IAMGetGroupPermissionsResponse, + responses={ + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Unprocessable content.", + "content": { + "application/json": { + "examples": { + "org_id": {"summary": "Invalid or missing org ID."}, + "oidc_claims": {"summary": "Invalid or missing OIDC claims."}, + } + } + }, + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Unauthorized", + "content": { + "application/json": { + "examples": { + "awaiting_approval": { + "summary": "Organisation has not yet been approved." + }, + "expired_token": {"summary": "User token has expired."}, + "oidc": {"summary": "Failed to verify OIDC claims."}, + } + } + }, + }, + status.HTTP_403_FORBIDDEN: { + "description": "Forbidden", + "content": { + "application/json": { + "examples": { + "not_root": {"summary": "Not authorised. Must be root user."}, + } + } + }, + }, + status.HTTP_404_NOT_FOUND: { + "description": "Not found", + "content": { + "application/json": { + "examples": { + "db_id": {"summary": "User not found in db when checking claims."}, + "user_model": {"summary": "User model not found in db."}, + "org_model": {"summary": "Org model not found in db."}, + "group_model": {"summary": "Group model not found in db."}, + } + } + }, + }, + }, ) async def get_group_permissions( - group_model: group_model_query_dependency, - org_model: org_model_root_claim_query_dependency, + group_model: group_model_query_dependency, + org_model: org_model_root_claim_query_dependency, ): - """ - Gets a list of permissions granted to the group. Also returns a summary for the org and group. - """ - if group_model.org_id != org_model.id: - raise ForbiddenException("Group does not belong to this organization") - return { - "organisation": org_model, - "group": group_model, - "permissions": group_model.permission_rel, - } + """ + Gets a list of permissions granted to the group. Also returns a summary for the org and group. + """ + if group_model.org_id != org_model.id: + raise ForbiddenException("Group does not belong to this organization") + return { + "organisation": org_model, + "group": group_model, + "permissions": group_model.permission_rel, + } @router.get( - path="/group/users", - summary="Gets a list of users assigned to a group", - status_code=status.HTTP_200_OK, - response_model=IAMGetGroupUsersResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "Group does not belong to this organization" - }, - }, + path="/group/users", + summary="Gets a list of users assigned to a group", + status_code=status.HTTP_200_OK, + response_model=IAMGetGroupUsersResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "Group does not belong to this organization" + }, + }, ) async def get_group_users( - group_model: group_model_query_dependency, - org_model: org_model_root_claim_query_dependency, + group_model: group_model_query_dependency, + org_model: org_model_root_claim_query_dependency, ): - """ - Gets a list of users assigned to the group. Also returns a summary for the org and group. - """ - if group_model.org_id != org_model.id: - raise ForbiddenException("Group does not belong to this organization") - return { - "organisation": org_model, - "group": group_model, - "users": group_model.user_rel, - } + """ + Gets a list of users assigned to the group. Also returns a summary for the org and group. + """ + if group_model.org_id != org_model.id: + raise ForbiddenException("Group does not belong to this organization") + return { + "organisation": org_model, + "group": group_model, + "users": group_model.user_rel, + } @router.post( - path="/group", - summary="Creates a new group", - status_code=status.HTTP_201_CREATED, - response_model=IAMPostGroupResponse, - responses={ - status.HTTP_409_CONFLICT: {"description": "Group with this name already exists"}, - }, + path="/group", + summary="Creates a new group", + status_code=status.HTTP_201_CREATED, + response_model=IAMPostGroupResponse, + responses={ + status.HTTP_409_CONFLICT: {"description": "Group with this name already exists"}, + }, ) async def create_group( - db: DbSession, - org_model: org_model_root_claim_body_dependency, - request_model: IAMPostGroupRequest, + db: DbSession, + org_model: org_model_root_claim_body_dependency, + request_model: IAMPostGroupRequest, ): - """ - Creates a new IAM group. - """ - group_model = Group(name=request_model.name, org_id=org_model.id) + """ + Creates a new IAM group. + """ + group_model = Group(name=request_model.name, org_id=org_model.id) - db.add(group_model) - try: - db.flush() - except IntegrityError as e: - if ( - isinstance(e.orig, UniqueViolation) # Postgres unique violation - or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation - ): - raise ConflictException("Group with this name already exists") - raise - group_response = GroupSummary(**group_model.__dict__) - org_response = OrgSummary(**org_model.__dict__) - db.commit() - return {"group": group_response, "organisation": org_response} + db.add(group_model) + try: + db.flush() + except IntegrityError as e: + if ( + isinstance(e.orig, UniqueViolation) # Postgres unique violation + or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation + ): + raise ConflictException("Group with this name already exists") + raise + group_response = GroupSummary(**group_model.__dict__) + org_response = OrgSummary(**org_model.__dict__) + db.commit() + return {"group": group_response, "organisation": org_response} @router.put( - path="/group/permission", - summary="Grants a permission to a group", - status_code=status.HTTP_200_OK, - response_model=IAMPutGroupPermissionResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "Group does not belong to this organization" - }, - status.HTTP_409_CONFLICT: { - "description": "This permission is already granted to this group" - }, - }, + path="/group/permission", + summary="Grants a permission to a group", + status_code=status.HTTP_200_OK, + response_model=IAMPutGroupPermissionResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "Group does not belong to this organization" + }, + status.HTTP_409_CONFLICT: { + "description": "This permission is already granted to this group" + }, + }, ) async def add_group_permission( - db: DbSession, - group_model: group_model_body_dependency, - perm_model: perm_model_body_dependency, - org_model: org_model_root_claim_body_dependency, - request_model: IAMPutGroupPermissionRequest, + db: DbSession, + group_model: group_model_body_dependency, + perm_model: perm_model_body_dependency, + org_model: org_model_root_claim_body_dependency, + request_model: IAMPutGroupPermissionRequest, ): - """ - Grants a permission to a group. Returns a list of the permissions in the group as well as a summary for the org and group. - """ - if group_model.org_id != org_model.id: - raise ForbiddenException("Group does not belong to this organization") + """ + Grants a permission to a group. Returns a list of the permissions in the group as well as a summary for the org and group. + """ + if group_model.org_id != org_model.id: + raise ForbiddenException("Group does not belong to this organization") - if perm_model in group_model.permission_rel: - raise ConflictException("Group already has this permission") + if perm_model in group_model.permission_rel: + raise ConflictException("Group already has this permission") - if perm_model not in org_model.permission_rel: # TODO: and not su - raise ForbiddenException("You cannot grant this permission") + if perm_model not in org_model.permission_rel: # TODO: and not su + raise ForbiddenException("You cannot grant this permission") - group_model.permission_rel.append(perm_model) + group_model.permission_rel.append(perm_model) - db.flush() - response = IAMPutGroupPermissionResponse( - organisation=OrgSummary(**org_model.__dict__), - group=GroupSummary(**group_model.__dict__), - permissions=group_model.permission_rel, - ) - db.commit() - return response + db.flush() + response = IAMPutGroupPermissionResponse( + organisation=OrgSummary(**org_model.__dict__), + group=GroupSummary(**group_model.__dict__), + permissions=group_model.permission_rel, + ) + db.commit() + return response @router.put( - path="/group/user", - summary="Directly adds a user to the group", - status_code=status.HTTP_200_OK, - response_model=IAMPutGroupUserResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "Group not in org | User not authenticated | User does not have permission" - }, - status.HTTP_409_CONFLICT: {"description": "User is already in group"}, - status.HTTP_403_FORBIDDEN: { - "description": "Only existing org members can be added directly." - }, - }, + path="/group/user", + summary="Directly adds a user to the group", + status_code=status.HTTP_200_OK, + response_model=IAMPutGroupUserResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "Group not in org | User not authenticated | User does not have permission" + }, + status.HTTP_409_CONFLICT: {"description": "User is already in group"}, + status.HTTP_403_FORBIDDEN: { + "description": "Only existing org members can be added directly." + }, + }, ) async def add_group_user( - db: DbSession, - group_model: group_model_body_dependency, - user_model: user_model_body_dependency, - org_model: org_model_root_claim_body_dependency, - request_model: IAMPutGroupUserRequest, + db: DbSession, + group_model: group_model_body_dependency, + user_model: user_model_body_dependency, + org_model: org_model_root_claim_body_dependency, + request_model: IAMPutGroupUserRequest, ): - """ - Directly adds an organisation member to a group.\n - To add a non-member, use an email invitation instead.\n - The user's email address must match the email on their OIDC profile. - """ - if group_model.org_id != org_model.id: - raise ForbiddenException("Group does not belong to this organization") + """ + Directly adds an organisation member to a group.\n + To add a non-member, use an email invitation instead.\n + The user's email address must match the email on their OIDC profile. + """ + if group_model.org_id != org_model.id: + raise ForbiddenException("Group does not belong to this organization") - if user_model in group_model.user_rel: - raise ConflictException("User already in group") + if user_model in group_model.user_rel: + raise ConflictException("User already in group") - if user_model not in org_model.user_rel: - raise ForbiddenException( - "Adding users directly can only be done with org members. Use email invitation instead." - ) + if user_model not in org_model.user_rel: + raise ForbiddenException( + "Adding users directly can only be done with org members. Use email invitation instead." + ) - group_model.user_rel.append(user_model) - db.flush() - response = IAMPutGroupUserResponse( - group=GroupSummary(**group_model.__dict__), users=group_model.user_rel - ) - db.commit() - return response + group_model.user_rel.append(user_model) + db.flush() + response = IAMPutGroupUserResponse( + group=GroupSummary(**group_model.__dict__), users=group_model.user_rel + ) + db.commit() + return response @router.delete( - path="/group/permission", - summary="Removes a permission from the group", - status_code=status.HTTP_200_OK, - response_model=IAMDeleteGroupPermissionResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "Group not in org | User not authenticated | User does not have permission" - }, - }, + path="/group/permission", + summary="Removes a permission from the group", + status_code=status.HTTP_200_OK, + response_model=IAMDeleteGroupPermissionResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "Group not in org | User not authenticated | User does not have permission" + }, + }, ) async def remove_group_permission( - db: DbSession, - group_model: group_model_query_dependency, - perm_model: perm_model_query_dependency, - org_model: org_model_root_claim_query_dependency, + db: DbSession, + group_model: group_model_query_dependency, + perm_model: perm_model_query_dependency, + org_model: org_model_root_claim_query_dependency, ): - """ - Removes a permission from the group. - """ - if group_model.org_id != org_model.id: - raise ForbiddenException("Group does not belong to this organization") + """ + Removes a permission from the group. + """ + if group_model.org_id != org_model.id: + raise ForbiddenException("Group does not belong to this organization") - if perm_model not in group_model.permission_rel: - raise UnprocessableContentException("Permission not granted to group") + if perm_model not in group_model.permission_rel: + raise UnprocessableContentException("Permission not granted to group") - group_model.permission_rel.remove(perm_model) - db.flush() - response = IAMDeleteGroupPermissionResponse( - group=GroupSummary(**group_model.__dict__), - permissions=group_model.permission_rel, - ) - db.commit() - return response + group_model.permission_rel.remove(perm_model) + db.flush() + response = IAMDeleteGroupPermissionResponse( + group=GroupSummary(**group_model.__dict__), + permissions=group_model.permission_rel, + ) + db.commit() + return response @router.delete( - path="/group/user", - summary="Removes a user from the group", - status_code=status.HTTP_200_OK, - response_model=IAMDeleteGroupUserResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "User not authenticated | User does not have permission" - }, - status.HTTP_403_FORBIDDEN: {"description": "Group not in org"}, - }, + path="/group/user", + summary="Removes a user from the group", + status_code=status.HTTP_200_OK, + response_model=IAMDeleteGroupUserResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "User not authenticated | User does not have permission" + }, + status.HTTP_403_FORBIDDEN: {"description": "Group not in org"}, + }, ) async def remove_group_user( - db: DbSession, - group_model: group_model_query_dependency, - user_model: user_model_query_dependency, - org_model: org_model_root_claim_query_dependency, + db: DbSession, + group_model: group_model_query_dependency, + user_model: user_model_query_dependency, + org_model: org_model_root_claim_query_dependency, ): - """ - Removes a user from the group. - """ - if group_model.org_id != org_model.id: - raise ForbiddenException("Group does not belong to this organization") + """ + Removes a user from the group. + """ + if group_model.org_id != org_model.id: + raise ForbiddenException("Group does not belong to this organization") - user_model.group_rel.remove(group_model) - db.flush() - response = IAMDeleteGroupUserResponse( - group=GroupSummary(**group_model.__dict__), users=group_model.user_rel - ) - db.commit() + user_model.group_rel.remove(group_model) + db.flush() + response = IAMDeleteGroupUserResponse( + group=GroupSummary(**group_model.__dict__), users=group_model.user_rel + ) + db.commit() - return response + return response @router.get( - path="/permissions", - summary="Returns a full list of permissions", - status_code=status.HTTP_200_OK, - response_model=IAMGetPermissionsResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "User must be root user of an organisation." - }, - }, + path="/permissions", + summary="Returns a full list of permissions", + status_code=status.HTTP_200_OK, + response_model=IAMGetPermissionsResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "User must be root user of an organisation." + }, + }, ) -async def get_permissions( - db: DbSession, org_model: org_model_root_claim_query_dependency -): - """ - Returns a full list of permissions. - """ - # TODO: if su: - # permission_models = db.query(Perm).all() - # else - permission_models = db.query(Perm).filter(Perm.org_rel.any(id=org_model.id)).all() - return {"permissions": permission_models} +async def get_permissions(db: DbSession, org_model: org_model_root_claim_query_dependency): + """ + Returns a full list of permissions. + """ + # TODO: if su: + # permission_models = db.query(Perm).all() + # else + permission_models = db.query(Perm).filter(Perm.org_rel.any(id=org_model.id)).all() + return {"permissions": permission_models} @router.post( - path="/permission", - summary="Creates a new permission for a service", - status_code=status.HTTP_201_CREATED, - response_model=IAMPostPermissionResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: {"description": "Must be super user."}, - status.HTTP_404_NOT_FOUND: {"description": "Service does not exist"}, - status.HTTP_409_CONFLICT: {"description": "Permission already exists"}, - }, + path="/permission", + summary="Creates a new permission for a service", + status_code=status.HTTP_201_CREATED, + response_model=IAMPostPermissionResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: {"description": "Must be super user."}, + status.HTTP_404_NOT_FOUND: {"description": "Service does not exist"}, + status.HTTP_409_CONFLICT: {"description": "Permission already exists"}, + }, ) async def create_new_permission( - db: DbSession, - su: super_admin_dependency, - request_model: IAMPostPermissionRequest, - service_model: service_model_body_dependency, # Used to verify service model exists + db: DbSession, + su: super_admin_dependency, + request_model: IAMPostPermissionRequest, + service_model: service_model_body_dependency, # Used to verify service model exists ): - """ - Allows a super admin to create a new IAM permission for a service. - """ - perm_model = Perm(**request_model.__dict__) - db.add(perm_model) - try: - db.flush() - except IntegrityError as e: - if ( - isinstance(e.orig, UniqueViolation) # Postgres unique violation - or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation - ): - raise ConflictException(message="Permission already exists") - raise - response = { - "id": perm_model.id, - "service_name": perm_model.service_name, - "resource": perm_model.resource, - "action": perm_model.action, - } - db.commit() - return {"permission": response} + """ + Allows a super admin to create a new IAM permission for a service. + """ + perm_model = Perm(**request_model.__dict__) + db.add(perm_model) + try: + db.flush() + except IntegrityError as e: + if ( + isinstance(e.orig, UniqueViolation) # Postgres unique violation + or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation + ): + raise ConflictException(message="Permission already exists") + raise + response = { + "id": perm_model.id, + "service_name": perm_model.service_name, + "resource": perm_model.resource, + "action": perm_model.action, + } + db.commit() + return {"permission": response} @router.delete( - path="/permission", - summary="Deletes a permission", - status_code=status.HTTP_204_NO_CONTENT, - responses={}, + path="/permission", + summary="Deletes a permission", + status_code=status.HTTP_204_NO_CONTENT, + responses={}, ) async def delete_permission( - db: DbSession, - su: super_admin_dependency, - perm_model: perm_model_query_dependency, + db: DbSession, + su: super_admin_dependency, + perm_model: perm_model_query_dependency, ): - """ - Allows a super admin to remove a permission. - """ - db.delete(perm_model) - db.commit() + """ + Allows a super admin to remove a permission. + """ + db.delete(perm_model) + db.commit() @router.post( - path="/permissions/search", - summary="Search list of permissions", - status_code=status.HTTP_200_OK, - response_model=IAMGetPermissionsSearchResponse, - responses={}, + path="/permissions/search", + summary="Search list of permissions", + status_code=status.HTTP_200_OK, + response_model=IAMGetPermissionsSearchResponse, + responses={}, ) async def permissions_search( - db: DbSession, - org_model: org_model_root_claim_body_dependency, - request_model: IAMGetPermissionsSearchRequest, + db: DbSession, + org_model: org_model_root_claim_body_dependency, + request_model: IAMGetPermissionsSearchRequest, ): - """ - Returns a list of permissions filtered by the queries provided.\n - If a query is null, it will be ignored. - """ - permission_query = db.query(Perm) + """ + Returns a list of permissions filtered by the queries provided.\n + If a query is null, it will be ignored. + """ + permission_query = db.query(Perm) - if not (request_model.service_id is None or request_model.service_id == ""): - permission_query = permission_query.filter( - Perm.service_id == request_model.service_id - ) + if not (request_model.service_id is None or request_model.service_id == ""): + permission_query = permission_query.filter( + Perm.service_id == request_model.service_id + ) - if not (request_model.resource is None or request_model.resource == ""): - permission_query = permission_query.filter(Perm.resource == request_model.resource) + if not (request_model.resource is None or request_model.resource == ""): + permission_query = permission_query.filter(Perm.resource == request_model.resource) - if not (request_model.action is None or request_model.action == ""): - permission_query = permission_query.filter(Perm.action == request_model.action) + if not (request_model.action is None or request_model.action == ""): + permission_query = permission_query.filter(Perm.action == request_model.action) - # TODO: if not su: - permission_query = permission_query.filter(Perm.org_rel.any(id=org_model.id)) + # TODO: if not su: + permission_query = permission_query.filter(Perm.org_rel.any(id=org_model.id)) - permission_models = permission_query.all() + permission_models = permission_query.all() - return {"permissions": permission_models} + return {"permissions": permission_models} @router.put( - path="/group/user/invitation", - summary="Send an email invitation for non-org member to join a group", - status_code=status.HTTP_200_OK, - response_model=IAMPutGroupInvitationResponse, - responses={}, + path="/group/user/invitation", + summary="Send an email invitation for non-org member to join a group", + status_code=status.HTTP_200_OK, + response_model=IAMPutGroupInvitationResponse, + responses={}, ) async def invitation( - background_tasks: BackgroundTasks, - org_model: org_model_root_claim_body_dependency, - group_model: group_model_body_dependency, - request_model: IAMPutGroupInvitationRequest, + background_tasks: BackgroundTasks, + org_model: org_model_root_claim_body_dependency, + group_model: group_model_body_dependency, + request_model: IAMPutGroupInvitationRequest, ): - """ - Sends an email invitation to join a group.\n - This is intended for inviting no-members to a group, giving them permission to access org resources.\n - i.e. Allowing somebody in a partner organisation to view metrics.\n - Can also be used for inviting organisaion members if needed. - """ - org_id: int = org_model.id - org_name: str = org_model.name - user_email = request_model.user_email - group_id: int = group_model.id - group_name: str = group_model.name + """ + Sends an email invitation to join a group.\n + This is intended for inviting no-members to a group, giving them permission to access org resources.\n + i.e. Allowing somebody in a partner organisation to view metrics.\n + Can also be used for inviting organisaion members if needed. + """ + org_id: int = org_model.id + org_name: str = org_model.name + user_email = request_model.user_email + group_id: int = group_model.id + group_name: str = group_model.name - background_tasks.add_task( - send_user_group_invitation, - org_id=org_id, - org_name=org_name, - user_email=user_email, - group_id=group_id, - group_name=group_name, - ) + background_tasks.add_task( + send_user_group_invitation, + org_id=org_id, + org_name=org_name, + user_email=user_email, + group_id=group_id, + group_name=group_name, + ) - response = { - "organisation": org_model, - "group": group_model, - "invited_email": user_email, - } + response = { + "organisation": org_model, + "group": group_model, + "invited_email": user_email, + } - return response + return response @router.put( - path="/group/user/invitation/accept", - summary="Accept email invitation to join an org's group", - status_code=status.HTTP_200_OK, - response_model=IAMPutGroupInvitationAcceptResponse, - responses={ - status.HTTP_404_NOT_FOUND: {"description": "User|Org|Group not found"}, - status.HTTP_403_FORBIDDEN: {"description": "Group and organisation do not match"}, - status.HTTP_409_CONFLICT: {"description": "User is already in the group"}, - }, + path="/group/user/invitation/accept", + summary="Accept email invitation to join an org's group", + status_code=status.HTTP_200_OK, + response_model=IAMPutGroupInvitationAcceptResponse, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "User|Org|Group not found"}, + status.HTTP_403_FORBIDDEN: {"description": "Group and organisation do not match"}, + status.HTTP_409_CONFLICT: {"description": "User is already in the group"}, + }, ) async def accept_invitation( - db: DbSession, - user_model: user_model_claims_dependency, - request_model: IAMPutGroupInvitationAcceptRequest, + db: DbSession, + user_model: user_model_claims_dependency, + request_model: IAMPutGroupInvitationAcceptRequest, ): - """ - Accepts an invitation to join an org's group - """ - email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model) + """ + Accepts an invitation to join an org's group + """ + email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model) - org_model = db.get(Org, email_claims["org_id"]) - if org_model is None: - raise OrgNotFoundException(email_claims["org_id"]) + org_model = db.get(Org, email_claims["org_id"]) + if org_model is None: + raise OrgNotFoundException(email_claims["org_id"]) - group_model = db.get(Group, email_claims["group_id"]) - if group_model is None: - raise GroupNotFoundException(email_claims["group_id"]) + group_model = db.get(Group, email_claims["group_id"]) + if group_model is None: + raise GroupNotFoundException(email_claims["group_id"]) - if group_model not in org_model.group_rel: - raise ForbiddenException("Group and org do not match.") + if group_model not in org_model.group_rel: + raise ForbiddenException("Group and org do not match.") - if user_model in group_model.user_rel: - raise ConflictException("User already in group.") + if user_model in group_model.user_rel: + raise ConflictException("User already in group.") - group_model.user_rel.append(user_model) - db.flush() + group_model.user_rel.append(user_model) + db.flush() - response = { - "organisation": org_model, - "user": user_model, - "group": {"details": group_model, "permissions": group_model.permission_rel}, - } - db.commit() + response = { + "organisation": org_model, + "user": user_model, + "group": {"details": group_model, "permissions": group_model.permission_rel}, + } + db.commit() - return response + return response @router.put( - path="/org/permissions", - summary="Grants an org access to permissions", - status_code=status.HTTP_200_OK, - response_model=IAMPutOrgPermissionsResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: {"description": "Must be super user."}, - }, + path="/org/permissions", + summary="Grants an org access to permissions", + status_code=status.HTTP_200_OK, + response_model=IAMPutOrgPermissionsResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: {"description": "Must be super user."}, + }, ) async def add_org_permissions( - db: DbSession, - su: super_admin_dependency, - org_model: org_model_body_dependency, - request_model: IAMPutOrgPermissionsRequest, + db: DbSession, + su: super_admin_dependency, + org_model: org_model_body_dependency, + request_model: IAMPutOrgPermissionsRequest, ): - """ - Grants a permission to a group. Returns a list of the permissions in the group as well as a summary for the org and group. - """ - for permission in request_model.permissions: - perm_model = db.get(Perm, permission) + """ + Grants a permission to a group. Returns a list of the permissions in the group as well as a summary for the org and group. + """ + for permission in request_model.permissions: + perm_model = db.get(Perm, permission) - if perm_model not in org_model.permission_rel: - org_model.permission_rel.append(perm_model) + if perm_model not in org_model.permission_rel: + org_model.permission_rel.append(perm_model) - db.flush() - response = IAMPutOrgPermissionsResponse( - organisation=OrgSummary(**org_model.__dict__), - permissions=org_model.permission_rel, - ) - db.commit() - return response + db.flush() + response = IAMPutOrgPermissionsResponse( + organisation=OrgSummary(**org_model.__dict__), + permissions=org_model.permission_rel, + ) + db.commit() + return response diff --git a/src/iam/schemas.py b/src/iam/schemas.py index 8072914..3a8380c 100644 --- a/src/iam/schemas.py +++ b/src/iam/schemas.py @@ -11,151 +11,151 @@ from typing import Optional, Annotated from pydantic import EmailStr, ConfigDict, Field from src.schemas import ( - CustomBaseModel, - ResourceName, - ServiceIDMixin, - OrgIDMixin, - UserIDMixin, - PermIDMixin, - GroupIDMixin, - GroupSummary, - OrgSummary, - UserSummary, + CustomBaseModel, + ResourceName, + ServiceIDMixin, + OrgIDMixin, + UserIDMixin, + PermIDMixin, + GroupIDMixin, + GroupSummary, + OrgSummary, + UserSummary, ) class UserSchema(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") + model_config = ConfigDict(from_attributes=True, extra="ignore") - id: int - first_name: str - last_name: str - email: EmailStr + id: int + first_name: str + last_name: str + email: EmailStr class PermissionSchema(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") + model_config = ConfigDict(from_attributes=True, extra="ignore") - id: int - service_name: str - resource: str - action: str + id: int + service_name: str + resource: str + action: str class GroupDetails(CustomBaseModel): - details: GroupSummary - permissions: list[PermissionSchema] + details: GroupSummary + permissions: list[PermissionSchema] class IAMCAoRRequest(CustomBaseModel): - action: str - rn: ResourceName + action: str + rn: ResourceName class IAMCAoRResponse(CustomBaseModel): - allowed: bool - user: UserSummary - action: str - rn: ResourceName + allowed: bool + user: UserSummary + action: str + rn: ResourceName class IAMGetGroupPermissionsResponse(CustomBaseModel): - organisation: OrgSummary - group: GroupSummary - permissions: list[PermissionSchema] + organisation: OrgSummary + group: GroupSummary + permissions: list[PermissionSchema] class IAMGetGroupUsersResponse(CustomBaseModel): - organisation: OrgSummary - group: GroupSummary - users: list[UserSummary] + organisation: OrgSummary + group: GroupSummary + users: list[UserSummary] class IAMPostGroupRequest(OrgIDMixin): - name: str = Field(min_length=3) + name: str = Field(min_length=3) class IAMPostGroupResponse(CustomBaseModel): - organisation: OrgSummary - group: GroupSummary + organisation: OrgSummary + group: GroupSummary class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin): - pass + pass class IAMPutGroupPermissionResponse(CustomBaseModel): - organisation: OrgSummary - group: GroupSummary - permissions: list[PermissionSchema] + organisation: OrgSummary + group: GroupSummary + permissions: list[PermissionSchema] class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin): - pass + pass class IAMPutGroupUserResponse(CustomBaseModel): - group: GroupSummary - users: list[UserSchema] + group: GroupSummary + users: list[UserSchema] class IAMDeleteGroupPermissionResponse(CustomBaseModel): - group: GroupSummary - permissions: list[PermissionSchema] + group: GroupSummary + permissions: list[PermissionSchema] class IAMDeleteGroupUserResponse(CustomBaseModel): - group: GroupSummary - users: list[UserSchema] + group: GroupSummary + users: list[UserSchema] class IAMGetPermissionsResponse(CustomBaseModel): - permissions: list[PermissionSchema] + permissions: list[PermissionSchema] class IAMPostPermissionRequest(ServiceIDMixin): - resource: str - action: str + resource: str + action: str class IAMPostPermissionResponse(CustomBaseModel): - permission: PermissionSchema + permission: PermissionSchema class IAMGetPermissionsSearchRequest(OrgIDMixin): - service_id: Annotated[int | None, Field(gt=0)] = None - resource: Optional[str] = None - action: Optional[str] = None + service_id: Annotated[int | None, Field(gt=0)] = None + resource: Optional[str] = None + action: Optional[str] = None class IAMGetPermissionsSearchResponse(CustomBaseModel): - permissions: list[PermissionSchema] + permissions: list[PermissionSchema] class IAMPutGroupInvitationRequest(OrgIDMixin, GroupIDMixin): - user_email: EmailStr + user_email: EmailStr class IAMPutGroupInvitationResponse(CustomBaseModel): - organisation: OrgSummary - group: GroupSummary - invited_email: EmailStr + organisation: OrgSummary + group: GroupSummary + invited_email: EmailStr class IAMPutGroupInvitationAcceptRequest(CustomBaseModel): - jwt: str + jwt: str class IAMPutGroupInvitationAcceptResponse(CustomBaseModel): - organisation: OrgSummary - user: UserSummary - group: GroupDetails + organisation: OrgSummary + user: UserSummary + group: GroupDetails class IAMPutOrgPermissionsRequest(OrgIDMixin): - permissions: list[int] + permissions: list[int] class IAMPutOrgPermissionsResponse(CustomBaseModel): - organisation: OrgSummary - permissions: list[PermissionSchema] + organisation: OrgSummary + permissions: list[PermissionSchema] diff --git a/src/iam/service.py b/src/iam/service.py index 2112c6c..0af4bfd 100644 --- a/src/iam/service.py +++ b/src/iam/service.py @@ -23,90 +23,90 @@ from src.service.schemas import HasServiceName def valid_service_key( - db: DbSession, request: Request, request_model: HasServiceName + db: DbSession, request: Request, request_model: HasServiceName ) -> bool: - rn = request_model.rn - api_key = request.headers.get("X-API-Key", None) - if not api_key: - raise UnauthorizedException("Missing API key") - service = rn.service - result = ( - db.query(Service) - .filter(Service.name == service) - .filter(Service.api_key == api_key) - .first() - ) - if result is None: - raise UnauthorizedException("Invalid API key") + rn = request_model.rn + api_key = request.headers.get("X-API-Key", None) + if not api_key: + raise UnauthorizedException("Missing API key") + service = rn.service + result = ( + db.query(Service) + .filter(Service.name == service) + .filter(Service.api_key == api_key) + .first() + ) + if result is None: + raise UnauthorizedException("Invalid API key") - return True + return True service_key_dependency = Annotated[bool, Depends(valid_service_key)] async def send_user_group_invitation( - user_email: str, org_name: str, org_id: int, group_id: int, group_name: str + user_email: str, org_name: str, org_id: int, group_id: int, group_name: str ): - expiry_delta = timedelta(hours=24) - expiry = datetime.now(timezone.utc) + expiry_delta - claims = { - "email": user_email, - "org_id": org_id, - "group_id": group_id, - "group_name": group_name, - "exp": expiry, - "type": "group_invite", - } + expiry_delta = timedelta(hours=24) + expiry = datetime.now(timezone.utc) + expiry_delta + claims = { + "email": user_email, + "org_id": org_id, + "group_id": group_id, + "group_name": group_name, + "exp": expiry, + "type": "group_invite", + } - token = await generate_jwt(claims) - subject = f"You have been invited to join a group of {org_name}" - body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}" + token = await generate_jwt(claims) + subject = f"You have been invited to join a group of {org_name}" + body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}" - await send_email( - recipient=user_email, - subject=subject, - body=body, - ) + await send_email( + recipient=user_email, + subject=subject, + body=body, + ) async def create_group_and_assign_perms( - db: Session, org_model: Org, group_name: str, perm_list: list[int] + db: Session, org_model: Org, group_name: str, perm_list: list[int] ): - new_group = Group(name=group_name, org_id=org_model.id) - db.add(new_group) - db.flush() + new_group = Group(name=group_name, org_id=org_model.id) + db.add(new_group) + db.flush() - for permission in perm_list: - perm_model = db.get(Perm, permission) + for permission in perm_list: + perm_model = db.get(Perm, permission) - if perm_model is None: - continue + if perm_model is None: + continue - new_group.permission_rel.append(perm_model) - db.flush() + new_group.permission_rel.append(perm_model) + db.flush() - return new_group + return new_group async def assign_default_group( - db: DbSession, - org_model: Org, - user_model: User, - group_name: str, - perm_list: list[int], + db: DbSession, + org_model: Org, + user_model: User, + group_name: str, + perm_list: list[int], ): - group_model = ( - db.query(Group) - .filter(Group.org_id == org_model.id) - .filter(Group.name == group_name) - .first() - ) + group_model = ( + db.query(Group) + .filter(Group.org_id == org_model.id) + .filter(Group.name == group_name) + .first() + ) - if group_model is None: - group_model = await create_group_and_assign_perms( - db=db, group_name=group_name, org_model=org_model, perm_list=perm_list - ) + if group_model is None: + group_model = await create_group_and_assign_perms( + db=db, group_name=group_name, org_model=org_model, perm_list=perm_list + ) - user_model.group_rel.append(group_model) - db.flush() + user_model.group_rel.append(group_model) + db.flush() diff --git a/src/main.py b/src/main.py index a96eddc..787cb36 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,7 @@ """ Application root file: Inits the FastAPI application """ + import os.path from contextlib import asynccontextmanager from typing import AsyncGenerator @@ -19,43 +20,43 @@ from src.auth.service import get_current_user, get_dev_user @asynccontextmanager async def lifespan(_application: FastAPI) -> AsyncGenerator: - # Startup - yield - # Shutdown + # Startup + yield + # Shutdown if settings.ENVIRONMENT.is_deployed: - # Just a precaution, should be False anyway - settings.DISABLE_AUTH = False + # Just a precaution, should be False anyway + settings.DISABLE_AUTH = False tags_metadata = [ - { - "name": "User", - "description": "User related operations, includes getting information about the current user", - }, - { - "name": "Organisation", - "description": "Organisation related operations, includes getting lists of users etc associated with orgs", - }, - { - "name": "Service", - "description": "Services related operations, includes registering services and reissuing API keys", - }, - { - "name": "IAM", - "description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.", - }, + { + "name": "User", + "description": "User related operations, includes getting information about the current user", + }, + { + "name": "Organisation", + "description": "Organisation related operations, includes getting lists of users etc associated with orgs", + }, + { + "name": "Service", + "description": "Services related operations, includes registering services and reissuing API keys", + }, + { + "name": "IAM", + "description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.", + }, ] app = FastAPI( - swagger_ui_init_oauth={ - "clientId": auth_settings.CLIENT_ID, - "usePkceWithAuthorizationCodeGrant": True, - "scopes": "openid profile email", - }, - openapi_tags=tags_metadata, + swagger_ui_init_oauth={ + "clientId": auth_settings.CLIENT_ID, + "usePkceWithAuthorizationCodeGrant": True, + "scopes": "openid profile email", + }, + openapi_tags=tags_metadata, ) # Type inspection disabled for middleware injection. @@ -64,19 +65,19 @@ app = FastAPI( app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value()) # noinspection PyTypeChecker app.add_middleware( - CORSMiddleware, - allow_origins=settings.CORS_ORIGINS, - allow_origin_regex=settings.CORS_ORIGINS_REGEX, - allow_credentials=True, - allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"), - allow_headers=settings.CORS_HEADERS, + CORSMiddleware, + allow_origins=settings.CORS_ORIGINS, + allow_origin_regex=settings.CORS_ORIGINS_REGEX, + allow_credentials=True, + allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"), + allow_headers=settings.CORS_HEADERS, ) if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL): - app.dependency_overrides[get_current_user] = get_dev_user + app.dependency_overrides[get_current_user] = get_dev_user app.include_router(api_router) if os.path.exists("/app/static"): - app.frontend("/ui", directory="/app/static", fallback="index.html") + app.frontend("/ui", directory="/app/static", fallback="index.html") diff --git a/src/models.py b/src/models.py index 3f2295d..3dbc891 100644 --- a/src/models.py +++ b/src/models.py @@ -10,28 +10,28 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column class CustomBase(DeclarativeBase): - type_annotation_map = { - datetime: DateTime(timezone=True), - dict[str, Any]: JSON, - } + type_annotation_map = { + datetime: DateTime(timezone=True), + dict[str, Any]: JSON, + } class ActivatedMixin: - active: Mapped[bool] = mapped_column(default=True) + active: Mapped[bool] = mapped_column(default=True) class DeletedTimestampMixin: - deleted_at: Mapped[datetime | None] = mapped_column(nullable=True) + deleted_at: Mapped[datetime | None] = mapped_column(nullable=True) class DescriptionMixin: - description: Mapped[str] + description: Mapped[str] class IdMixin: - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) class TimestampMixin: - created_at: Mapped[datetime] = mapped_column(default=func.now()) - updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now()) + created_at: Mapped[datetime] = mapped_column(default=func.now()) + updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now()) diff --git a/src/organisation/constants.py b/src/organisation/constants.py index 94bcc2d..dc3d4c2 100644 --- a/src/organisation/constants.py +++ b/src/organisation/constants.py @@ -10,48 +10,48 @@ from enum import StrEnum, auto class Status(StrEnum): - """ - Enumeration of organisation statuses. + """ + Enumeration of organisation statuses. - Attributes: - PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted. - SUBMITTED (str): Questionnaire submitted but not approved. - REMEDIATION (str): Questionnaire submitted but requires revisions. - APPROVED (str): Questionnaire has been approved by an admin. - REJECTED (str): Questionnaire has been rejected by an admin. - REMOVED (str): Organisation has been removed. - """ + Attributes: + PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted. + SUBMITTED (str): Questionnaire submitted but not approved. + REMEDIATION (str): Questionnaire submitted but requires revisions. + APPROVED (str): Questionnaire has been approved by an admin. + REJECTED (str): Questionnaire has been rejected by an admin. + REMOVED (str): Organisation has been removed. + """ - PARTIAL = auto() - SUBMITTED = auto() - REMEDIATION = auto() - APPROVED = auto() - REJECTED = auto() - REMOVED = auto() + PARTIAL = auto() + SUBMITTED = auto() + REMEDIATION = auto() + APPROVED = auto() + REJECTED = auto() + REMOVED = auto() - @property - def is_pre_approval(self): - return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION) + @property + def is_pre_approval(self): + return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION) - @property - def is_pre_submission(self): - return self in (self.PARTIAL, self.REMEDIATION) + @property + def is_pre_submission(self): + return self in (self.PARTIAL, self.REMEDIATION) - @property - def is_blocked(self): - return self in (self.REMOVED, self.REJECTED) + @property + def is_blocked(self): + return self in (self.REMOVED, self.REJECTED) class ContactType(StrEnum): - """ - Enumeration of organisation contact types. + """ + Enumeration of organisation contact types. - Attributes: - BILLING(str): Billing contact. - SECURITY (str): Security contact. - OWNER (str): Owner contact. - """ + Attributes: + BILLING(str): Billing contact. + SECURITY (str): Security contact. + OWNER (str): Owner contact. + """ - BILLING = auto() - SECURITY = auto() - OWNER = auto() + BILLING = auto() + SECURITY = auto() + OWNER = auto() diff --git a/src/organisation/dependencies.py b/src/organisation/dependencies.py index 2686560..e9f6f6b 100644 --- a/src/organisation/dependencies.py +++ b/src/organisation/dependencies.py @@ -18,25 +18,25 @@ from src.organisation.exceptions import OrgNotFoundException def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org: - org_model = db.get(Org, org_id) - if org_model is None: - raise OrgNotFoundException(org_id) - return org_model + org_model = db.get(Org, org_id) + if org_model is None: + raise OrgNotFoundException(org_id) + return org_model org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)] def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org: - org_id: Optional[int] = getattr(request_model, "organisation_id", None) - if org_id is None: - raise OrgNotFoundException() + org_id: Optional[int] = getattr(request_model, "organisation_id", None) + if org_id is None: + raise OrgNotFoundException() - org_model = db.get(Org, org_id) - if org_model is None: - raise OrgNotFoundException(org_id) + org_model = db.get(Org, org_id) + if org_model is None: + raise OrgNotFoundException(org_id) - return org_model + return org_model org_model_body_dependency = Annotated[Org, Depends(get_org_model_body)] diff --git a/src/organisation/exceptions.py b/src/organisation/exceptions.py index a56b395..930aaf5 100644 --- a/src/organisation/exceptions.py +++ b/src/organisation/exceptions.py @@ -12,26 +12,26 @@ from fastapi import HTTPException, status class OrgNotFoundException(HTTPException): - def __init__(self, org_id: Optional[int] = None) -> None: - detail = ( - "Organisation not found" - if org_id is None - else f"Organisation with ID '{org_id}' was not found." - ) - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - ) + def __init__(self, org_id: Optional[int] = None) -> None: + detail = ( + "Organisation not found" + if org_id is None + else f"Organisation with ID '{org_id}' was not found." + ) + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) class AwaitingApprovalException(HTTPException): - def __init__(self, org_id: Optional[int] = None) -> None: - detail = ( - "Organisation has not been approved." - if org_id is None - else f"Organisation with ID '{org_id}' has not been approved." - ) - super().__init__( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=detail, - ) + def __init__(self, org_id: Optional[int] = None) -> None: + detail = ( + "Organisation has not been approved." + if org_id is None + else f"Organisation with ID '{org_id}' has not been approved." + ) + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=detail, + ) diff --git a/src/organisation/models.py b/src/organisation/models.py index 97d1247..e28c7cc 100644 --- a/src/organisation/models.py +++ b/src/organisation/models.py @@ -25,51 +25,51 @@ from src.models import CustomBase, TimestampMixin class Organisation(CustomBase, IdMixin, TimestampMixin, DeletedTimestampMixin): - __tablename__ = "organisation" + __tablename__ = "organisation" - name: Mapped[str] - status: Mapped[str] = mapped_column(default="partial") - intake_questionnaire: Mapped[dict[str, Any] | None] + name: Mapped[str] + status: Mapped[str] = mapped_column(default="partial") + intake_questionnaire: Mapped[dict[str, Any] | None] - root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) - billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) - security_contact_id: Mapped[int] = mapped_column( - ForeignKey("contact.id"), nullable=True - ) - owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) + root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) + billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) + security_contact_id: Mapped[int] = mapped_column( + ForeignKey("contact.id"), nullable=True + ) + owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) - user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel") + user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel") - group_rel = relationship( - "Group", back_populates="org_rel", cascade="all, delete-orphan" - ) - root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id") + group_rel = relationship( + "Group", back_populates="org_rel", cascade="all, delete-orphan" + ) + root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id") - billing_contact_rel = relationship( - "Contact", foreign_keys="Organisation.billing_contact_id" - ) - security_contact_rel = relationship( - "Contact", foreign_keys="Organisation.security_contact_id" - ) - owner_contact_rel = relationship( - "Contact", foreign_keys="Organisation.owner_contact_id" - ) + billing_contact_rel = relationship( + "Contact", foreign_keys="Organisation.billing_contact_id" + ) + security_contact_rel = relationship( + "Contact", foreign_keys="Organisation.security_contact_id" + ) + owner_contact_rel = relationship( + "Contact", foreign_keys="Organisation.owner_contact_id" + ) - permission_rel = relationship( - "Permission", secondary="org_permissions", back_populates="org_rel" - ) + permission_rel = relationship( + "Permission", secondary="org_permissions", back_populates="org_rel" + ) - @property - def root_user_email(self) -> str: - return self.root_user_rel.email if self.root_user_rel else "" + @property + def root_user_email(self) -> str: + return self.root_user_rel.email if self.root_user_rel else "" class OrgUsers(CustomBase): - __tablename__ = "orgusers" + __tablename__ = "orgusers" - org_id: Mapped[int] = mapped_column( - ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True - ) - user_id: Mapped[int] = mapped_column( - ForeignKey("user.id", ondelete="CASCADE"), primary_key=True - ) + org_id: Mapped[int] = mapped_column( + ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True + ) + user_id: Mapped[int] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), primary_key=True + ) diff --git a/src/organisation/router.py b/src/organisation/router.py index 78159ca..3c06b91 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -26,9 +26,9 @@ from fastapi.params import Query from src.contact.schemas import ContactModel from src.exceptions import ( - UnprocessableContentException, - ConflictException, - ForbiddenException, + UnprocessableContentException, + ConflictException, + ForbiddenException, ) from src.contact.models import Contact from src.contact.schemas import ContactAddress @@ -37,614 +37,612 @@ from src.database import DbSession from src.organisation.schemas_questionnaires import QuestionnaireQuestionsVersion0 from src.organisation.service import assign_defaults from src.user.dependencies import ( - user_model_body_dependency, - user_model_claims_dependency, - user_model_query_dependency, + user_model_body_dependency, + user_model_claims_dependency, + user_model_query_dependency, ) from src.auth.dependencies import ( - super_admin_dependency, - org_model_root_claim_query_dependency, - org_model_root_claim_body_dependency, + super_admin_dependency, + org_model_root_claim_query_dependency, + org_model_root_claim_body_dependency, ) from src.iam.models import Group from src.organisation.dependencies import ( - org_model_body_dependency, - org_model_query_dependency, + org_model_body_dependency, + org_model_query_dependency, ) from src.organisation.constants import ContactType, Status as StatusEnum from src.organisation.models import Organisation as Org from src.organisation.schemas import ( - OrgPostOrgRequest, - OrgPatchQuestionnaireRequest, - OrgPatchStatusRequest, - OrgPatchContactRequest, - OrgPostUserRequest, - OrgGetUserResponse, - OrgGetContactResponse, - OrgGetOrgResponse, - OrgPatchRootRequest, - OrgGetGroupResponse, - OrgPostOrgResponse, - OrgPatchQuestionnaireResponse, - OrgPatchStatusResponse, - OrgPostUserResponse, - OrgPatchRootResponse, - Questionnaire, - OrgPatchContactResponse, - QuestionnaireMetadata, + OrgPostOrgRequest, + OrgPatchQuestionnaireRequest, + OrgPatchStatusRequest, + OrgPatchContactRequest, + OrgPostUserRequest, + OrgGetUserResponse, + OrgGetContactResponse, + OrgGetOrgResponse, + OrgPatchRootRequest, + OrgGetGroupResponse, + OrgPostOrgResponse, + OrgPatchQuestionnaireResponse, + OrgPatchStatusResponse, + OrgPostUserResponse, + OrgPatchRootResponse, + Questionnaire, + OrgPatchContactResponse, + QuestionnaireMetadata, ) router = APIRouter( - prefix="/org", - tags=["Organisation"], + prefix="/org", + tags=["Organisation"], ) @router.get( - "", - summary="Get org details by ID.", - response_model=OrgGetOrgResponse, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - status.HTTP_404_NOT_FOUND: {"description": "Organisation not found"}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Missing or invalid org_id query parameter" - }, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - }, + "", + summary="Get org details by ID.", + response_model=OrgGetOrgResponse, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + status.HTTP_404_NOT_FOUND: {"description": "Organisation not found"}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Missing or invalid org_id query parameter" + }, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + }, ) -async def get_org_by_id( - db: DbSession, org_model: org_model_root_claim_query_dependency -): - """ - Returns organisation details including key member email addresses - """ - response = { - "organisation_id": org_model.id, - "name": org_model.name, - "status": org_model.status, - "intake_questionnaire": org_model.intake_questionnaire, - "root_user_email": org_model.root_user_email, - "billing_contact": { - "id": org_model.billing_contact_id, - "email": org_model.billing_contact_rel.email, - }, - "owner_contact": { - "id": org_model.owner_contact_id, - "email": org_model.owner_contact_rel.email, - }, - "security_contact": { - "id": org_model.security_contact_id, - "email": org_model.security_contact_rel.email, - }, - } +async def get_org_by_id(db: DbSession, org_model: org_model_root_claim_query_dependency): + """ + Returns organisation details including key member email addresses + """ + response = { + "organisation_id": org_model.id, + "name": org_model.name, + "status": org_model.status, + "intake_questionnaire": org_model.intake_questionnaire, + "root_user_email": org_model.root_user_email, + "billing_contact": { + "id": org_model.billing_contact_id, + "email": org_model.billing_contact_rel.email, + }, + "owner_contact": { + "id": org_model.owner_contact_id, + "email": org_model.owner_contact_rel.email, + }, + "security_contact": { + "id": org_model.security_contact_id, + "email": org_model.security_contact_rel.email, + }, + } - return {"organisations": [response]} + return {"organisations": [response]} @router.post( - "", - summary="Create new organisation.", - status_code=status.HTTP_201_CREATED, - response_model=OrgPostOrgResponse, - responses={ - status.HTTP_201_CREATED: {"description": "Successfully created organisation."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_401_UNAUTHORIZED: { - "description": "User must be logged in with OIDC to create organisation." - }, - status.HTTP_409_CONFLICT: { - "description": "Organisation with this name already exists." - }, - }, + "", + summary="Create new organisation.", + status_code=status.HTTP_201_CREATED, + response_model=OrgPostOrgResponse, + responses={ + status.HTTP_201_CREATED: {"description": "Successfully created organisation."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_401_UNAUTHORIZED: { + "description": "User must be logged in with OIDC to create organisation." + }, + status.HTTP_409_CONFLICT: { + "description": "Organisation with this name already exists." + }, + }, ) async def create_org( - db: DbSession, - user_model: user_model_claims_dependency, - request_model: OrgPostOrgRequest, - background_tasks: BackgroundTasks, + db: DbSession, + user_model: user_model_claims_dependency, + request_model: OrgPostOrgRequest, + background_tasks: BackgroundTasks, ): - """ - Creates a new organisation with optional questionnaire (to be completed or submitted). - ALl organisations are given the "partial" status on creation. See update_questionnaire() for more details. - """ - if request_model.intake_questionnaire: - questionnaire_questions = request_model.intake_questionnaire.model_dump() - else: - questionnaire_questions = QuestionnaireQuestionsVersion0().model_dump() + """ + Creates a new organisation with optional questionnaire (to be completed or submitted). + ALl organisations are given the "partial" status on creation. See update_questionnaire() for more details. + """ + if request_model.intake_questionnaire: + questionnaire_questions = request_model.intake_questionnaire.model_dump() + else: + questionnaire_questions = QuestionnaireQuestionsVersion0().model_dump() - questionnaire_metadata = QuestionnaireMetadata(version=0, submission_date=None) + questionnaire_metadata = QuestionnaireMetadata(version=0, submission_date=None) - intake_questionnaire = Questionnaire( - metadata=questionnaire_metadata, - questions=questionnaire_questions, - ) + intake_questionnaire = Questionnaire( + metadata=questionnaire_metadata, + questions=questionnaire_questions, + ) - org_model = Org( - name=request_model.name, - intake_questionnaire=intake_questionnaire.model_dump(mode="json"), - root_user_id=user_model.id, - ) + org_model = Org( + name=request_model.name, + intake_questionnaire=intake_questionnaire.model_dump(mode="json"), + root_user_id=user_model.id, + ) - org_model.status = "partial" + org_model.status = "partial" - db.add(org_model) - try: - db.flush() - except IntegrityError as e: - if ( - isinstance(e.orig, UniqueViolation) # Postgres unique violation - or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation - ): - raise ConflictException(message="Organisation with this name already exists") - raise - # Adds currently logged-in user to org users list and sets them as root_user - org_model.user_rel.append(user_model) + db.add(org_model) + try: + db.flush() + except IntegrityError as e: + if ( + isinstance(e.orig, UniqueViolation) # Postgres unique violation + or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation + ): + raise ConflictException(message="Organisation with this name already exists") + raise + # Adds currently logged-in user to org users list and sets them as root_user + org_model.user_rel.append(user_model) - background_tasks.add_task( - assign_defaults, db, org_id=org_model.id, user_id=user_model.id - ) + background_tasks.add_task( + assign_defaults, db, org_id=org_model.id, user_id=user_model.id + ) - for contact_type in [ - "billing_contact_id", - "security_contact_id", - "owner_contact_id", - ]: - contact_model = Contact(org_id=org_model.id) - db.add(contact_model) - db.flush() - org_model.__setattr__(contact_type, contact_model.id) - response = OrgPostOrgResponse(**org_model.__dict__) - db.commit() - return response + for contact_type in [ + "billing_contact_id", + "security_contact_id", + "owner_contact_id", + ]: + contact_model = Contact(org_id=org_model.id) + db.add(contact_model) + db.flush() + org_model.__setattr__(contact_type, contact_model.id) + response = OrgPostOrgResponse(**org_model.__dict__) + db.commit() + return response @router.patch( - "/questionnaire", - summary="Update questionnaire.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchQuestionnaireResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated questionnaire."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - }, + "/questionnaire", + summary="Update questionnaire.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchQuestionnaireResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated questionnaire."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + }, ) async def update_questionnaire( - db: DbSession, - org_model: org_model_root_claim_body_dependency, - request_model: OrgPatchQuestionnaireRequest, + db: DbSession, + org_model: org_model_root_claim_body_dependency, + request_model: OrgPatchQuestionnaireRequest, ): - """ - Route for updating questionnaire. - The partial bool allows for submission of partially completed questionnaire and/or - final "are you sure" check before setting the org to be in "submitted" status, awaiting admin approval. - """ - org_status = StatusEnum(org_model.status) - if not org_status.is_pre_submission: - raise ForbiddenException("Questionnaire may only be modified prior to submission.") - update_data: dict = request_model.intake_questionnaire.model_dump(exclude_none=True) - questionnaire = org_model.intake_questionnaire - if questionnaire is None: - questionnaire_questions = QuestionnaireQuestionsVersion0().model_dump() + """ + Route for updating questionnaire. + The partial bool allows for submission of partially completed questionnaire and/or + final "are you sure" check before setting the org to be in "submitted" status, awaiting admin approval. + """ + org_status = StatusEnum(org_model.status) + if not org_status.is_pre_submission: + raise ForbiddenException("Questionnaire may only be modified prior to submission.") + update_data: dict = request_model.intake_questionnaire.model_dump(exclude_none=True) + questionnaire = org_model.intake_questionnaire + if questionnaire is None: + questionnaire_questions = QuestionnaireQuestionsVersion0().model_dump() - questionnaire_metadata = QuestionnaireMetadata(version=0, submission_date=None) + questionnaire_metadata = QuestionnaireMetadata(version=0, submission_date=None) - questionnaire = Questionnaire( - metadata=questionnaire_metadata, - questions=questionnaire_questions, - ).model_dump() + questionnaire = Questionnaire( + metadata=questionnaire_metadata, + questions=questionnaire_questions, + ).model_dump() - questions_model = QuestionnaireQuestionsVersion0(**questionnaire["questions"]) - else: - questions_model = QuestionnaireQuestionsVersion0(**questionnaire["questions"]) - for key, value in update_data.items(): - if hasattr(questions_model, key): - setattr(questions_model, key, value) - else: - raise UnprocessableContentException("Invalid keys in update request") + questions_model = QuestionnaireQuestionsVersion0(**questionnaire["questions"]) + else: + questions_model = QuestionnaireQuestionsVersion0(**questionnaire["questions"]) + for key, value in update_data.items(): + if hasattr(questions_model, key): + setattr(questions_model, key, value) + else: + raise UnprocessableContentException("Invalid keys in update request") - metadata = QuestionnaireMetadata(version=questionnaire["metadata"]["version"]) + metadata = QuestionnaireMetadata(version=questionnaire["metadata"]["version"]) - # Allows for partially completed questionnaires to be saved without being submitted for review - if not request_model.partial: - org_model.status = "submitted" - metadata.submission_date = datetime.now(timezone.utc) + # Allows for partially completed questionnaires to be saved without being submitted for review + if not request_model.partial: + org_model.status = "submitted" + metadata.submission_date = datetime.now(timezone.utc) - questionnaire_model = Questionnaire( - metadata=metadata, - questions=questions_model, - ) + questionnaire_model = Questionnaire( + metadata=metadata, + questions=questions_model, + ) - org_model.intake_questionnaire = questionnaire_model.model_dump(mode="json") - db.flush() - response = OrgPatchQuestionnaireResponse(**org_model.__dict__) - db.commit() - return response + org_model.intake_questionnaire = questionnaire_model.model_dump(mode="json") + db.flush() + response = OrgPatchQuestionnaireResponse(**org_model.__dict__) + db.commit() + return response @router.patch( - "/status", - summary="Update status of organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchStatusResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated organisation status."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_403_FORBIDDEN: {"description": "Not authorised. Must be super admin."}, - }, + "/status", + summary="Update status of organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchStatusResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated organisation status."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_403_FORBIDDEN: {"description": "Not authorised. Must be super admin."}, + }, ) async def update_status( - db: DbSession, - org_model: org_model_body_dependency, - su: super_admin_dependency, - request_model: OrgPatchStatusRequest, + db: DbSession, + org_model: org_model_body_dependency, + su: super_admin_dependency, + request_model: OrgPatchStatusRequest, ): - """ - Sets an organisation's status. This is the endpoint for approving or denying an organisation after reviewing the questionnaire. - """ - org_model.status = request_model.status - db.flush() - response = OrgPatchStatusResponse(**org_model.__dict__) - db.commit() - return response + """ + Sets an organisation's status. This is the endpoint for approving or denying an organisation after reviewing the questionnaire. + """ + org_model.status = request_model.status + db.flush() + response = OrgPatchStatusResponse(**org_model.__dict__) + db.commit() + return response @router.get( - "/users", - summary="Get email addresses of users of the organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgGetUserResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval of users."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Org ID missing or invalid." - }, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - }, + "/users", + summary="Get email addresses of users of the organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgGetUserResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval of users."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Org ID missing or invalid." + }, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + }, ) async def get_users(org_model: org_model_root_claim_query_dependency): - """ - Returns a list of the email addresses of all users of the organisation. - """ - return { - "users": [{"email": user.email, "id": user.id} for user in org_model.user_rel], - "organisation": org_model, - } + """ + Returns a list of the email addresses of all users of the organisation. + """ + return { + "users": [{"email": user.email, "id": user.id} for user in org_model.user_rel], + "organisation": org_model, + } @router.post( - "/user", - summary="Add user to the organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPostUserResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully added user to the organisation."}, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_409_CONFLICT: { - "description": "User is already a member of the organisation." - }, - }, + "/user", + summary="Add user to the organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPostUserResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully added user to the organisation."}, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_409_CONFLICT: { + "description": "User is already a member of the organisation." + }, + }, ) async def add_user_to_org( - db: DbSession, - org_model: org_model_body_dependency, - user_model: user_model_body_dependency, - su: super_admin_dependency, - request_model: OrgPostUserRequest, + db: DbSession, + org_model: org_model_body_dependency, + user_model: user_model_body_dependency, + su: super_admin_dependency, + request_model: OrgPostUserRequest, ): - """ - Adds a user to the organisation. - """ - if user_model in org_model.user_rel: - raise ConflictException(message="User already a part of this organisation") - org_model.user_rel.append(user_model) - db.flush() - group_model = ( - db.query(Group) - .filter(Group.org_id == org_model.id) - .filter(Group.name == "Default Users") - .first() - ) - if group_model is not None: - user_model.group_rel.append(group_model) - response = { - "organisation": org_model, - "users": [{"id": user.id, "email": user.email} for user in org_model.user_rel], - } - db.commit() - return response + """ + Adds a user to the organisation. + """ + if user_model in org_model.user_rel: + raise ConflictException(message="User already a part of this organisation") + org_model.user_rel.append(user_model) + db.flush() + group_model = ( + db.query(Group) + .filter(Group.org_id == org_model.id) + .filter(Group.name == "Default Users") + .first() + ) + if group_model is not None: + user_model.group_rel.append(group_model) + response = { + "organisation": org_model, + "users": [{"id": user.id, "email": user.email} for user in org_model.user_rel], + } + db.commit() + return response @router.delete( - "", - summary="Delete organisation from the hub.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."}, - status.HTTP_403_FORBIDDEN: {"description": "Not authorised. Must be super admin."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Org ID missing or invalid." - }, - }, + "", + summary="Delete organisation from the hub.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."}, + status.HTTP_403_FORBIDDEN: {"description": "Not authorised. Must be super admin."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Org ID missing or invalid." + }, + }, ) async def delete_organisation_by_id( - db: DbSession, - org_model: org_model_query_dependency, - su: super_admin_dependency, + db: DbSession, + org_model: org_model_query_dependency, + su: super_admin_dependency, ): - """ - Removes an organisation from the hub. - """ - org_model.status = "removed" - org_model.deleted_at = datetime.now(tz=timezone.utc) - db.commit() + """ + Removes an organisation from the hub. + """ + org_model.status = "removed" + org_model.deleted_at = datetime.now(tz=timezone.utc) + db.commit() @router.delete( - "/self", - summary="Delete organisation from the hub as root user before it has been approved.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Unprocessable content.", - "content": { - "application/json": { - "examples": { - "org_id": {"summary": "Invalid or missing org ID."}, - "oidc_claims": {"summary": "Invalid or missing OIDC claims."}, - } - } - }, - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Unauthorized", - "content": { - "application/json": { - "examples": { - "awaiting_approval": { - "summary": "Organisation has not yet been approved." - }, - "expired_token": {"summary": "User token has expired."}, - "oidc": {"summary": "Failed to verify OIDC claims."}, - } - } - }, - }, - status.HTTP_403_FORBIDDEN: { - "description": "Forbidden", - "content": { - "application/json": { - "examples": { - "invalid_state": { - "summary": "Organisation is no longer in pre-approval state." - }, - "not_root": {"summary": "Not authorised. Must be root user."}, - } - } - }, - }, - status.HTTP_404_NOT_FOUND: { - "description": "Not found", - "content": { - "application/json": { - "examples": { - "db_id": {"summary": "User not found in db when checking claims."}, - "user_model": {"summary": "User model not found in db."}, - "org_model": {"summary": "Org model not found in db."}, - } - } - }, - }, - }, + "/self", + summary="Delete organisation from the hub as root user before it has been approved.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Unprocessable content.", + "content": { + "application/json": { + "examples": { + "org_id": {"summary": "Invalid or missing org ID."}, + "oidc_claims": {"summary": "Invalid or missing OIDC claims."}, + } + } + }, + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Unauthorized", + "content": { + "application/json": { + "examples": { + "awaiting_approval": { + "summary": "Organisation has not yet been approved." + }, + "expired_token": {"summary": "User token has expired."}, + "oidc": {"summary": "Failed to verify OIDC claims."}, + } + } + }, + }, + status.HTTP_403_FORBIDDEN: { + "description": "Forbidden", + "content": { + "application/json": { + "examples": { + "invalid_state": { + "summary": "Organisation is no longer in pre-approval state." + }, + "not_root": {"summary": "Not authorised. Must be root user."}, + } + } + }, + }, + status.HTTP_404_NOT_FOUND: { + "description": "Not found", + "content": { + "application/json": { + "examples": { + "db_id": {"summary": "User not found in db when checking claims."}, + "user_model": {"summary": "User model not found in db."}, + "org_model": {"summary": "Org model not found in db."}, + } + } + }, + }, + }, ) async def delete_preapproved_organisation_by_id( - db: DbSession, - org_model: org_model_root_claim_query_dependency, + db: DbSession, + org_model: org_model_root_claim_query_dependency, ): - """ - Removes an organisation from the hub before it has been approved, if user is root. - """ - org_status = StatusEnum(org_model.status) - if not org_status.is_pre_approval: - raise ForbiddenException(message="Organisation is no longer in pre-approval state.") + """ + Removes an organisation from the hub before it has been approved, if user is root. + """ + org_status = StatusEnum(org_model.status) + if not org_status.is_pre_approval: + raise ForbiddenException(message="Organisation is no longer in pre-approval state.") - db.delete(org_model) - db.commit() + db.delete(org_model) + db.commit() @router.patch( - "/root_user", - summary="Update the root user of the organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchRootResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated root user."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_401_UNAUTHORIZED: { - "description": "Not authorised. Must be super admin." - }, - }, + "/root_user", + summary="Update the root user of the organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchRootResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated root user."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_401_UNAUTHORIZED: { + "description": "Not authorised. Must be super admin." + }, + }, ) async def update_root_user( - db: DbSession, - org_model: org_model_body_dependency, - user_model: user_model_body_dependency, - su: super_admin_dependency, - request_model: OrgPatchRootRequest, + db: DbSession, + org_model: org_model_body_dependency, + user_model: user_model_body_dependency, + su: super_admin_dependency, + request_model: OrgPatchRootRequest, ): - """ - Promotes an existing organisation user to the root user, giving them full control of the org. - """ - if user_model not in org_model.user_rel: - raise UnprocessableContentException( - message="This user does not belong to your organisation." - ) - org_model.root_user_rel = user_model - db.flush() - response = OrgPatchRootResponse( - name=org_model.name, root_user_email=org_model.root_user_email - ) - db.commit() - return response + """ + Promotes an existing organisation user to the root user, giving them full control of the org. + """ + if user_model not in org_model.user_rel: + raise UnprocessableContentException( + message="This user does not belong to your organisation." + ) + org_model.root_user_rel = user_model + db.flush() + response = OrgPatchRootResponse( + name=org_model.name, root_user_email=org_model.root_user_email + ) + db.commit() + return response @router.get( - "/groups", - summary="Get all organisation IAM groups.", - status_code=status.HTTP_200_OK, - response_model=OrgGetGroupResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval of IAM groups."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Org ID missing or invalid." - }, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - }, + "/groups", + summary="Get all organisation IAM groups.", + status_code=status.HTTP_200_OK, + response_model=OrgGetGroupResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval of IAM groups."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Org ID missing or invalid." + }, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + }, ) async def get_org_groups(org_model: org_model_root_claim_query_dependency): - """ - Returns a list of the names of all IAM groups created by the organisation. - """ - return { - "organisation": org_model, - "groups": [{"id": group.id, "name": group.name} for group in org_model.group_rel], - } + """ + Returns a list of the names of all IAM groups created by the organisation. + """ + return { + "organisation": org_model, + "groups": [{"id": group.id, "name": group.name} for group in org_model.group_rel], + } @router.delete( - "/user", - summary="Remove user from organisation.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "Successfully removed user."}, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - }, + "/user", + summary="Remove user from organisation.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "Successfully removed user."}, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + }, ) async def remove_user_from_org( - db: DbSession, - org_model: org_model_root_claim_query_dependency, - user_model: user_model_query_dependency, + db: DbSession, + org_model: org_model_root_claim_query_dependency, + user_model: user_model_query_dependency, ): - """ - Revokes a user's membership in an organisation. - """ - if user_model not in org_model.user_rel: - return + """ + Revokes a user's membership in an organisation. + """ + if user_model not in org_model.user_rel: + return - org_model.user_rel.remove(user_model) - db.commit() + org_model.user_rel.remove(user_model) + db.commit() @router.get( - "/contact", - summary="Get contact for organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgGetContactResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval of contact."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - }, + "/contact", + summary="Get contact for organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgGetContactResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval of contact."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + }, ) async def get_contact( - org_model: org_model_root_claim_query_dependency, - contact_type: Annotated[ - ContactType, Query(description="Must be billing|security|owner") - ], + org_model: org_model_root_claim_query_dependency, + contact_type: Annotated[ + ContactType, Query(description="Must be billing|security|owner") + ], ): - """ - Gets full details for a contact point at an organisation. - """ - match contact_type: - case "billing": - contact_model = org_model.billing_contact_rel - case "security": - contact_model = org_model.security_contact_rel - case "owner": - contact_model = org_model.owner_contact_rel - case _: - raise UnprocessableContentException("Invalid contact type") + """ + Gets full details for a contact point at an organisation. + """ + match contact_type: + case "billing": + contact_model = org_model.billing_contact_rel + case "security": + contact_model = org_model.security_contact_rel + case "owner": + contact_model = org_model.owner_contact_rel + case _: + raise UnprocessableContentException("Invalid contact type") - if contact_model is None: - raise ContactNotFoundException() + if contact_model is None: + raise ContactNotFoundException() - address = ContactAddress.model_validate(contact_model) - contact_response = ContactModel.model_construct( - **contact_model.__dict__, address=address - ) + address = ContactAddress.model_validate(contact_model) + contact_response = ContactModel.model_construct( + **contact_model.__dict__, address=address + ) - return {"contact": contact_response, "organisation": org_model} + return {"contact": contact_response, "organisation": org_model} @router.patch( - "/contact", - summary="Update contact for organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchContactResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated contact."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be org root user." - }, - }, + "/contact", + summary="Update contact for organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchContactResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated contact."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_403_FORBIDDEN: { + "description": "Not authorised. Must be org root user." + }, + }, ) async def update_contact( - db: DbSession, - org_model: org_model_root_claim_body_dependency, - request_model: OrgPatchContactRequest, + db: DbSession, + org_model: org_model_root_claim_body_dependency, + request_model: OrgPatchContactRequest, ): - """ - Updates details for a contact point at an organisation. - """ - match request_model.contact_type: - case "billing": - contact_model = org_model.billing_contact_rel - case "security": - contact_model = org_model.security_contact_rel - case "owner": - contact_model = org_model.owner_contact_rel - case _: - raise UnprocessableContentException("Invalid contact type") + """ + Updates details for a contact point at an organisation. + """ + match request_model.contact_type: + case "billing": + contact_model = org_model.billing_contact_rel + case "security": + contact_model = org_model.security_contact_rel + case "owner": + contact_model = org_model.owner_contact_rel + case _: + raise UnprocessableContentException("Invalid contact type") - if contact_model is None: - raise ContactNotFoundException() + if contact_model is None: + raise ContactNotFoundException() - update_data = request_model.model_dump(exclude_none=True) - for key, value in update_data.items(): - if hasattr(contact_model, key): - setattr(contact_model, key, value) - else: - if key == "contact_type" or key == "organisation_id": - continue - raise UnprocessableContentException("Invalid keys in update request") - db.flush() + update_data = request_model.model_dump(exclude_none=True) + for key, value in update_data.items(): + if hasattr(contact_model, key): + setattr(contact_model, key, value) + else: + if key == "contact_type" or key == "organisation_id": + continue + raise UnprocessableContentException("Invalid keys in update request") + db.flush() - address = ContactAddress.model_validate(contact_model) - contact_response = ContactModel.model_construct( - **contact_model.__dict__, address=address - ) + address = ContactAddress.model_validate(contact_model) + contact_response = ContactModel.model_construct( + **contact_model.__dict__, address=address + ) - db.commit() + db.commit() - return {"contact": contact_response, "organisation": org_model} + return {"contact": contact_response, "organisation": org_model} diff --git a/src/organisation/schemas.py b/src/organisation/schemas.py index 522158e..d4bf592 100644 --- a/src/organisation/schemas.py +++ b/src/organisation/schemas.py @@ -12,139 +12,139 @@ from datetime import datetime from pydantic import EmailStr, ConfigDict, Field from src.schemas import ( - CustomBaseModel, - OrgIDMixin, - UserIDMixin, - GroupSummary, - OrgSummary, - UserSummary, + CustomBaseModel, + OrgIDMixin, + UserIDMixin, + GroupSummary, + OrgSummary, + UserSummary, ) from src.contact.schemas import ContactModel from src.organisation.constants import Status, ContactType from src.organisation.schemas_questionnaires import ( - QuestionnaireQuestionsVersion0 as CurrentQuestions, - questionnaire_union, + QuestionnaireQuestionsVersion0 as CurrentQuestions, + questionnaire_union, ) class QuestionnaireMetadata(CustomBaseModel): - version: int - submission_date: Optional[datetime] = None + version: int + submission_date: Optional[datetime] = None class Questionnaire(CustomBaseModel): - metadata: QuestionnaireMetadata - questions: questionnaire_union + metadata: QuestionnaireMetadata + questions: questionnaire_union class ContactSummary(CustomBaseModel): - id: int - email: Optional[EmailStr] = None + id: int + email: Optional[EmailStr] = None class OrgSchema(OrgIDMixin): - name: str - status: Status - root_user_email: EmailStr - intake_questionnaire: Optional[Questionnaire] = None + name: str + status: Status + root_user_email: EmailStr + intake_questionnaire: Optional[Questionnaire] = None - billing_contact: ContactSummary - owner_contact: ContactSummary - security_contact: ContactSummary + billing_contact: ContactSummary + owner_contact: ContactSummary + security_contact: ContactSummary class OrgPostOrgRequest(CustomBaseModel): - name: str = Field(min_length=3) - intake_questionnaire: Optional[CurrentQuestions] = None + name: str = Field(min_length=3) + intake_questionnaire: Optional[CurrentQuestions] = None class OrgPostOrgResponse(CustomBaseModel): - id: int - name: str - status: Status + id: int + name: str + status: Status class OrgPatchQuestionnaireRequest(OrgIDMixin): - intake_questionnaire: CurrentQuestions - partial: bool + intake_questionnaire: CurrentQuestions + partial: bool class OrgPatchQuestionnaireResponse(CustomBaseModel): - id: int - name: str - intake_questionnaire: Questionnaire - status: Status + id: int + name: str + intake_questionnaire: Questionnaire + status: Status class OrgPatchStatusRequest(OrgIDMixin): - status: Status + status: Status class OrgPatchStatusResponse(CustomBaseModel): - id: int - name: str - status: Status + id: int + name: str + status: Status class OrgPatchContactRequest(OrgIDMixin): - contact_type: ContactType + contact_type: ContactType - email: Optional[EmailStr] = None - first_name: Optional[str] = None - last_name: Optional[str] = None - phonenumber: Optional[str] = None - vat_number: Optional[str] = None - post_office_box_number: Optional[str] = None - street_address: Optional[str] = None - street_address_line_2: Optional[str] = None - locality: Optional[str] = None - address_region: Optional[str] = None - country_code: Optional[str] = None - postal_code: Optional[str] = None + email: Optional[EmailStr] = None + first_name: Optional[str] = None + last_name: Optional[str] = None + phonenumber: Optional[str] = None + vat_number: Optional[str] = None + post_office_box_number: Optional[str] = None + street_address: Optional[str] = None + street_address_line_2: Optional[str] = None + locality: Optional[str] = None + address_region: Optional[str] = None + country_code: Optional[str] = None + postal_code: Optional[str] = None class OrgPostUserRequest(OrgIDMixin, UserIDMixin): - pass + pass class OrgPostUserResponse(CustomBaseModel): - organisation: OrgSummary - users: list[UserSummary] + organisation: OrgSummary + users: list[UserSummary] class OrgPatchRootRequest(OrgIDMixin, UserIDMixin): - pass + pass class OrgPatchRootResponse(CustomBaseModel): - name: str - root_user_email: str + name: str + root_user_email: str class OrgGetUserResponse(CustomBaseModel): - users: list[dict[str, str | int]] - organisation: OrgSummary + users: list[dict[str, str | int]] + organisation: OrgSummary class OrgGetGroupResponse(CustomBaseModel): - organisation: OrgSummary - groups: list[GroupSummary] + organisation: OrgSummary + groups: list[GroupSummary] class OrgGetContactResponse(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") + model_config = ConfigDict(from_attributes=True, extra="ignore") - contact: ContactModel - organisation: OrgSummary + contact: ContactModel + organisation: OrgSummary class OrgPatchContactResponse(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") + model_config = ConfigDict(from_attributes=True, extra="ignore") - contact: ContactModel - organisation: OrgSummary + contact: ContactModel + organisation: OrgSummary class OrgGetOrgResponse(CustomBaseModel): - organisations: list[OrgSchema] + organisations: list[OrgSchema] diff --git a/src/organisation/schemas_questionnaires.py b/src/organisation/schemas_questionnaires.py index 491dfe5..b257c53 100644 --- a/src/organisation/schemas_questionnaires.py +++ b/src/organisation/schemas_questionnaires.py @@ -4,13 +4,13 @@ from src.schemas import CustomBaseModel class QuestionnaireQuestions(CustomBaseModel): - pass + pass class QuestionnaireQuestionsVersion0(QuestionnaireQuestions): - question_one: Optional[str] = None - question_two: Optional[str] = None - question_three: Optional[str] = None + question_one: Optional[str] = None + question_two: Optional[str] = None + question_three: Optional[str] = None questionnaire_union = QuestionnaireQuestionsVersion0 # | QuestionnaireQuestionsVersion1 diff --git a/src/organisation/service.py b/src/organisation/service.py index 4fc04f6..dffd044 100644 --- a/src/organisation/service.py +++ b/src/organisation/service.py @@ -11,57 +11,57 @@ from src.user.models import User async def add_default_org_permissions( - db: Session, - org_model: Org, - perm_list: list[int], + db: Session, + org_model: Org, + perm_list: list[int], ): - for permission in perm_list: - perm_model = db.get(Perm, permission) + for permission in perm_list: + perm_model = db.get(Perm, permission) - if perm_model is None: - continue + if perm_model is None: + continue - if perm_model in org_model.permission_rel: - continue + if perm_model in org_model.permission_rel: + continue - org_model.permission_rel.append(perm_model) - db.flush() + org_model.permission_rel.append(perm_model) + db.flush() - db.commit() + db.commit() async def assign_defaults( - db: Session, - org_id: int, - user_id: int, + db: Session, + org_id: int, + user_id: int, ): - default_org_permissions = [] + default_org_permissions = [] - default_user_permissions = [] + default_user_permissions = [] - org_model = db.get(Org, org_id) - if org_model is None: - print("Org not found while adding defaults") - return + org_model = db.get(Org, org_id) + if org_model is None: + print("Org not found while adding defaults") + return - user_model = db.get(User, user_id) - if user_model is None: - print("User not found while adding defaults") - return + user_model = db.get(User, user_id) + if user_model is None: + print("User not found while adding defaults") + return - await add_default_org_permissions(db, org_model, default_org_permissions) - await assign_default_group( - db=db, - org_model=org_model, - user_model=user_model, - group_name="Default Users", - perm_list=default_user_permissions, - ) - await assign_default_group( - db=db, - org_model=org_model, - user_model=user_model, - group_name="Root User", - perm_list=default_org_permissions, - ) - db.commit() + await add_default_org_permissions(db, org_model, default_org_permissions) + await assign_default_group( + db=db, + org_model=org_model, + user_model=user_model, + group_name="Default Users", + perm_list=default_user_permissions, + ) + await assign_default_group( + db=db, + org_model=org_model, + user_model=user_model, + group_name="Root User", + perm_list=default_org_permissions, + ) + db.commit() diff --git a/src/schemas.py b/src/schemas.py index 8bb4fa2..33178ce 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -11,54 +11,54 @@ from typing import Optional class CustomBaseModel(BaseModel): - pass + pass ### Mixins ### class OrgIDMixin(CustomBaseModel): - organisation_id: int = Field(gt=0) + organisation_id: int = Field(gt=0) class GroupIDMixin(CustomBaseModel): - group_id: int = Field(gt=0) + group_id: int = Field(gt=0) class PermIDMixin(CustomBaseModel): - permission_id: int = Field(gt=0) + permission_id: int = Field(gt=0) class ServiceIDMixin(CustomBaseModel): - service_id: int = Field(gt=0) + service_id: int = Field(gt=0) class UserIDMixin(CustomBaseModel): - user_id: int = Field(gt=0) + user_id: int = Field(gt=0) class ServiceNameMixin(CustomBaseModel): - service: str + service: str class OrgSummary(CustomBaseModel): - id: int - name: str + id: int + name: str class GroupSummary(CustomBaseModel): - id: int - name: str + id: int + name: str class UserSummary(CustomBaseModel): - id: int - email: str + id: int + email: str class ServiceSummary(CustomBaseModel): - id: int - name: str + id: int + name: str class ResourceName(ServiceNameMixin, OrgIDMixin): - resource: str - instance: Optional[str] = None + resource: str + instance: Optional[str] = None diff --git a/src/service/dependencies.py b/src/service/dependencies.py index ea42c89..5104035 100644 --- a/src/service/dependencies.py +++ b/src/service/dependencies.py @@ -16,25 +16,23 @@ from src.service.models import Service from src.service.schemas import ServiceIDMixin -async def get_service_model_query( - db: DbSession, service_id: Annotated[int, Query(gt=0)] -): - service_model = db.get(Service, service_id) - if service_model is None: - raise ServiceNotFoundException(service_id=service_id) +async def get_service_model_query(db: DbSession, service_id: Annotated[int, Query(gt=0)]): + service_model = db.get(Service, service_id) + if service_model is None: + raise ServiceNotFoundException(service_id=service_id) - return service_model + return service_model service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)] async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin): - service_model = db.get(Service, request_model.service_id) - if service_model is None: - raise ServiceNotFoundException(service_id=request_model.service_id) + service_model = db.get(Service, request_model.service_id) + if service_model is None: + raise ServiceNotFoundException(service_id=request_model.service_id) - return service_model + return service_model service_model_body_dependency = Annotated[Service, Depends(get_service_model_body)] diff --git a/src/service/exceptions.py b/src/service/exceptions.py index 36a927d..bece7d3 100644 --- a/src/service/exceptions.py +++ b/src/service/exceptions.py @@ -11,13 +11,13 @@ from fastapi import HTTPException, status class ServiceNotFoundException(HTTPException): - def __init__(self, service_id: Optional[int] = None) -> None: - detail = ( - "Service not found" - if service_id is None - else f"Service with ID '{service_id}' was not found." - ) - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - ) + def __init__(self, service_id: Optional[int] = None) -> None: + detail = ( + "Service not found" + if service_id is None + else f"Service with ID '{service_id}' was not found." + ) + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) diff --git a/src/service/models.py b/src/service/models.py index de6dcd6..7ab26fb 100644 --- a/src/service/models.py +++ b/src/service/models.py @@ -12,11 +12,11 @@ from src.models import CustomBase, IdMixin class Service(CustomBase, IdMixin): - __tablename__ = "service" + __tablename__ = "service" - name: Mapped[str] = mapped_column(unique=True) - api_key: Mapped[str] + name: Mapped[str] = mapped_column(unique=True) + api_key: Mapped[str] - permission_rel = relationship( - "Permission", back_populates="service_rel", cascade="all, delete-orphan" - ) + permission_rel = relationship( + "Permission", back_populates="service_rel", cascade="all, delete-orphan" + ) diff --git a/src/service/router.py b/src/service/router.py index 54a7a3a..25d4925 100644 --- a/src/service/router.py +++ b/src/service/router.py @@ -15,8 +15,8 @@ from psycopg.errors import UniqueViolation from src.exceptions import ConflictException from src.database import DbSession from src.auth.dependencies import ( - super_admin_dependency, - org_model_root_claim_query_dependency, + super_admin_dependency, + org_model_root_claim_query_dependency, ) from src.iam.service import service_key_dependency from src.iam.models import Permission as Perm @@ -25,212 +25,210 @@ from src.service.exceptions import ServiceNotFoundException from src.service.models import Service from src.service.utils import generate_api_key from src.service.dependencies import ( - service_model_body_dependency, - service_model_query_dependency, + service_model_body_dependency, + service_model_query_dependency, ) from src.service.schemas import ( - ServiceGetServiceResponse, - ServicePostServiceRequest, - ServicePostServiceResponse, - ServiceWithKeySchema, - ServicePatchKeyResponse, - ServicePatchKeyRequest, - ServicePostPermissionsResponse, - ServicePostPermissionsRequest, + ServiceGetServiceResponse, + ServicePostServiceRequest, + ServicePostServiceResponse, + ServiceWithKeySchema, + ServicePatchKeyResponse, + ServicePatchKeyRequest, + ServicePostPermissionsResponse, + ServicePostPermissionsRequest, ) router = APIRouter( - tags=["Service"], - prefix="/service", + tags=["Service"], + prefix="/service", ) @router.get( - "", - summary="Get all services", - status_code=status.HTTP_200_OK, - response_model=ServiceGetServiceResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - status.HTTP_401_UNAUTHORIZED: { - "description": "Unauthorized", - "content": { - "application/json": { - "examples": { - "awaiting_approval": { - "summary": "Organisation has not yet been approved." - }, - } - } - }, - }, - status.HTTP_403_FORBIDDEN: { - "description": "Forbidden", - "content": { - "application/json": { - "examples": { - "not_root": {"summary": "Not authorised. Must be root user."}, - } - } - }, - }, - }, + "", + summary="Get all services", + status_code=status.HTTP_200_OK, + response_model=ServiceGetServiceResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + status.HTTP_401_UNAUTHORIZED: { + "description": "Unauthorized", + "content": { + "application/json": { + "examples": { + "awaiting_approval": { + "summary": "Organisation has not yet been approved." + }, + } + } + }, + }, + status.HTTP_403_FORBIDDEN: { + "description": "Forbidden", + "content": { + "application/json": { + "examples": { + "not_root": {"summary": "Not authorised. Must be root user."}, + } + } + }, + }, + }, ) -async def get_all_services( - db: DbSession, org_model: org_model_root_claim_query_dependency -): - """ - Returns the ID and name of all services registered to the hub. - """ - permission_models = db.query(Service).all() +async def get_all_services(db: DbSession, org_model: org_model_root_claim_query_dependency): + """ + Returns the ID and name of all services registered to the hub. + """ + permission_models = db.query(Service).all() - return {"services": permission_models} + return {"services": permission_models} @router.post( - "", - summary="Register a new service.", - status_code=status.HTTP_200_OK, - response_model=ServicePostServiceResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully registered a new service"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"}, - }, + "", + summary="Register a new service.", + status_code=status.HTTP_200_OK, + response_model=ServicePostServiceResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully registered a new service"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"}, + }, ) async def register_service( - db: DbSession, - su: super_admin_dependency, - request_model: ServicePostServiceRequest, + db: DbSession, + su: super_admin_dependency, + request_model: ServicePostServiceRequest, ): - """ - Registers a new service to the hub, generating and returning an API key for it. - """ - key = generate_api_key() - service_model = Service(name=request_model.name, api_key=key) + """ + Registers a new service to the hub, generating and returning an API key for it. + """ + key = generate_api_key() + service_model = Service(name=request_model.name, api_key=key) - db.add(service_model) - try: - db.flush() - except IntegrityError as e: - if ( - isinstance(e.orig, UniqueViolation) # Postgres unique violation - or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation - ): - raise ConflictException(message="Service with this name already exists") - raise - response = ServiceWithKeySchema(**service_model.__dict__) - db.commit() - return {"service": response} + db.add(service_model) + try: + db.flush() + except IntegrityError as e: + if ( + isinstance(e.orig, UniqueViolation) # Postgres unique violation + or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation + ): + raise ConflictException(message="Service with this name already exists") + raise + response = ServiceWithKeySchema(**service_model.__dict__) + db.commit() + return {"service": response} @router.patch( - "/key", - summary="Regenerate service API key.", - status_code=status.HTTP_200_OK, - response_model=ServicePatchKeyResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful update of API key"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - }, + "/key", + summary="Regenerate service API key.", + status_code=status.HTTP_200_OK, + response_model=ServicePatchKeyResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful update of API key"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }, ) async def regenerate_api_key( - db: DbSession, - su: super_admin_dependency, - service_model: service_model_body_dependency, - request_model: ServicePatchKeyRequest, + db: DbSession, + su: super_admin_dependency, + service_model: service_model_body_dependency, + request_model: ServicePatchKeyRequest, ): - """ - Generates and returns a new API key for the service to access the hub. - """ - key = generate_api_key() - service_model.api_key = key + """ + Generates and returns a new API key for the service to access the hub. + """ + key = generate_api_key() + service_model.api_key = key - db.flush() - response = ServiceWithKeySchema(**service_model.__dict__) - db.commit() - return {"service": response} + db.flush() + response = ServiceWithKeySchema(**service_model.__dict__) + db.commit() + return {"service": response} @router.delete( - "", - summary="Remove a service.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - }, + "", + summary="Remove a service.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }, ) async def remove_service( - db: DbSession, - service_model: service_model_query_dependency, - su: super_admin_dependency, + db: DbSession, + service_model: service_model_query_dependency, + su: super_admin_dependency, ): - """ - Removes a service from the hub. - """ - db.delete(service_model) - db.commit() + """ + Removes a service from the hub. + """ + db.delete(service_model) + db.commit() @router.post( - path="/permissions", - summary="Service endpoint for creating its own permissions.", - status_code=status.HTTP_200_OK, - response_model=ServicePostPermissionsResponse, - responses={ - status.HTTP_401_UNAUTHORIZED: { - "description": "API Key missing or invalid | Issue verifying user OIDC claims" - }, - }, + path="/permissions", + summary="Service endpoint for creating its own permissions.", + status_code=status.HTTP_200_OK, + response_model=ServicePostPermissionsResponse, + responses={ + status.HTTP_401_UNAUTHORIZED: { + "description": "API Key missing or invalid | Issue verifying user OIDC claims" + }, + }, ) async def service_create_new_permissions( - db: DbSession, - request_model: ServicePostPermissionsRequest, - valid_key: service_key_dependency, + db: DbSession, + request_model: ServicePostPermissionsRequest, + valid_key: service_key_dependency, ): - """ - Allows a service to register its own set of permissions. - """ - service_model = ( - db.query(Service).filter(Service.name == request_model.rn.service).first() - ) - if service_model is None: - raise ServiceNotFoundException() - else: - service_id = service_model.id - response_list = [] - for new_permission in request_model.permissions: - perm_model = ( - db.query(Perm) - .filter(Perm.service_id == service_id) - .filter(Perm.resource == new_permission.resource) - .filter(Perm.action == new_permission.action) - .first() - ) - if perm_model is not None: - response_code = 409 - response = { - "id": perm_model.id, - "service_name": perm_model.service_name, - "resource": perm_model.resource, - "action": perm_model.action, - } - response_list.append((response, response_code)) - continue + """ + Allows a service to register its own set of permissions. + """ + service_model = ( + db.query(Service).filter(Service.name == request_model.rn.service).first() + ) + if service_model is None: + raise ServiceNotFoundException() + else: + service_id = service_model.id + response_list = [] + for new_permission in request_model.permissions: + perm_model = ( + db.query(Perm) + .filter(Perm.service_id == service_id) + .filter(Perm.resource == new_permission.resource) + .filter(Perm.action == new_permission.action) + .first() + ) + if perm_model is not None: + response_code = 409 + response = { + "id": perm_model.id, + "service_name": perm_model.service_name, + "resource": perm_model.resource, + "action": perm_model.action, + } + response_list.append((response, response_code)) + continue - new_perm_model = Perm(**new_permission.__dict__) - new_perm_model.service_id = service_id - db.add(new_perm_model) - db.flush() - response_code = 201 - response = { - "id": new_perm_model.id, - "service_name": new_perm_model.service_name, - "resource": new_perm_model.resource, - "action": new_perm_model.action, - } - response_list.append((response, response_code)) + new_perm_model = Perm(**new_permission.__dict__) + new_perm_model.service_id = service_id + db.add(new_perm_model) + db.flush() + response_code = 201 + response = { + "id": new_perm_model.id, + "service_name": new_perm_model.service_name, + "resource": new_perm_model.resource, + "action": new_perm_model.action, + } + response_list.append((response, response_code)) - db.commit() - return {"permissions": response_list} + db.commit() + return {"permissions": response_list} diff --git a/src/service/schemas.py b/src/service/schemas.py index 544dac3..e618d6d 100644 --- a/src/service/schemas.py +++ b/src/service/schemas.py @@ -10,10 +10,10 @@ from typing import Generic, TypeVar from pydantic import Field, ConfigDict from src.schemas import ( - CustomBaseModel, - ServiceIDMixin, - ServiceSummary, - ServiceNameMixin, + CustomBaseModel, + ServiceIDMixin, + ServiceSummary, + ServiceNameMixin, ) @@ -21,51 +21,51 @@ T = TypeVar("T", bound=ServiceNameMixin) class HasServiceName(CustomBaseModel, Generic[T]): - rn: T + rn: T class PermissionResponseSchema(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") + model_config = ConfigDict(from_attributes=True, extra="ignore") - id: int - service_name: str - resource: str - action: str + id: int + service_name: str + resource: str + action: str class PermissionRequestSchema(CustomBaseModel): - resource: str - action: str + resource: str + action: str class ServiceWithKeySchema(ServiceSummary): - api_key: str + api_key: str class ServiceGetServiceResponse(CustomBaseModel): - services: list[ServiceSummary] + services: list[ServiceSummary] class ServicePostServiceRequest(CustomBaseModel): - name: str = Field(min_length=3) + name: str = Field(min_length=3) class ServicePostServiceResponse(CustomBaseModel): - service: ServiceWithKeySchema + service: ServiceWithKeySchema class ServicePatchKeyRequest(ServiceIDMixin): - pass + pass class ServicePatchKeyResponse(CustomBaseModel): - service: ServiceWithKeySchema + service: ServiceWithKeySchema class ServicePostPermissionsRequest(CustomBaseModel): - rn: ServiceNameMixin - permissions: list[PermissionRequestSchema] + rn: ServiceNameMixin + permissions: list[PermissionRequestSchema] class ServicePostPermissionsResponse(CustomBaseModel): - permissions: list[tuple[PermissionResponseSchema, int]] + permissions: list[tuple[PermissionResponseSchema, int]] diff --git a/src/service/utils.py b/src/service/utils.py index 79bb91f..27cfc27 100644 --- a/src/service/utils.py +++ b/src/service/utils.py @@ -9,4 +9,4 @@ import uuid def generate_api_key() -> str: - return str(uuid.uuid4()) + return str(uuid.uuid4()) diff --git a/src/user/dependencies.py b/src/user/dependencies.py index de23693..7518c16 100644 --- a/src/user/dependencies.py +++ b/src/user/dependencies.py @@ -19,37 +19,37 @@ from src.schemas import UserIDMixin async def get_user_model_claims(claims: claims_dependency, db: DbSession): - user_id = claims.get("db_id", None) - if user_id is None: - raise UserNotFoundException() + user_id = claims.get("db_id", None) + if user_id is None: + raise UserNotFoundException() - user_model = db.get(User, user_id) - if user_model is None: - raise UserNotFoundException(user_id=user_id) + user_model = db.get(User, user_id) + if user_model is None: + raise UserNotFoundException(user_id=user_id) - return user_model + return user_model user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)] async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]): - user_model = db.get(User, user_id) - if user_model is None: - raise UserNotFoundException(user_id=user_id) + user_model = db.get(User, user_id) + if user_model is None: + raise UserNotFoundException(user_id=user_id) - return user_model + return user_model user_model_query_dependency = Annotated[User, Depends(get_user_model_query)] async def get_user_model_body(db: DbSession, request_model: UserIDMixin): - user_model = db.get(User, request_model.user_id) - if user_model is None: - raise UserNotFoundException(user_id=request_model.user_id) + user_model = db.get(User, request_model.user_id) + if user_model is None: + raise UserNotFoundException(user_id=request_model.user_id) - return user_model + return user_model user_model_body_dependency = Annotated[User, Depends(get_user_model_body)] diff --git a/src/user/exceptions.py b/src/user/exceptions.py index fa4db2f..9b0509f 100644 --- a/src/user/exceptions.py +++ b/src/user/exceptions.py @@ -11,13 +11,13 @@ from fastapi import HTTPException, status class UserNotFoundException(HTTPException): - def __init__(self, user_id: Optional[int] = None) -> None: - detail = ( - "User not found" - if user_id is None - else f"User with ID '{user_id}' was not found." - ) - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - ) + def __init__(self, user_id: Optional[int] = None) -> None: + detail = ( + "User not found" + if user_id is None + else f"User with ID '{user_id}' was not found." + ) + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) diff --git a/src/user/models.py b/src/user/models.py index 4803d54..333e78a 100644 --- a/src/user/models.py +++ b/src/user/models.py @@ -20,26 +20,26 @@ from src.models import CustomBase class User(CustomBase, IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin): - __tablename__ = "user" + __tablename__ = "user" - email: Mapped[str] - first_name: Mapped[str] - last_name: Mapped[str] - oidc_id: Mapped[str] = mapped_column(index=True, unique=True) + email: Mapped[str] + first_name: Mapped[str] + last_name: Mapped[str] + oidc_id: Mapped[str] = mapped_column(index=True, unique=True) - organisation_rel = relationship( - "Organisation", secondary="orgusers", back_populates="user_rel" - ) + organisation_rel = relationship( + "Organisation", secondary="orgusers", back_populates="user_rel" + ) - group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel") + group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel") - @property - def organisations(self): - return [{"name": org.name, "id": org.id} for org in self.organisation_rel] + @property + def organisations(self): + return [{"name": org.name, "id": org.id} for org in self.organisation_rel] - @property - def groups(self): - result = defaultdict(list) - for group in self.group_rel: - result[group.org_rel.name].append({"name": group.name, "id": group.id}) - return dict(result) + @property + def groups(self): + result = defaultdict(list) + for group in self.group_rel: + result[group.org_rel.name].append({"name": group.name, "id": group.id}) + return dict(result) diff --git a/src/user/router.py b/src/user/router.py index ad0ccb1..6c87e24 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -13,205 +13,205 @@ from fastapi import APIRouter, status, BackgroundTasks from src.iam.models import Group from src.organisation.exceptions import OrgNotFoundException from src.user.schemas import ( - UserResponse, - OIDCClaims, - UserPostInvitationRequest, - UserPostInvitationAcceptRequest, - UserGetSelfOrgsResponse, - UserPostInvitationResponse, - UserPostInvitationAcceptResponse, + UserResponse, + OIDCClaims, + UserPostInvitationRequest, + UserPostInvitationAcceptRequest, + UserGetSelfOrgsResponse, + UserPostInvitationResponse, + UserPostInvitationAcceptResponse, ) from src.user.dependencies import ( - user_model_claims_dependency, - user_model_query_dependency, + user_model_claims_dependency, + user_model_query_dependency, ) from src.user.service import send_invitation from src.organisation.models import Organisation as Org from src.auth.dependencies import ( - super_admin_dependency, - org_model_root_claim_body_dependency, + super_admin_dependency, + org_model_root_claim_body_dependency, ) from src.auth.service import claims_dependency from src.database import DbSession from src.utils import verify_email_token router = APIRouter( - prefix="/user", - tags=["User"], + prefix="/user", + tags=["User"], ) @router.get( - "/self/claims", - summary="Get current user OIDC claims.", - response_model=OIDCClaims, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - }, + "/self/claims", + summary="Get current user OIDC claims.", + response_model=OIDCClaims, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + }, ) async def current_user_claims(user: claims_dependency): - """ - Returns the full OIDC claims associated with the currently logged-in user. - """ - user["allowed_origins"] = user.get("allowed-origins", []) - return user + """ + Returns the full OIDC claims associated with the currently logged-in user. + """ + user["allowed_origins"] = user.get("allowed-origins", []) + return user @router.get( - "/self/db", - summary="Get current user hub details.", - response_model=UserResponse, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - }, + "/self/db", + summary="Get current user hub details.", + response_model=UserResponse, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + }, ) async def current_user(user_model: user_model_claims_dependency): - """ - Returns the database details associated with the currently logged-in user. - """ - return user_model + """ + Returns the database details associated with the currently logged-in user. + """ + return user_model @router.get( - "", - summary="Get user hub details by ID.", - response_model=UserResponse, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - }, + "", + summary="Get user hub details by ID.", + response_model=UserResponse, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + }, ) async def get_user_by_id( - user_model: user_model_query_dependency, su: super_admin_dependency + user_model: user_model_query_dependency, su: super_admin_dependency ): - """ - Returns the database details associated with the provided user ID. - """ - return user_model + """ + Returns the database details associated with the provided user ID. + """ + return user_model @router.delete( - "", - summary="Delete user from hub by ID.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - }, + "", + summary="Delete user from hub by ID.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + }, ) async def delete_user_by_id( - db: DbSession, - user_model: user_model_query_dependency, - su: super_admin_dependency, + db: DbSession, + user_model: user_model_query_dependency, + su: super_admin_dependency, ): - """ - Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login. - """ - db.delete(user_model) - db.commit() + """ + Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login. + """ + db.delete(user_model) + db.commit() @router.get( - "/self/orgs", - summary="Get all orgs the current user is a member of", - status_code=status.HTTP_200_OK, - response_model=UserGetSelfOrgsResponse, - responses={}, + "/self/orgs", + summary="Get all orgs the current user is a member of", + status_code=status.HTTP_200_OK, + response_model=UserGetSelfOrgsResponse, + responses={}, ) async def get_user_orgs(user_model: user_model_claims_dependency): - user_orgs = user_model.organisation_rel - response = [] - for org in user_orgs: - response.append( - { - "organisation_id": org.id, - "name": org.name, - "status": org.status, - "intake_questionnaire": org.intake_questionnaire, - "root_user_email": org.root_user_email, - "billing_contact": { - "id": org.billing_contact_id, - "email": org.billing_contact_rel.email, - }, - "owner_contact": { - "id": org.owner_contact_id, - "email": org.owner_contact_rel.email, - }, - "security_contact": { - "id": org.security_contact_id, - "email": org.security_contact_rel.email, - }, - } - ) + user_orgs = user_model.organisation_rel + response = [] + for org in user_orgs: + response.append( + { + "organisation_id": org.id, + "name": org.name, + "status": org.status, + "intake_questionnaire": org.intake_questionnaire, + "root_user_email": org.root_user_email, + "billing_contact": { + "id": org.billing_contact_id, + "email": org.billing_contact_rel.email, + }, + "owner_contact": { + "id": org.owner_contact_id, + "email": org.owner_contact_rel.email, + }, + "security_contact": { + "id": org.security_contact_id, + "email": org.security_contact_rel.email, + }, + } + ) - return {"organisations": response} + return {"organisations": response} @router.post( - "/invitation", - summary="Send an email invitation for a user to join an org", - status_code=status.HTTP_200_OK, - response_model=UserPostInvitationResponse, + "/invitation", + summary="Send an email invitation for a user to join an org", + status_code=status.HTTP_200_OK, + response_model=UserPostInvitationResponse, ) async def invitation( - background_tasks: BackgroundTasks, - org_model: org_model_root_claim_body_dependency, - request_model: UserPostInvitationRequest, + background_tasks: BackgroundTasks, + org_model: org_model_root_claim_body_dependency, + request_model: UserPostInvitationRequest, ): - org_id = org_model.id - org_name = org_model.name - user_email = request_model.user_email + org_id = org_model.id + org_name = org_model.name + user_email = request_model.user_email - background_tasks.add_task( - send_invitation, org_id=org_id, org_name=org_name, user_email=user_email - ) + background_tasks.add_task( + send_invitation, org_id=org_id, org_name=org_name, user_email=user_email + ) - response = { - "organisation": org_model, - "invited_email": user_email, - } + response = { + "organisation": org_model, + "invited_email": user_email, + } - return response + return response @router.post( - "/invitation/accept", - summary="Accept email invitation to join an org", - status_code=status.HTTP_200_OK, - response_model=UserPostInvitationAcceptResponse, + "/invitation/accept", + summary="Accept email invitation to join an org", + status_code=status.HTTP_200_OK, + response_model=UserPostInvitationAcceptResponse, ) async def accept_invitation( - db: DbSession, - user_model: user_model_claims_dependency, - request_model: UserPostInvitationAcceptRequest, + db: DbSession, + user_model: user_model_claims_dependency, + request_model: UserPostInvitationAcceptRequest, ): - email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model) + email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model) - org_model = db.get(Org, email_claims["org_id"]) - if org_model is None: - raise OrgNotFoundException() + org_model = db.get(Org, email_claims["org_id"]) + if org_model is None: + raise OrgNotFoundException() - org_model.user_rel.append(user_model) - db.flush() - group_model = ( - db.query(Group) - .filter(Group.org_id == org_model.id) - .filter(Group.name == "Default Users") - .first() - ) - if group_model is not None: - user_model.group_rel.append(group_model) + org_model.user_rel.append(user_model) + db.flush() + group_model = ( + db.query(Group) + .filter(Group.org_id == org_model.id) + .filter(Group.name == "Default Users") + .first() + ) + if group_model is not None: + user_model.group_rel.append(group_model) - response = { - "organisation": org_model, - "user": user_model, - } + response = { + "organisation": org_model, + "user": user_model, + } - db.commit() + db.commit() - return response + return response diff --git a/src/user/schemas.py b/src/user/schemas.py index 65ab530..666f91f 100644 --- a/src/user/schemas.py +++ b/src/user/schemas.py @@ -10,63 +10,63 @@ from src.schemas import CustomBaseModel, OrgIDMixin, OrgSummary, UserSummary class OIDCClaims(CustomBaseModel): - exp: int - iat: int - auth_time: int - jti: str - iss: str - aud: str - sub: str - typ: str - azp: str - sid: str - acr: str - allowed_origins: list[str] - realm_access: dict[str, list[str]] - resource_access: dict[str, dict[str, list[str]]] - scope: str - email_verified: bool - name: str - preferred_username: str - given_name: str - family_name: str - email: str - db_id: int + exp: int + iat: int + auth_time: int + jti: str + iss: str + aud: str + sub: str + typ: str + azp: str + sid: str + acr: str + allowed_origins: list[str] + realm_access: dict[str, list[str]] + resource_access: dict[str, dict[str, list[str]]] + scope: str + email_verified: bool + name: str + preferred_username: str + given_name: str + family_name: str + email: str + db_id: int class OIDCUser(CustomBaseModel): - first_name: str - last_name: str - email: str - oidc_id: str + first_name: str + last_name: str + email: str + oidc_id: str class UserResponse(CustomBaseModel): - id: int - first_name: str - last_name: str - email: str - organisations: list[Optional[dict[str, str | int]]] - groups: Optional[dict[str, list[dict[str, str | int]]]] = None + id: int + first_name: str + last_name: str + email: str + organisations: list[Optional[dict[str, str | int]]] + groups: Optional[dict[str, list[dict[str, str | int]]]] = None class UserPostInvitationRequest(OrgIDMixin): - user_email: EmailStr + user_email: EmailStr class UserPostInvitationAcceptRequest(CustomBaseModel): - jwt: str + jwt: str class UserGetSelfOrgsResponse(CustomBaseModel): - organisations: list[OrgSchema] + organisations: list[OrgSchema] class UserPostInvitationResponse(CustomBaseModel): - organisation: OrgSummary - invited_email: EmailStr + organisation: OrgSummary + invited_email: EmailStr class UserPostInvitationAcceptResponse(CustomBaseModel): - organisation: OrgSummary - user: UserSummary + organisation: OrgSummary + user: UserSummary diff --git a/src/user/service.py b/src/user/service.py index 8721b88..c450cbc 100644 --- a/src/user/service.py +++ b/src/user/service.py @@ -16,49 +16,49 @@ from src.user.models import User async def add_user(db: Session, user_claims: dict[str, Any]) -> int: - try: - valid_user = OIDCUser( - first_name=user_claims["given_name"], - last_name=user_claims["family_name"], - email=user_claims["email"], - oidc_id=user_claims["sub"], - ) - except Exception as e: - logging.exception(e) - raise UnprocessableContentException("Invalid or missing OIDC data") + try: + valid_user = OIDCUser( + first_name=user_claims["given_name"], + last_name=user_claims["family_name"], + email=user_claims["email"], + oidc_id=user_claims["sub"], + ) + except Exception as e: + logging.exception(e) + raise UnprocessableContentException("Invalid or missing OIDC data") - db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first() + db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first() - if not db_user: - user_model = User(**valid_user.model_dump()) - db.add(user_model) - user_id = user_model.id - db.commit() - return user_id + if not db_user: + user_model = User(**valid_user.model_dump()) + db.add(user_model) + user_id = user_model.id + db.commit() + return user_id - user_id = db_user.id - db_user.first_name = valid_user.first_name - db_user.last_name = valid_user.last_name - db.commit() - return user_id + user_id = db_user.id + db_user.first_name = valid_user.first_name + db_user.last_name = valid_user.last_name + db.commit() + return user_id async def send_invitation(user_email: str, org_name: str, org_id: int): - expiry_delta = timedelta(hours=24) - expiry = datetime.now(timezone.utc) + expiry_delta - claims = { - "email": user_email, - "org_id": org_id, - "exp": expiry, - "type": "org_invite", - } + expiry_delta = timedelta(hours=24) + expiry = datetime.now(timezone.utc) + expiry_delta + claims = { + "email": user_email, + "org_id": org_id, + "exp": expiry, + "type": "org_invite", + } - token = await generate_jwt(claims) - subject = f"You have been invited to join {org_name}" - body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}" + token = await generate_jwt(claims) + subject = f"You have been invited to join {org_name}" + body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}" - await send_email( - recipient=user_email, - subject=subject, - body=body, - ) + await send_email( + recipient=user_email, + subject=subject, + body=body, + ) diff --git a/src/utils.py b/src/utils.py index ff8cb89..27f5d90 100644 --- a/src/utils.py +++ b/src/utils.py @@ -11,52 +11,56 @@ KEY = jwk.import_key(settings.SECRET_KEY.get_secret_value(), "oct") async def generate_jwt(claims): - jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims) + jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims) - return jwt_token + return jwt_token async def decode_jwt(encoded): - try: - token = jwt.decode(encoded, key=KEY) - return token.claims - except errors.DecodeError: - raise UnauthorizedException("Invalid JWS") + try: + token = jwt.decode(encoded, key=KEY) + return token.claims + except errors.DecodeError: + raise UnauthorizedException("Invalid JWS") async def verify_email_token(user_model, token): - email_claims = await decode_jwt(token) + email_claims = await decode_jwt(token) - claimed_email = email_claims["email"] + claimed_email = email_claims["email"] - expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc) + expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc) - if expiry < datetime.now(timezone.utc): - raise UnauthorizedException("Invitation expired.") + if expiry < datetime.now(timezone.utc): + raise UnauthorizedException("Invitation expired.") - if user_model.email != claimed_email: - raise ForbiddenException("The logged in user and email do not match.") + if user_model.email != claimed_email: + raise ForbiddenException("The logged in user and email do not match.") - return email_claims + return email_claims async def send_email(recipient: str, subject: str, body: str): - if settings.ENVIRONMENT.is_testing: - return + if settings.ENVIRONMENT.is_testing: + return - lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value()) + lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value()) - if settings.ENVIRONMENT == "local": - recipient = "ok@testing.lettermint.co" + if settings.ENVIRONMENT == "local": + recipient = "ok@testing.lettermint.co" - try: - response = ( - lettermint.email.from_("noreply@sr2.uk") - .to(recipient) - .subject(subject) - .text(body) - .send() - ) - logging.info("Email sent to {} with subject {} (Status: {})".format(recipient, subject, response.status_code)) - except ValidationError as e: - logging.exception(e) + try: + response = ( + lettermint.email.from_("noreply@sr2.uk") + .to(recipient) + .subject(subject) + .text(body) + .send() + ) + logging.info( + "Email sent to {} with subject {} (Status: {})".format( + recipient, subject, response.status_code + ) + ) + except ValidationError as e: + logging.exception(e) diff --git a/test/conftest.py b/test/conftest.py index 6821b9d..022b8dc 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -22,15 +22,15 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @pytest.fixture() def db_session(): - CustomBase.metadata.drop_all(bind=engine) - CustomBase.metadata.create_all(bind=engine) - db = SessionLocal() - try: - _seed(db) # extracted seeding logic into a plain function - yield db - finally: - db.rollback() - db.close() + CustomBase.metadata.drop_all(bind=engine) + CustomBase.metadata.create_all(bind=engine) + db = SessionLocal() + try: + _seed(db) # extracted seeding logic into a plain function + yield db + finally: + db.rollback() + db.close() @pytest.fixture @@ -83,176 +83,176 @@ async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]: def _seed(db): - db.add( - User( - email="admin@test.com", - first_name="Admin", - last_name="Test", - oidc_id="abcd-efgh-ijkl-mnop", - ) - ) - db.add( - User( - email="user@orgone.com", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-qwer", - ) - ) - db.add( - User( - email="root@orgtwo.com", - first_name="Root", - last_name="Test", - oidc_id="abcd-efgh-ijkl-hjkl", - ) - ) - db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927")) - db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927")) - db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927")) - db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927")) - db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927")) - db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927")) - db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927")) - db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927")) - db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927")) - db.flush() - db.add( - Org( - name="Org One", - root_user_id=1, - billing_contact_id=1, - owner_contact_id=2, - security_contact_id=3, - status="approved", - intake_questionnaire={ - "metadata": {"version": 0, "submission_date": None}, - "questions": {"question_two": "answer two"}, - }, - ) - ) - db.add( - Org( - name="Org Two", - root_user_id=3, - billing_contact_id=4, - owner_contact_id=5, - security_contact_id=6, - status="approved", - intake_questionnaire={ - "metadata": {"version": 0, "submission_date": None}, - "questions": {"question_two": "answer two"}, - }, - ) - ) - db.add( - Org( - name="Org Three", - root_user_id=1, - billing_contact_id=7, - owner_contact_id=8, - security_contact_id=9, - status="partial", - intake_questionnaire={ - "metadata": {"version": 0, "submission_date": None}, - "questions": {"question_two": "answer two"}, - }, - ) - ) - db.add(OrgUsers(org_id=1, user_id=2)) - db.add(Service(name="Test Service", api_key="123456789")) - db.add(Permission(service_id=1, resource="test_resource", action="read")) - db.add(Permission(service_id=1, resource="test_resource", action="move")) - db.add(Permission(service_id=1, resource="test_resource", action="delete")) - db.add(OrgPermissions(org_id=1, permission_id=1)) - db.add(OrgPermissions(org_id=1, permission_id=2)) - db.add(Group(name="Org One Group", org_id=1)) - db.add(Group(name="Org Two Group", org_id=2)) - db.add(Group(name="Org One Group Two", org_id=1)) - db.flush() - group_model = db.get(Group, 1) - perm_model = db.get(Permission, 1) - group_model.permission_rel.append(perm_model) - user_model = db.get(User, 1) - org_model = db.get(Org, 1) - org_model.user_rel.append(user_model) - org_model.group_rel.append(group_model) - db.flush() - group_model.user_rel.append(user_model) - db.commit() + db.add( + User( + email="admin@test.com", + first_name="Admin", + last_name="Test", + oidc_id="abcd-efgh-ijkl-mnop", + ) + ) + db.add( + User( + email="user@orgone.com", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-qwer", + ) + ) + db.add( + User( + email="root@orgtwo.com", + first_name="Root", + last_name="Test", + oidc_id="abcd-efgh-ijkl-hjkl", + ) + ) + db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927")) + db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927")) + db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927")) + db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927")) + db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927")) + db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927")) + db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927")) + db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927")) + db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927")) + db.flush() + db.add( + Org( + name="Org One", + root_user_id=1, + billing_contact_id=1, + owner_contact_id=2, + security_contact_id=3, + status="approved", + intake_questionnaire={ + "metadata": {"version": 0, "submission_date": None}, + "questions": {"question_two": "answer two"}, + }, + ) + ) + db.add( + Org( + name="Org Two", + root_user_id=3, + billing_contact_id=4, + owner_contact_id=5, + security_contact_id=6, + status="approved", + intake_questionnaire={ + "metadata": {"version": 0, "submission_date": None}, + "questions": {"question_two": "answer two"}, + }, + ) + ) + db.add( + Org( + name="Org Three", + root_user_id=1, + billing_contact_id=7, + owner_contact_id=8, + security_contact_id=9, + status="partial", + intake_questionnaire={ + "metadata": {"version": 0, "submission_date": None}, + "questions": {"question_two": "answer two"}, + }, + ) + ) + db.add(OrgUsers(org_id=1, user_id=2)) + db.add(Service(name="Test Service", api_key="123456789")) + db.add(Permission(service_id=1, resource="test_resource", action="read")) + db.add(Permission(service_id=1, resource="test_resource", action="move")) + db.add(Permission(service_id=1, resource="test_resource", action="delete")) + db.add(OrgPermissions(org_id=1, permission_id=1)) + db.add(OrgPermissions(org_id=1, permission_id=2)) + db.add(Group(name="Org One Group", org_id=1)) + db.add(Group(name="Org Two Group", org_id=2)) + db.add(Group(name="Org One Group Two", org_id=1)) + db.flush() + group_model = db.get(Group, 1) + perm_model = db.get(Permission, 1) + group_model.permission_rel.append(perm_model) + user_model = db.get(User, 1) + org_model = db.get(Org, 1) + org_model.user_rel.append(user_model) + org_model.group_rel.append(group_model) + db.flush() + group_model.user_rel.append(user_model) + db.commit() def generate_query_and_status(params) -> list[tuple[str, int]]: - possible_values = [0, -1, 42, "banana", ""] + possible_values = [0, -1, 42, "banana", ""] - defaults = [f"{param}=1" for param in params] + defaults = [f"{param}=1" for param in params] - # Missing params - query_list = [ - "&".join(combo) - for r in range(len(defaults) + 1) - for combo in combinations(defaults, r) - ] + # Missing params + query_list = [ + "&".join(combo) + for r in range(len(defaults) + 1) + for combo in combinations(defaults, r) + ] - # Complete query as default for invalid checks - default_query = query_list.pop(-1) + # Complete query as default for invalid checks + default_query = query_list.pop(-1) - # Checks for each param being invalid - for param in params: - for value in possible_values: - new_value = f"&{param}={value}" - query_list.append(default_query.replace(f"{param}=1", new_value)) + # Checks for each param being invalid + for param in params: + for value in possible_values: + new_value = f"&{param}={value}" + query_list.append(default_query.replace(f"{param}=1", new_value)) - query_and_status = [] + query_and_status = [] - # Assign expected status - for query in query_list: - # ID 42 is used to represent a non-existent entry. So it should 404. - status = 404 if "42" in query else 422 - # Remove leading "&" if present - query = query if len(query) > 1 and query[0] != "&" else query[1:] - query_and_status.append((query, status)) + # Assign expected status + for query in query_list: + # ID 42 is used to represent a non-existent entry. So it should 404. + status = 404 if "42" in query else 422 + # Remove leading "&" if present + query = query if len(query) > 1 and query[0] != "&" else query[1:] + query_and_status.append((query, status)) - return query_and_status + return query_and_status def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]: - possible_values_int = [0, -1, 42, "banana", ""] - possible_values_str = [0, "", "a"] + possible_values_int = [0, -1, 42, "banana", ""] + possible_values_str = [0, "", "a"] - defaults = [{param: 1 for param in params.keys()}] + defaults = [{param: 1 for param in params.keys()}] - # Missing params - body_list = [ - {key: ("valid string" if params[key] == "str" else 1) for key in combo} - for r in range(len(defaults[0].keys()) + 1) - for combo in combinations(defaults[0].keys(), r) - ] + # Missing params + body_list = [ + {key: ("valid string" if params[key] == "str" else 1) for key in combo} + for r in range(len(defaults[0].keys()) + 1) + for combo in combinations(defaults[0].keys(), r) + ] - # Complete body as default for generating invalid checks - default_body = body_list.pop(-1) + # Complete body as default for generating invalid checks + default_body = body_list.pop(-1) - # Generates checks for each param being invalid - for param, typ in params.items(): - if typ == "int": - possible_values = possible_values_int - elif typ == "str": - possible_values = possible_values_str - else: - raise TypeError(f"Unknown type {typ}") - for value in possible_values: - new_record = default_body.copy() - new_record[param] = value - body_list.append(new_record) + # Generates checks for each param being invalid + for param, typ in params.items(): + if typ == "int": + possible_values = possible_values_int + elif typ == "str": + possible_values = possible_values_str + else: + raise TypeError(f"Unknown type {typ}") + for value in possible_values: + new_record = default_body.copy() + new_record[param] = value + body_list.append(new_record) - body_and_status = [] + body_and_status = [] - # Assign expected status - for body in body_list: - # ID 42 is used to represent a non-existent entry. So it should 404. - status = 404 if 42 in body.values() else 422 - body_and_status.append((body, status)) - return body_and_status + # Assign expected status + for body in body_list: + # ID 42 is used to represent a non-existent entry. So it should 404. + status = 404 if 42 in body.values() else 422 + body_and_status.append((body, status)) + return body_and_status def get_testable_routes(): diff --git a/test/test_auth_approval.py b/test/test_auth_approval.py index 42d0cd9..a23e410 100644 --- a/test/test_auth_approval.py +++ b/test/test_auth_approval.py @@ -8,181 +8,181 @@ import pytest from httpx import AsyncClient pytestmark = [ - pytest.mark.auth, - pytest.mark.preapproval, + pytest.mark.auth, + pytest.mark.preapproval, ] @pytest.mark.anyio async def test_get_org_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/org?org_id=3") - assert resp.status_code != 422 - assert resp.status_code == 200 + resp = await no_su_client.get("/org?org_id=3") + assert resp.status_code != 422 + assert resp.status_code == 200 @pytest.mark.anyio async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/questionnaire", - json={ - "organisation_id": 3, - "intake_questionnaire": { - "question_one": "new answer one", - "question_two": None, - "question_three": None, - }, - "partial": True, - }, - ) - assert resp.status_code != 422 - assert resp.status_code == 200 + resp = await no_su_client.patch( + "/org/questionnaire", + json={ + "organisation_id": 3, + "intake_questionnaire": { + "question_one": "new answer one", + "question_two": None, + "question_three": None, + }, + "partial": True, + }, + ) + assert resp.status_code != 422 + assert resp.status_code == 200 @pytest.mark.anyio async def test_get_org_users_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/users?org_id=3") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.get("/org/users?org_id=3") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_get_org_groups_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/groups?org_id=3") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.get("/org/groups?org_id=3") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_get_org_contact_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing") - assert resp.status_code != 422 - assert resp.status_code == 200 + resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing") + assert resp.status_code != 422 + assert resp.status_code == 200 @pytest.mark.anyio async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/contact", - json={ - "organisation_id": 3, - "contact_type": "billing", - "email": "user@example.com", - }, - ) - assert resp.status_code != 422 - assert resp.status_code == 200 + resp = await no_su_client.patch( + "/org/contact", + json={ + "organisation_id": 3, + "contact_type": "billing", + "email": "user@example.com", + }, + ) + assert resp.status_code != 422 + assert resp.status_code == 200 @pytest.mark.anyio async def test_get_service_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/service?org_id=3") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.get("/service?org_id=3") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_get_iam_group_permissions_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_get_iam_group_users_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_post_iam_group_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.post( - "/iam/group", json={"name": "New Group", "organisation_id": 3} - ) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.post( + "/iam/group", json={"name": "New Group", "organisation_id": 3} + ) + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.put( - "/iam/group/permission", - json={"permission_id": 1, "group_id": 2, "organisation_id": 3}, - ) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.put( + "/iam/group/permission", + json={"permission_id": 1, "group_id": 2, "organisation_id": 3}, + ) + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.put( - "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3} - ) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.put( + "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3} + ) + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_get_iam_permissions_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/permissions?org_id=3") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.get("/iam/permissions?org_id=3") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_post_iam_permissions_search_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.post( - "/iam/permissions/search", json={"organisation_id": 3, "action": "read"} - ) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + resp = await no_su_client.post( + "/iam/permissions/search", json={"organisation_id": 3, "action": "read"} + ) + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_delete_org_user_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.delete("/org/user?org_id=3&user_id=1") + resp = await no_su_client.delete("/org/user?org_id=3&user_id=1") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_delete_preapproval_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.delete("/org/self?org_id=3") + resp = await no_su_client.delete("/org/self?org_id=3") - assert resp.status_code != 422 - assert resp.status_code == 204 + assert resp.status_code != 422 + assert resp.status_code == 204 @pytest.mark.anyio async def test_post_user_invitation_auth_approval(no_su_client: AsyncClient): - body = {"user_email": "admin@test.com", "organisation_id": 3} - resp = await no_su_client.post("/user/invitation", json=body) + body = {"user_email": "admin@test.com", "organisation_id": 3} + resp = await no_su_client.post("/user/invitation", json=body) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_delete_group_permissions_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1") + resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_delete_group_users_success(no_su_client: AsyncClient): - resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1") + resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio async def test_put_group_user_invitation_success(no_su_client: AsyncClient): - body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1} - resp = await no_su_client.put("/iam/group/user/invitation", json=body) + body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1} + resp = await no_su_client.put("/iam/group/user/invitation", json=body) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] diff --git a/test/test_auth_general.py b/test/test_auth_general.py index 0599c33..543af7a 100644 --- a/test/test_auth_general.py +++ b/test/test_auth_general.py @@ -5,14 +5,14 @@ from httpx import AsyncClient pytestmark = [ - pytest.mark.auth, + pytest.mark.auth, ] @pytest.mark.anyio async def test_get_org_auth_root_su(default_client: AsyncClient): - # If a super admin can access a resource when not the root user - resp = await default_client.get("/org?org_id=2") - assert resp.status_code != 422 - assert resp.status_code == 200 - assert resp.json()["organisations"][0]["name"] == "Org Two" + # If a super admin can access a resource when not the root user + resp = await default_client.get("/org?org_id=2") + assert resp.status_code != 422 + assert resp.status_code == 200 + assert resp.json()["organisations"][0]["name"] == "Org Two" diff --git a/test/test_auth_root.py b/test/test_auth_root.py index d28d6b5..429bf4b 100644 --- a/test/test_auth_root.py +++ b/test/test_auth_root.py @@ -7,147 +7,147 @@ import pytest from httpx import AsyncClient pytestmark = [ - pytest.mark.auth, - pytest.mark.root_user, + pytest.mark.auth, + pytest.mark.root_user, ] @pytest.mark.anyio async def test_get_org_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/org?org_id=2") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/org?org_id=2") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/questionnaire", - json={ - "organisation_id": 2, - "intake_questionnaire": { - "question_one": "new answer one", - "question_two": None, - "question_three": None, - }, - "partial": True, - }, - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.patch( + "/org/questionnaire", + json={ + "organisation_id": 2, + "intake_questionnaire": { + "question_one": "new answer one", + "question_two": None, + "question_three": None, + }, + "partial": True, + }, + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_org_users_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/users?org_id=2") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/org/users?org_id=2") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_org_groups_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/groups?org_id=2") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/org/groups?org_id=2") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_org_contact_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_patch_org_contact_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/contact", - json={ - "organisation_id": 2, - "contact_type": "billing", - "email": "user@example.com", - }, - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.patch( + "/org/contact", + json={ + "organisation_id": 2, + "contact_type": "billing", + "email": "user@example.com", + }, + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_service_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/service?org_id=2") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/service?org_id=2") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_iam_group_permissions_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_post_iam_group_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.post( - "/iam/group", json={"name": "New Group", "organisation_id": 2} - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.post( + "/iam/group", json={"name": "New Group", "organisation_id": 2} + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.put( - "/iam/group/permission", - json={"permission_id": 1, "group_id": 2, "organisation_id": 2}, - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.put( + "/iam/group/permission", + json={"permission_id": 1, "group_id": 2, "organisation_id": 2}, + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_put_iam_group_user_auth_root( - no_su_client: AsyncClient, + no_su_client: AsyncClient, ): - resp = await no_su_client.put( - "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2} - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.put( + "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2} + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/permissions?org_id=2") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.get("/iam/permissions?org_id=2") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.post( - "/iam/permissions/search", json={"organisation_id": 2, "action": "read"} - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be the org's root user" in resp.json()["detail"] + resp = await no_su_client.post( + "/iam/permissions/search", json={"organisation_id": 2, "action": "read"} + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be the org's root user" in resp.json()["detail"] diff --git a/test/test_auth_su.py b/test/test_auth_su.py index 09ce558..f0136bf 100644 --- a/test/test_auth_su.py +++ b/test/test_auth_su.py @@ -7,69 +7,69 @@ import pytest from httpx import AsyncClient pytestmark = [ - pytest.mark.auth, - pytest.mark.super_admin, + pytest.mark.auth, + pytest.mark.super_admin, ] @pytest.mark.anyio async def test_get_user_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.get("/user?user_id=1") - assert resp.status_code != 422 - assert resp.status_code == 403 - assert resp.json()["detail"] == "Must be super admin" + resp = await no_su_client.get("/user?user_id=1") + assert resp.status_code != 422 + assert resp.status_code == 403 + assert resp.json()["detail"] == "Must be super admin" @pytest.mark.anyio async def test_patch_org_status_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/status", json={"organisation_id": 1, "status": "submitted"} - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert resp.json()["detail"] == "Must be super admin" + resp = await no_su_client.patch( + "/org/status", json={"organisation_id": 1, "status": "submitted"} + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert resp.json()["detail"] == "Must be super admin" @pytest.mark.anyio async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/root_user", json={"organisation_id": 1, "user_id": 2} - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert resp.json()["detail"] == "Must be super admin" + resp = await no_su_client.patch( + "/org/root_user", json={"organisation_id": 1, "user_id": 2} + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert resp.json()["detail"] == "Must be super admin" @pytest.mark.anyio async def test_patch_service_key_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.patch("/service/key", json={"service_id": 1}) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert resp.json()["detail"] == "Must be super admin" + resp = await no_su_client.patch("/service/key", json={"service_id": 1}) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert resp.json()["detail"] == "Must be super admin" @pytest.mark.anyio async def test_post_service_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.post("/service", json={"name": "New Test Service"}) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert resp.json()["detail"] == "Must be super admin" + resp = await no_su_client.post("/service", json={"name": "New Test Service"}) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert resp.json()["detail"] == "Must be super admin" @pytest.mark.anyio async def test_post_perm_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.post( - "/iam/permission", - json={"service_id": 1, "resource": "test_resource", "action": "create"}, - ) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert resp.json()["detail"] == "Must be super admin" + resp = await no_su_client.post( + "/iam/permission", + json={"service_id": 1, "resource": "test_resource", "action": "create"}, + ) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert resp.json()["detail"] == "Must be super admin" @pytest.mark.anyio async def test_post_org_user_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2}) - assert resp.status_code != 422 - assert resp.status_code == 403 - assert "Must be super admin" in resp.json()["detail"] + resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2}) + assert resp.status_code != 422 + assert resp.status_code == 403 + assert "Must be super admin" in resp.json()["detail"] diff --git a/test/test_auth_user.py b/test/test_auth_user.py index 99f5579..3d2f6ce 100644 --- a/test/test_auth_user.py +++ b/test/test_auth_user.py @@ -7,22 +7,22 @@ from httpx import AsyncClient pytestmark = [ - pytest.mark.auth, - pytest.mark.user, + pytest.mark.auth, + pytest.mark.user, ] @pytest.mark.anyio async def test_get_self_db_auth_user(no_user_client: AsyncClient): - resp = await no_user_client.get("/user/self/db") - assert resp.status_code != 422 - assert resp.status_code == 401 - assert resp.json()["detail"] == "Not authenticated" + resp = await no_user_client.get("/user/self/db") + assert resp.status_code != 422 + assert resp.status_code == 401 + assert resp.json()["detail"] == "Not authenticated" @pytest.mark.anyio async def test_post_org_success_auth_user(no_user_client: AsyncClient): - resp = await no_user_client.post("/org", json={"name": "New Test Org"}) - assert resp.status_code != 422 - assert resp.status_code == 401 - assert resp.json()["detail"] == "Not authenticated" + resp = await no_user_client.post("/org", json={"name": "New Test Org"}) + assert resp.status_code != 422 + assert resp.status_code == 401 + assert resp.json()["detail"] == "Not authenticated" diff --git a/test/test_healthcheck.py b/test/test_healthcheck.py index 47a3993..6fdb9be 100644 --- a/test/test_healthcheck.py +++ b/test/test_healthcheck.py @@ -4,7 +4,7 @@ from httpx import AsyncClient @pytest.mark.anyio async def test_healthcheck(default_client: AsyncClient): - resp = await default_client.get("/healthcheck") + resp = await default_client.get("/healthcheck") - assert resp.status_code == 200 - assert resp.json() == {"status": "ok"} + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} diff --git a/test/test_iam.py b/test/test_iam.py index ad45543..c890925 100644 --- a/test/test_iam.py +++ b/test/test_iam.py @@ -6,747 +6,747 @@ from httpx import AsyncClient from .conftest import generate_query_and_status, generate_body_and_status pytestmark = [ - pytest.mark.iam_module, + pytest.mark.iam_module, ] @pytest.mark.anyio async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient): - body = { - "rn": { - "service": "Test Service", - "organisation_id": 1, - "resource": "test_resource", - "instance": None, - }, - "action": "read", - } - headers = { - "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": "123456789", - } - resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) - data = resp.json() + body = { + "rn": { + "service": "Test Service", + "organisation_id": 1, + "resource": "test_resource", + "instance": None, + }, + "action": "read", + } + headers = { + "Authorization": "Bearer not_checked_when_auth_is_disabled", + "X-API-Key": "123456789", + } + resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) + data = resp.json() - assert resp.status_code == 200 - assert data["allowed"] is True + assert resp.status_code == 200 + assert data["allowed"] is True - print(data) + print(data) @pytest.mark.parametrize( - "service, api_key", - [("Test Service", "not_the_correct_key"), ("Test Service Two", "123456789")], + "service, api_key", + [("Test Service", "not_the_correct_key"), ("Test Service Two", "123456789")], ) @pytest.mark.anyio async def test_act_on_resource_wrong_key( - default_client: AsyncClient, service: str, api_key: str + default_client: AsyncClient, service: str, api_key: str ): - body = { - "rn": { - "service": service, - "organisation": "Test Org", - "resource": "test_resource", - }, - "action": "read", - } - headers = { - "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": api_key, - } - resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) - data = resp.json() + body = { + "rn": { + "service": service, + "organisation": "Test Org", + "resource": "test_resource", + }, + "action": "read", + } + headers = { + "Authorization": "Bearer not_checked_when_auth_is_disabled", + "X-API-Key": api_key, + } + resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) + data = resp.json() - assert resp.status_code == 401 - assert data["detail"] == "Invalid API key" + assert resp.status_code == 401 + assert data["detail"] == "Invalid API key" @pytest.mark.anyio async def test_act_on_resource_missing_key(default_client: AsyncClient): - body = { - "rn": { - "service": "Test Service", - "organisation": "Test Org", - "resource": "test_resource", - }, - "action": "read", - } - headers = {"Authorization": "Bearer not_checked_when_auth_is_disabled"} - resp = await default_client.post( - "/iam/can_act_on_resource?action=read", json=body, headers=headers - ) - data = resp.json() + body = { + "rn": { + "service": "Test Service", + "organisation": "Test Org", + "resource": "test_resource", + }, + "action": "read", + } + headers = {"Authorization": "Bearer not_checked_when_auth_is_disabled"} + resp = await default_client.post( + "/iam/can_act_on_resource?action=read", json=body, headers=headers + ) + data = resp.json() - assert resp.status_code == 401 - assert data["detail"] == "Missing API key" + assert resp.status_code == 401 + assert data["detail"] == "Missing API key" @pytest.mark.parametrize( - "service, org, resource, action, expected_status", - [ - (None, "Test Org", "test_resource", "read", 422), - (42, "Test Org", "test_resource", "read", 422), - ("Test Service", None, "test_resource", "read", 422), - ("Test Service", 42, "test_resource", "read", 422), - ("Test Service", "Test Org", None, "read", 422), - ("Test Service", "Test Org", 42, "read", 422), - ("Test Service", "Test Org", "test_resource", None, 422), - ("Test Service", "Test Org", "test_resource", 42, 422), - ], + "service, org, resource, action, expected_status", + [ + (None, "Test Org", "test_resource", "read", 422), + (42, "Test Org", "test_resource", "read", 422), + ("Test Service", None, "test_resource", "read", 422), + ("Test Service", 42, "test_resource", "read", 422), + ("Test Service", "Test Org", None, "read", 422), + ("Test Service", "Test Org", 42, "read", 422), + ("Test Service", "Test Org", "test_resource", None, 422), + ("Test Service", "Test Org", "test_resource", 42, 422), + ], ) @pytest.mark.anyio async def test_act_on_resource_endpoint_status_checks( - default_client: AsyncClient, service, org, resource, action, expected_status: int + default_client: AsyncClient, service, org, resource, action, expected_status: int ): - body = { - "rn": {"service": service, "organisation": org, "resource": resource}, - "action": action, - } - headers = { - "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": "123456789", - } - resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) + body = { + "rn": {"service": service, "organisation": org, "resource": resource}, + "action": action, + } + headers = { + "Authorization": "Bearer not_checked_when_auth_is_disabled", + "X-API-Key": "123456789", + } + resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "service, org, resource, action, expected_response", - [ - ("Test Service", 1, "test_resource", "read", True), - ("Test Service", 1, "test_resource", "create", False), - ("Test Service", 1, "no_access_here", "read", False), - ("Test Service", 2, "test_resource", "read", False), - ], + "service, org, resource, action, expected_response", + [ + ("Test Service", 1, "test_resource", "read", True), + ("Test Service", 1, "test_resource", "create", False), + ("Test Service", 1, "no_access_here", "read", False), + ("Test Service", 2, "test_resource", "read", False), + ], ) @pytest.mark.anyio async def test_act_on_resource_logic( - default_client: AsyncClient, - service, - org, - resource, - action, - expected_response: bool, + default_client: AsyncClient, + service, + org, + resource, + action, + expected_response: bool, ): - body = { - "rn": {"service": service, "organisation_id": org, "resource": resource}, - "action": action, - } - headers = { - "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": "123456789", - } - resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) - data = resp.json() + body = { + "rn": {"service": service, "organisation_id": org, "resource": resource}, + "action": action, + } + headers = { + "Authorization": "Bearer not_checked_when_auth_is_disabled", + "X-API-Key": "123456789", + } + resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) + data = resp.json() - assert resp.status_code == 200 - assert data["allowed"] == expected_response + assert resp.status_code == 200 + assert data["allowed"] == expected_response @pytest.mark.anyio async def test_get_group_permissions_success(default_client: AsyncClient): - resp = await default_client.get("/iam/group/permissions?org_id=1&group_id=1") - assert resp.status_code == 200 + resp = await default_client.get("/iam/group/permissions?org_id=1&group_id=1") + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "permissions" in data - assert isinstance(data["permissions"], list) + assert "permissions" in data + assert isinstance(data["permissions"], list) - permission = data["permissions"][0] - assert permission["id"] == 1 - assert permission["service_name"] == "Test Service" - assert permission["resource"] == "test_resource" - assert permission["action"] == "read" + permission = data["permissions"][0] + assert permission["id"] == 1 + assert permission["service_name"] == "Test Service" + assert permission["resource"] == "test_resource" + assert permission["action"] == "read" @pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["group_id", "org_id"]) + "query, expected_status", generate_query_and_status(["group_id", "org_id"]) ) @pytest.mark.anyio async def test_get_group_permissions_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/iam/group/permissions?{query}") + resp = await default_client.get(f"/iam/group/permissions?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "query", - [ - "org_id=1&group_id=2", - "org_id=2&group_id=1", - ], + "query", + [ + "org_id=1&group_id=2", + "org_id=2&group_id=1", + ], ) @pytest.mark.anyio async def test_get_group_permissions_mismatch(default_client: AsyncClient, query: str): - resp = await default_client.get(f"/iam/group/permissions?{query}") + resp = await default_client.get(f"/iam/group/permissions?{query}") - assert resp.status_code == 403 - assert resp.json()["detail"] == "Group does not belong to this organization" + assert resp.status_code == 403 + assert resp.json()["detail"] == "Group does not belong to this organization" @pytest.mark.anyio async def test_get_group_users_success(default_client: AsyncClient): - resp = await default_client.get("/iam/group/users?org_id=1&group_id=1") - assert resp.status_code == 200 + resp = await default_client.get("/iam/group/users?org_id=1&group_id=1") + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "users" in data - assert isinstance(data["users"], list) + assert "users" in data + assert isinstance(data["users"], list) - user = data["users"][0] - assert user["id"] == 1 - assert user["email"] == "admin@test.com" + user = data["users"][0] + assert user["id"] == 1 + assert user["email"] == "admin@test.com" - assert "group" in data - assert isinstance(data["group"], dict) - assert data["group"]["id"] == 1 - assert data["group"]["name"] == "Org One Group" + assert "group" in data + assert isinstance(data["group"], dict) + assert data["group"]["id"] == 1 + assert data["group"]["name"] == "Org One Group" - assert "organisation" in data - assert isinstance(data["organisation"], dict) - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert "organisation" in data + assert isinstance(data["organisation"], dict) + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" @pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["group_id", "org_id"]) + "query, expected_status", generate_query_and_status(["group_id", "org_id"]) ) @pytest.mark.anyio async def test_get_group_users_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/iam/group/users?{query}") + resp = await default_client.get(f"/iam/group/users?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "query", - [ - "org_id=1&group_id=2", - "org_id=2&group_id=1", - ], + "query", + [ + "org_id=1&group_id=2", + "org_id=2&group_id=1", + ], ) @pytest.mark.anyio async def test_get_group_users_mismatch(default_client: AsyncClient, query: str): - resp = await default_client.get(f"/iam/group/users?{query}") + resp = await default_client.get(f"/iam/group/users?{query}") - assert resp.status_code == 403 - assert resp.json()["detail"] == "Group does not belong to this organization" + assert resp.status_code == 403 + assert resp.json()["detail"] == "Group does not belong to this organization" @pytest.mark.anyio async def test_post_group_success(default_client: AsyncClient): - resp = await default_client.post( - "/iam/group", json={"name": "New Group", "organisation_id": 1} - ) - assert resp.status_code == 201 + resp = await default_client.post( + "/iam/group", json={"name": "New Group", "organisation_id": 1} + ) + assert resp.status_code == 201 - data = resp.json() + data = resp.json() - assert "group" in data - assert isinstance(data["group"], dict) - assert data["group"]["name"] == "New Group" - assert data["group"]["id"] == 4 + assert "group" in data + assert isinstance(data["group"], dict) + assert data["group"]["name"] == "New Group" + assert data["group"]["id"] == 4 @pytest.mark.parametrize( - "body, expected_status", - generate_body_and_status({"organisation_id": "int", "name": "str"}), + "body, expected_status", + generate_body_and_status({"organisation_id": "int", "name": "str"}), ) @pytest.mark.anyio async def test_post_group_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.post("/iam/group", json=body) + resp = await default_client.post("/iam/group", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_group_conflict(default_client: AsyncClient): - resp = await default_client.post( - "/iam/group", json={"organisation_id": 1, "name": "Org One Group"} - ) + resp = await default_client.post( + "/iam/group", json={"organisation_id": 1, "name": "Org One Group"} + ) - assert resp.status_code == 409 + assert resp.status_code == 409 @pytest.mark.anyio async def test_post_group_non_conflict(default_client: AsyncClient): - resp = await default_client.post( - "/iam/group", json={"organisation_id": 2, "name": "Org One Group"} - ) + resp = await default_client.post( + "/iam/group", json={"organisation_id": 2, "name": "Org One Group"} + ) - assert resp.status_code == 201 + assert resp.status_code == 201 @pytest.mark.anyio async def test_put_group_perm_success(default_client: AsyncClient): - resp = await default_client.put( - "/iam/group/permission", - json={"permission_id": 1, "group_id": 3, "organisation_id": 1}, - ) - assert resp.status_code == 200 + resp = await default_client.put( + "/iam/group/permission", + json={"permission_id": 1, "group_id": 3, "organisation_id": 1}, + ) + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "group" in data - assert isinstance(data["group"], dict) - assert data["group"]["name"] == "Org One Group Two" - assert data["group"]["id"] == 3 + assert "group" in data + assert isinstance(data["group"], dict) + assert data["group"]["name"] == "Org One Group Two" + assert data["group"]["id"] == 3 - assert "permissions" in data - assert isinstance(data["permissions"], list) + assert "permissions" in data + assert isinstance(data["permissions"], list) - permission = data["permissions"][0] - assert permission["id"] == 1 - assert permission["service_name"] == "Test Service" - assert permission["resource"] == "test_resource" - assert permission["action"] == "read" + permission = data["permissions"][0] + assert permission["id"] == 1 + assert permission["service_name"] == "Test Service" + assert permission["resource"] == "test_resource" + assert permission["action"] == "read" @pytest.mark.parametrize( - "body, expected_status", - generate_body_and_status( - {"organisation_id": "int", "group_id": "int", "permission_id": "int"} - ), + "body, expected_status", + generate_body_and_status( + {"organisation_id": "int", "group_id": "int", "permission_id": "int"} + ), ) @pytest.mark.anyio async def test_put_group_perm_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.put("/iam/group/permission", json=body) + resp = await default_client.put("/iam/group/permission", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_put_group_perm_conflict(default_client: AsyncClient): - resp = await default_client.put( - "/iam/group/permission", - json={"organisation_id": 1, "group_id": 1, "permission_id": 1}, - ) + resp = await default_client.put( + "/iam/group/permission", + json={"organisation_id": 1, "group_id": 1, "permission_id": 1}, + ) - assert resp.status_code == 409 + assert resp.status_code == 409 @pytest.mark.parametrize( - "body", - [ - {"organisation_id": 1, "group_id": 2, "permission_id": 1}, - {"organisation_id": 2, "group_id": 1, "permission_id": 1}, - ], + "body", + [ + {"organisation_id": 1, "group_id": 2, "permission_id": 1}, + {"organisation_id": 2, "group_id": 1, "permission_id": 1}, + ], ) @pytest.mark.anyio async def test_put_group_perm_mismatch(default_client: AsyncClient, body: dict): - resp = await default_client.put("/iam/group/permission", json=body) + resp = await default_client.put("/iam/group/permission", json=body) - assert resp.status_code == 403 - assert resp.json()["detail"] == "Group does not belong to this organization" + assert resp.status_code == 403 + assert resp.json()["detail"] == "Group does not belong to this organization" @pytest.mark.anyio async def test_put_group_user_success(default_client: AsyncClient): - resp = await default_client.put( - "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1} - ) - assert resp.status_code == 200 + resp = await default_client.put( + "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1} + ) + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "group" in data - assert isinstance(data["group"], dict) - assert data["group"]["name"] == "Org One Group" - assert data["group"]["id"] == 1 + assert "group" in data + assert isinstance(data["group"], dict) + assert data["group"]["name"] == "Org One Group" + assert data["group"]["id"] == 1 - assert "users" in data - assert isinstance(data["users"], list) + assert "users" in data + assert isinstance(data["users"], list) - user = data["users"][1] - assert user["id"] == 2 - assert user["first_name"] == "User" - assert user["last_name"] == "Test" - assert user["email"] == "user@orgone.com" + user = data["users"][1] + assert user["id"] == 2 + assert user["first_name"] == "User" + assert user["last_name"] == "Test" + assert user["email"] == "user@orgone.com" @pytest.mark.parametrize( - "body, expected_status", - generate_body_and_status( - {"organisation_id": "int", "group_id": "int", "user_id": "int"} - ), + "body, expected_status", + generate_body_and_status( + {"organisation_id": "int", "group_id": "int", "user_id": "int"} + ), ) @pytest.mark.anyio async def test_put_group_user_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.put("/iam/group/user", json=body) + resp = await default_client.put("/iam/group/user", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_get_permissions_success(default_client: AsyncClient): - resp = await default_client.get("/iam/permissions?org_id=1") - data = resp.json() + resp = await default_client.get("/iam/permissions?org_id=1") + data = resp.json() - assert resp.status_code == 200 - assert "permissions" in data - assert isinstance(data["permissions"], list) + assert resp.status_code == 200 + assert "permissions" in data + assert isinstance(data["permissions"], list) - permission = data["permissions"][0] - assert permission["id"] == 1 - assert permission["service_name"] == "Test Service" - assert permission["resource"] == "test_resource" - assert permission["action"] == "read" + permission = data["permissions"][0] + assert permission["id"] == 1 + assert permission["service_name"] == "Test Service" + assert permission["resource"] == "test_resource" + assert permission["action"] == "read" @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_permissions_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/iam/permissions?{query}") + resp = await default_client.get(f"/iam/permissions?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_perm_success(default_client: AsyncClient): - resp = await default_client.post( - "/iam/permission", - json={"service_id": 1, "resource": "test_resource", "action": "create"}, - ) - assert resp.status_code == 201 + resp = await default_client.post( + "/iam/permission", + json={"service_id": 1, "resource": "test_resource", "action": "create"}, + ) + assert resp.status_code == 201 - data = resp.json() + data = resp.json() - assert "permission" in data - assert isinstance(data["permission"], dict) + assert "permission" in data + assert isinstance(data["permission"], dict) - assert data["permission"]["id"] == 4 - assert data["permission"]["service_name"] == "Test Service" - assert data["permission"]["resource"] == "test_resource" - assert data["permission"]["action"] == "create" + assert data["permission"]["id"] == 4 + assert data["permission"]["service_name"] == "Test Service" + assert data["permission"]["resource"] == "test_resource" + assert data["permission"]["action"] == "create" @pytest.mark.parametrize( - "body, expected_status", - [ - ( - {"service_id": 1, "resource": "test_resource", "action": "read"}, - 409, - ), - # service_id tests - ( - {"service_id": 42, "resource": "test_resource", "action": "read"}, - 404, - ), # Non-existent service - ( - {"service_id": "banana", "resource": "test_resource", "action": "read"}, - 422, - ), # Invalid service ID - ( - {"service_id": "", "resource": "test_resource", "action": "read"}, - 422, - ), # Blank service ID - ( - {"service_id": -1, "resource": "test_resource", "action": "read"}, - 422, - ), # Negative service ID - # resource tests - ( - {"service_id": 1, "resource": 42, "action": "read"}, - 422, - ), # Invalid resource type - # action tests - ( - {"service_id": 1, "resource": "test_resource", "action": 42}, - 422, - ), # Invalid action type - # missing/partial body tests - ({}, 422), # Blank body - ({"resource": "test_resource"}, 422), # Only resource - ({"action": "read"}, 422), # Only action - ({"service_id": 1}, 422), # Only service - ({"service_id": 1, "action": "read"}, 422), # Missing resource - ({"service_id": 1, "resource": "test_resource"}, 422), # Missing action - ({"resource": "test_resource", "action": "read"}, 422), # Missing service - ], + "body, expected_status", + [ + ( + {"service_id": 1, "resource": "test_resource", "action": "read"}, + 409, + ), + # service_id tests + ( + {"service_id": 42, "resource": "test_resource", "action": "read"}, + 404, + ), # Non-existent service + ( + {"service_id": "banana", "resource": "test_resource", "action": "read"}, + 422, + ), # Invalid service ID + ( + {"service_id": "", "resource": "test_resource", "action": "read"}, + 422, + ), # Blank service ID + ( + {"service_id": -1, "resource": "test_resource", "action": "read"}, + 422, + ), # Negative service ID + # resource tests + ( + {"service_id": 1, "resource": 42, "action": "read"}, + 422, + ), # Invalid resource type + # action tests + ( + {"service_id": 1, "resource": "test_resource", "action": 42}, + 422, + ), # Invalid action type + # missing/partial body tests + ({}, 422), # Blank body + ({"resource": "test_resource"}, 422), # Only resource + ({"action": "read"}, 422), # Only action + ({"service_id": 1}, 422), # Only service + ({"service_id": 1, "action": "read"}, 422), # Missing resource + ({"service_id": 1, "resource": "test_resource"}, 422), # Missing action + ({"resource": "test_resource", "action": "read"}, 422), # Missing service + ], ) @pytest.mark.anyio async def test_post_perm_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.post("/iam/permission", json=body) + resp = await default_client.post("/iam/permission", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "body", - [ - { - "organisation_id": 1, - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - {"organisation_id": 1, "service_id": 1}, - {"organisation_id": 1, "resource": "test_resource"}, - {"organisation_id": 1, "action": "read"}, - {"organisation_id": 1, "service_id": 1, "action": "read"}, - {"organisation_id": 1, "service_id": 1, "resource": "test_resource"}, - {"organisation_id": 1, "resource": "test_resource", "action": "read"}, - ], + "body", + [ + { + "organisation_id": 1, + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + {"organisation_id": 1, "service_id": 1}, + {"organisation_id": 1, "resource": "test_resource"}, + {"organisation_id": 1, "action": "read"}, + {"organisation_id": 1, "service_id": 1, "action": "read"}, + {"organisation_id": 1, "service_id": 1, "resource": "test_resource"}, + {"organisation_id": 1, "resource": "test_resource", "action": "read"}, + ], ) @pytest.mark.anyio async def test_post_perm_search_success(default_client: AsyncClient, body): - resp = await default_client.post("/iam/permissions/search", json=body) - data = resp.json() + resp = await default_client.post("/iam/permissions/search", json=body) + data = resp.json() - assert resp.status_code == 200 - assert "permissions" in data - assert isinstance(data["permissions"], list) + assert resp.status_code == 200 + assert "permissions" in data + assert isinstance(data["permissions"], list) - permissions_filtered = [ - permission for permission in data["permissions"] if permission["id"] == 1 - ] - assert len(permissions_filtered) == 1 - permission = permissions_filtered[0] - assert permission["id"] == 1 - assert permission["service_name"] == "Test Service" - assert permission["resource"] == "test_resource" - assert permission["action"] == "read" + permissions_filtered = [ + permission for permission in data["permissions"] if permission["id"] == 1 + ] + assert len(permissions_filtered) == 1 + permission = permissions_filtered[0] + assert permission["id"] == 1 + assert permission["service_name"] == "Test Service" + assert permission["resource"] == "test_resource" + assert permission["action"] == "read" @pytest.mark.parametrize( - "body, expected_status", - [ - # organisation_id tests - ( - { - "organisation_id": 42, - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - 404, - ), # Non-existent organisation - ( - { - "organisation_id": "banana", - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Invalid organisation ID - ( - { - "organisation_id": "", - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Blank organisation ID - ( - { - "organisation_id": -1, - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Negative organisation ID - # service_id tests - ( - { - "organisation_id": 1, - "service_id": "banana", - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Invalid service ID - ( - { - "organisation_id": 1, - "service_id": "", - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Blank service ID - ( - { - "organisation_id": 1, - "service_id": -1, - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Negative service ID - # resource tests - ( - {"organisation_id": 1, "service_id": 1, "resource": 42, "action": "read"}, - 422, - ), # Invalid resource type - # action tests - ( - { - "organisation_id": 1, - "service_id": 1, - "resource": "test_resource", - "action": 42, - }, - 422, - ), # Invalid action type - # missing/partial body tests - ({}, 422), # Blank body - ], + "body, expected_status", + [ + # organisation_id tests + ( + { + "organisation_id": 42, + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + 404, + ), # Non-existent organisation + ( + { + "organisation_id": "banana", + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Invalid organisation ID + ( + { + "organisation_id": "", + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Blank organisation ID + ( + { + "organisation_id": -1, + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Negative organisation ID + # service_id tests + ( + { + "organisation_id": 1, + "service_id": "banana", + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Invalid service ID + ( + { + "organisation_id": 1, + "service_id": "", + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Blank service ID + ( + { + "organisation_id": 1, + "service_id": -1, + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Negative service ID + # resource tests + ( + {"organisation_id": 1, "service_id": 1, "resource": 42, "action": "read"}, + 422, + ), # Invalid resource type + # action tests + ( + { + "organisation_id": 1, + "service_id": 1, + "resource": "test_resource", + "action": 42, + }, + 422, + ), # Invalid action type + # missing/partial body tests + ({}, 422), # Blank body + ], ) @pytest.mark.anyio async def test_post_perm_search_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.post("/iam/permissions/search", json=body) + resp = await default_client.post("/iam/permissions/search", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_delete_group_permissions_success(default_client: AsyncClient): - resp = await default_client.delete( - "/iam/group/permission?org_id=1&group_id=1&perm_id=1" - ) - data = resp.json() + resp = await default_client.delete( + "/iam/group/permission?org_id=1&group_id=1&perm_id=1" + ) + data = resp.json() - assert resp.status_code == 200 - assert "permissions" in data - assert isinstance(data["permissions"], list) - assert len(data["permissions"]) == 0 - assert "group" in data - assert data["group"]["id"] == 1 - assert data["group"]["name"] == "Org One Group" + assert resp.status_code == 200 + assert "permissions" in data + assert isinstance(data["permissions"], list) + assert len(data["permissions"]) == 0 + assert "group" in data + assert data["group"]["id"] == 1 + assert data["group"]["name"] == "Org One Group" @pytest.mark.parametrize( - "query, expected_status", - generate_query_and_status(["group_id", "org_id", "perm_id"]), + "query, expected_status", + generate_query_and_status(["group_id", "org_id", "perm_id"]), ) @pytest.mark.anyio async def test_delete_group_permissions_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.delete(f"/iam/group/permission?{query}") - assert resp.status_code == expected_status + resp = await default_client.delete(f"/iam/group/permission?{query}") + assert resp.status_code == expected_status @pytest.mark.anyio async def test_delete_group_permissions_not_in_group(default_client: AsyncClient): - resp = await default_client.delete( - "/iam/group/permission?org_id=1&group_id=1&perm_id=2" - ) - assert resp.status_code == 422 + resp = await default_client.delete( + "/iam/group/permission?org_id=1&group_id=1&perm_id=2" + ) + assert resp.status_code == 422 @pytest.mark.anyio async def test_delete_permissions_success(default_client: AsyncClient): - resp = await default_client.delete("/iam/permission?perm_id=1") + resp = await default_client.delete("/iam/permission?perm_id=1") - assert resp.status_code == 204 + assert resp.status_code == 204 @pytest.mark.anyio async def test_delete_group_users_success(default_client: AsyncClient): - resp = await default_client.delete("/iam/group/user?org_id=1&group_id=1&user_id=1") - data = resp.json() + resp = await default_client.delete("/iam/group/user?org_id=1&group_id=1&user_id=1") + data = resp.json() - assert resp.status_code == 200 - assert "users" in data - assert isinstance(data["users"], list) - assert len(data["users"]) == 0 - assert "group" in data - assert data["group"]["id"] == 1 - assert data["group"]["name"] == "Org One Group" + assert resp.status_code == 200 + assert "users" in data + assert isinstance(data["users"], list) + assert len(data["users"]) == 0 + assert "group" in data + assert data["group"]["id"] == 1 + assert data["group"]["name"] == "Org One Group" @pytest.mark.anyio async def test_put_group_user_invitation_success(default_client: AsyncClient): - body = {"user_email": "admin@test.com", "organisation_id": 1, "group_id": 1} - resp = await default_client.put("/iam/group/user/invitation", json=body) + body = {"user_email": "admin@test.com", "organisation_id": 1, "group_id": 1} + resp = await default_client.put("/iam/group/user/invitation", json=body) - assert resp.status_code == 200 - data = resp.json() - assert "organisation" in data - assert isinstance(data["organisation"], dict) - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert resp.status_code == 200 + data = resp.json() + assert "organisation" in data + assert isinstance(data["organisation"], dict) + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" - assert "invited_email" in data - assert isinstance(data["invited_email"], str) - assert data["invited_email"] == "admin@test.com" + assert "invited_email" in data + assert isinstance(data["invited_email"], str) + assert data["invited_email"] == "admin@test.com" - assert "group" in data - assert isinstance(data["group"], dict) - assert data["group"]["name"] == "Org One Group" - assert data["group"]["id"] == 1 + assert "group" in data + assert isinstance(data["group"], dict) + assert data["group"]["name"] == "Org One Group" + assert data["group"]["id"] == 1 @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42, "user_email": "admin@test.com", "group_id": 1}, 404), - ( - { - "organisation_id": "Test Org", - "user_email": "admin@test.com", - "group_id": 1, - }, - 422, - ), - ({"organisation_id": "", "user_email": "admin@test.com", "group_id": 1}, 422), - ({}, 422), - ({"user_email": 42, "group_id": 1}, 422), - ({"organisation_id": 1, "user_email": "Test User", "group_id": 1}, 422), - ({"organisation_id": 1, "user_email": "admin@test.com", "group_id": 42}, 404), - ({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422), - ({"organisation_id": "", "user_email": "admin@test.com"}, 422), - ({"user_email": 42}, 422), - ], + "body, expected_status", + [ + ({"organisation_id": 42, "user_email": "admin@test.com", "group_id": 1}, 404), + ( + { + "organisation_id": "Test Org", + "user_email": "admin@test.com", + "group_id": 1, + }, + 422, + ), + ({"organisation_id": "", "user_email": "admin@test.com", "group_id": 1}, 422), + ({}, 422), + ({"user_email": 42, "group_id": 1}, 422), + ({"organisation_id": 1, "user_email": "Test User", "group_id": 1}, 422), + ({"organisation_id": 1, "user_email": "admin@test.com", "group_id": 42}, 404), + ({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422), + ({"organisation_id": "", "user_email": "admin@test.com"}, 422), + ({"user_email": 42}, 422), + ], ) @pytest.mark.anyio async def test_put_group_user_invitation_status_checks( - default_client: AsyncClient, body, expected_status + default_client: AsyncClient, body, expected_status ): - resp = await default_client.put("/iam/group/user/invitation", json=body) + resp = await default_client.put("/iam/group/user/invitation", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "body, expected_status", - [ - ({"jwt": "invalid"}, 401), - ({"jwt": ""}, 401), - ({"jwt": None}, 422), - ({"jwt": 42}, 422), - ], + "body, expected_status", + [ + ({"jwt": "invalid"}, 401), + ({"jwt": ""}, 401), + ({"jwt": None}, 422), + ({"jwt": 42}, 422), + ], ) @pytest.mark.anyio async def test_put_group_user_invitation_accept_status_checks( - default_client: AsyncClient, body, expected_status + default_client: AsyncClient, body, expected_status ): - resp = await default_client.put("/iam/group/user/invitation/accept", json=body) + resp = await default_client.put("/iam/group/user/invitation/accept", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status - if resp.status_code == 401: - assert resp.json()["detail"] == "Invalid JWS" + if resp.status_code == 401: + assert resp.json()["detail"] == "Invalid JWS" diff --git a/test/test_organisation.py b/test/test_organisation.py index 8c9cff6..49dd0a3 100644 --- a/test/test_organisation.py +++ b/test/test_organisation.py @@ -9,506 +9,506 @@ from .conftest import generate_query_and_status pytestmark = [ - pytest.mark.org_module, + pytest.mark.org_module, ] @pytest.mark.anyio async def test_get_org_success(default_client: AsyncClient): - resp = await default_client.get("/org?org_id=1") - data = resp.json() + resp = await default_client.get("/org?org_id=1") + data = resp.json() - assert resp.status_code == 200 + assert resp.status_code == 200 - org = data["organisations"][0] + org = data["organisations"][0] - assert isinstance(org, dict) - assert org["organisation_id"] == 1 - assert org["name"] == "Org One" - assert org["status"] == "approved" - assert org["root_user_email"] == "admin@test.com" - assert "intake_questionnaire" in org - assert isinstance(org["intake_questionnaire"], dict) + assert isinstance(org, dict) + assert org["organisation_id"] == 1 + assert org["name"] == "Org One" + assert org["status"] == "approved" + assert org["root_user_email"] == "admin@test.com" + assert "intake_questionnaire" in org + assert isinstance(org["intake_questionnaire"], dict) - assert org["billing_contact"]["email"] == "billing@orgone.com" - assert org["owner_contact"]["email"] == "owner@orgone.com" - assert org["security_contact"]["email"] == "security@orgone.com" + assert org["billing_contact"]["email"] == "billing@orgone.com" + assert org["owner_contact"]["email"] == "owner@orgone.com" + assert org["security_contact"]["email"] == "security@orgone.com" @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_org_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/org?{query}") + resp = await default_client.get(f"/org?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_org_success(default_client: AsyncClient): - resp = await default_client.post("/org", json={"name": "New Test Org"}) - data = resp.json() + resp = await default_client.post("/org", json={"name": "New Test Org"}) + data = resp.json() - assert resp.status_code == 201 - assert data["name"] == "New Test Org" - assert data["status"] == "partial" + assert resp.status_code == 201 + assert data["name"] == "New Test Org" + assert data["status"] == "partial" @pytest.mark.parametrize( - "body, expected_status", - [ - ({"name": 42}, 422), - ({}, 422), - ({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422), - ], + "body, expected_status", + [ + ({"name": 42}, 422), + ({}, 422), + ({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422), + ], ) @pytest.mark.anyio async def test_post_org_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.post("/org", json=body) + resp = await default_client.post("/org", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient): - resp = await default_client.patch( - "/org/questionnaire", - json={ - "organisation_id": 3, - "intake_questionnaire": { - "question_one": "new answer one", - "question_two": None, - "question_three": None, - }, - "partial": True, - }, - ) - data = resp.json() + resp = await default_client.patch( + "/org/questionnaire", + json={ + "organisation_id": 3, + "intake_questionnaire": { + "question_one": "new answer one", + "question_two": None, + "question_three": None, + }, + "partial": True, + }, + ) + data = resp.json() - assert resp.status_code == 200 - assert data["name"] == "Org Three" - assert data["status"] == "partial" - assert "intake_questionnaire" in data - assert isinstance(data["intake_questionnaire"], dict) - metadata = data["intake_questionnaire"]["metadata"] - assert metadata["version"] == 0 - assert metadata["submission_date"] is None - questions = data["intake_questionnaire"]["questions"] - assert questions["question_one"] == "new answer one" - assert questions["question_two"] == "answer two" - assert questions["question_three"] is None + assert resp.status_code == 200 + assert data["name"] == "Org Three" + assert data["status"] == "partial" + assert "intake_questionnaire" in data + assert isinstance(data["intake_questionnaire"], dict) + metadata = data["intake_questionnaire"]["metadata"] + assert metadata["version"] == 0 + assert metadata["submission_date"] is None + questions = data["intake_questionnaire"]["questions"] + assert questions["question_one"] == "new answer one" + assert questions["question_two"] == "answer two" + assert questions["question_three"] is None @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42}, 404), - ({"organisation_id": "Org One"}, 422), - ({"organisation_id": ""}, 422), - ({}, 422), - ( - { - "organisation_id": "1", - "intake_questionnaire": {"question_one": 42}, - "partial": True, - }, - 422, - ), - ( - {"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}}, - 422, - ), - ( - { - "organisation_id": "1", - "intake_questionnaire": {"question_one": "valid"}, - "partial": 42, - }, - 422, - ), - ], + "body, expected_status", + [ + ({"organisation_id": 42}, 404), + ({"organisation_id": "Org One"}, 422), + ({"organisation_id": ""}, 422), + ({}, 422), + ( + { + "organisation_id": "1", + "intake_questionnaire": {"question_one": 42}, + "partial": True, + }, + 422, + ), + ( + {"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}}, + 422, + ), + ( + { + "organisation_id": "1", + "intake_questionnaire": {"question_one": "valid"}, + "partial": 42, + }, + 422, + ), + ], ) @pytest.mark.anyio async def test_patch_questionnaire_partial_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.patch("/org/questionnaire", json=body) + resp = await default_client.patch("/org/questionnaire", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient): - resp = await default_client.patch( - "/org/questionnaire", - json={ - "organisation_id": 3, - "intake_questionnaire": { - "question_one": "new answer one", - "question_two": None, - "question_three": None, - }, - "partial": False, - }, - ) - data = resp.json() + resp = await default_client.patch( + "/org/questionnaire", + json={ + "organisation_id": 3, + "intake_questionnaire": { + "question_one": "new answer one", + "question_two": None, + "question_three": None, + }, + "partial": False, + }, + ) + data = resp.json() - assert resp.status_code == 200 - assert data["name"] == "Org Three" - assert data["status"] == "submitted" - assert "intake_questionnaire" in data - assert isinstance(data["intake_questionnaire"], dict) - metadata = data["intake_questionnaire"]["metadata"] - assert metadata["version"] == 0 - assert metadata["submission_date"] is not None - questions = data["intake_questionnaire"]["questions"] - assert questions["question_one"] == "new answer one" - assert questions["question_two"] == "answer two" - assert questions["question_three"] is None + assert resp.status_code == 200 + assert data["name"] == "Org Three" + assert data["status"] == "submitted" + assert "intake_questionnaire" in data + assert isinstance(data["intake_questionnaire"], dict) + metadata = data["intake_questionnaire"]["metadata"] + assert metadata["version"] == 0 + assert metadata["submission_date"] is not None + questions = data["intake_questionnaire"]["questions"] + assert questions["question_one"] == "new answer one" + assert questions["question_two"] == "answer two" + assert questions["question_three"] is None @pytest.mark.parametrize( - "status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"] + "status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"] ) @pytest.mark.anyio async def test_patch_org_status_success(default_client: AsyncClient, status: str): - resp = await default_client.patch( - "/org/status", json={"organisation_id": 1, "status": status} - ) - data = resp.json() + resp = await default_client.patch( + "/org/status", json={"organisation_id": 1, "status": status} + ) + data = resp.json() - assert resp.status_code == 200 - assert data["name"] == "Org One" - assert data["status"] == status + assert resp.status_code == 200 + assert data["name"] == "Org One" + assert data["status"] == status @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42}, 404), - ({"organisation_id": "Org One"}, 422), - ({"organisation_id": ""}, 422), - ({}, 422), - ({"organisation_id": "1", "status": True}, 422), - ({"organisation_id": "1", "status": 42}, 422), - ], + "body, expected_status", + [ + ({"organisation_id": 42}, 404), + ({"organisation_id": "Org One"}, 422), + ({"organisation_id": ""}, 422), + ({}, 422), + ({"organisation_id": "1", "status": True}, 422), + ({"organisation_id": "1", "status": 42}, 422), + ], ) @pytest.mark.anyio async def test_patch_org_status_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.patch("/org/status", json=body) + resp = await default_client.patch("/org/status", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_get_org_users_success(default_client: AsyncClient): - resp = await default_client.get("/org/users?org_id=1") - data = resp.json() + resp = await default_client.get("/org/users?org_id=1") + data = resp.json() - assert resp.status_code == 200 + assert resp.status_code == 200 - assert "users" in data - assert isinstance(data["users"], list) - assert len(data["users"]) == 2 + assert "users" in data + assert isinstance(data["users"], list) + assert len(data["users"]) == 2 - user = data["users"][0] - assert isinstance(user, dict) - assert user["email"] == "admin@test.com" - assert user["id"] == 1 + user = data["users"][0] + assert isinstance(user, dict) + assert user["email"] == "admin@test.com" + assert user["id"] == 1 - assert "organisation" in data - assert data["organisation"]["name"] == "Org One" - assert data["organisation"]["id"] == 1 + assert "organisation" in data + assert data["organisation"]["name"] == "Org One" + assert data["organisation"]["id"] == 1 @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_org_users_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/org/users?{query}") + resp = await default_client.get(f"/org/users?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_org_user_success(default_client: AsyncClient): - resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3}) + resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3}) - assert resp.status_code == 200 + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "organisation" in data - assert isinstance(data["organisation"], dict) - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert "organisation" in data + assert isinstance(data["organisation"], dict) + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" - assert "users" in data - assert isinstance(data["users"], list) - assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1 + assert "users" in data + assert isinstance(data["users"], list) + assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1 @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42}, 404), - ({}, 422), - ({"organisation_id": 1, "user_id": "id"}, 422), - ({"user_id": 2}, 422), - ({"organisation_id": 1, "user_id": 42}, 404), - ({"organisation_id": 1, "user_id": 1}, 409), - ], + "body, expected_status", + [ + ({"organisation_id": 42}, 404), + ({}, 422), + ({"organisation_id": 1, "user_id": "id"}, 422), + ({"user_id": 2}, 422), + ({"organisation_id": 1, "user_id": 42}, 404), + ({"organisation_id": 1, "user_id": 1}, 409), + ], ) @pytest.mark.anyio async def test_post_org_user_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.post("/org/user", json=body) + resp = await default_client.post("/org/user", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_patch_org_root_user_success(default_client: AsyncClient): - resp = await default_client.patch( - "/org/root_user", json={"organisation_id": 1, "user_id": 2} - ) - assert resp.status_code == 200 + resp = await default_client.patch( + "/org/root_user", json={"organisation_id": 1, "user_id": 2} + ) + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert data["name"] == "Org One" - assert data["root_user_email"] == "user@orgone.com" + assert data["name"] == "Org One" + assert data["root_user_email"] == "user@orgone.com" @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42, "user_id": 2}, 404), - ({"organisation_id": "Org One", "user_id": 2}, 422), - ({"organisation_id": "", "user_id": 2}, 422), - ({}, 422), - ({"user_id": 2}, 422), - ({"user_id": 42}, 404), - ({"organisation_id": 1, "user_id": "Test User"}, 422), - ], + "body, expected_status", + [ + ({"organisation_id": 42, "user_id": 2}, 404), + ({"organisation_id": "Org One", "user_id": 2}, 422), + ({"organisation_id": "", "user_id": 2}, 422), + ({}, 422), + ({"user_id": 2}, 422), + ({"user_id": 42}, 404), + ({"organisation_id": 1, "user_id": "Test User"}, 422), + ], ) @pytest.mark.anyio async def test_patch_root_user_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.patch("/org/root_user", json=body) + resp = await default_client.patch("/org/root_user", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_patch_org_root_user_non_member(default_client: AsyncClient): - resp = await default_client.patch( - "/org/root_user", json={"organisation_id": 1, "user_id": 3} - ) - data = resp.json() + resp = await default_client.patch( + "/org/root_user", json={"organisation_id": 1, "user_id": 3} + ) + data = resp.json() - assert resp.status_code == 422 - assert data["detail"] == "This user does not belong to your organisation." + assert resp.status_code == 422 + assert data["detail"] == "This user does not belong to your organisation." @pytest.mark.anyio async def test_get_org_groups_success(default_client: AsyncClient): - resp = await default_client.get("/org/groups?org_id=1") - assert resp.status_code == 200 + resp = await default_client.get("/org/groups?org_id=1") + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "organisation" in data - assert isinstance(data["organisation"], dict) - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert "organisation" in data + assert isinstance(data["organisation"], dict) + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" - assert "groups" in data - assert isinstance(data["groups"], list) - group = data["groups"][0] - assert isinstance(group, dict) - assert group["id"] == 1 - assert group["name"] == "Org One Group" + assert "groups" in data + assert isinstance(data["groups"], list) + group = data["groups"][0] + assert isinstance(group, dict) + assert group["id"] == 1 + assert group["name"] == "Org One Group" @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_org_groups_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/org/groups?{query}") + resp = await default_client.get(f"/org/groups?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize("contact_type", ["billing", "security", "owner"]) @pytest.mark.anyio async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str): - resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}") - data = resp.json() + resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}") + data = resp.json() - assert resp.status_code == 200 + assert resp.status_code == 200 - assert "organisation" in data - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert "organisation" in data + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" - attributes = [ - "email", - "first_name", - "last_name", - "phonenumber", - "vat_number", - "address", - ] + attributes = [ + "email", + "first_name", + "last_name", + "phonenumber", + "vat_number", + "address", + ] - for attribute in attributes: - assert attribute in data["contact"] + for attribute in attributes: + assert attribute in data["contact"] - address_attributes = [ - "post_office_box_number", - "street_address", - "street_address_line_2", - "locality", - "address_region", - "country_code", - "postal_code", - ] + address_attributes = [ + "post_office_box_number", + "street_address", + "street_address_line_2", + "locality", + "address_region", + "country_code", + "postal_code", + ] - for attribute in address_attributes: - assert attribute in data["contact"]["address"] + for attribute in address_attributes: + assert attribute in data["contact"]["address"] @pytest.mark.parametrize( - "query, expected_status", - [ - ("org_id=42&contact_type=billing", 404), - ("org_id=banana&contact_type=billing", 422), - ("", 422), - ("org_id=1&contact_type=contact", 422), - ("contact_type=billing", 422), - ], + "query, expected_status", + [ + ("org_id=42&contact_type=billing", 404), + ("org_id=banana&contact_type=billing", 422), + ("", 422), + ("org_id=1&contact_type=contact", 422), + ("contact_type=billing", 422), + ], ) @pytest.mark.anyio async def test_get_org_contact_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/org/contact?{query}") + resp = await default_client.get(f"/org/contact?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "key, value", - [ - ("email", "user@example.com"), - ("first_name", "John"), - ("last_name", "Doe"), - ("phonenumber", "+441234567890"), - ("vat_number", "GB123456789"), - ("post_office_box_number", "PO Box 123"), - ("street_address", "123 Example Street"), - ("street_address_line_2", "Suite 4B"), - ("locality", "Glasgow"), - ("address_region", "Glasgow City"), - ("country_code", "GB"), - ("postal_code", "G1 1AA"), - ], + "key, value", + [ + ("email", "user@example.com"), + ("first_name", "John"), + ("last_name", "Doe"), + ("phonenumber", "+441234567890"), + ("vat_number", "GB123456789"), + ("post_office_box_number", "PO Box 123"), + ("street_address", "123 Example Street"), + ("street_address_line_2", "Suite 4B"), + ("locality", "Glasgow"), + ("address_region", "Glasgow City"), + ("country_code", "GB"), + ("postal_code", "G1 1AA"), + ], ) @pytest.mark.anyio async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str): - resp = await default_client.patch( - "/org/contact", - json={"organisation_id": 1, "contact_type": "billing", key: value}, - ) - assert resp.status_code == 200 + resp = await default_client.patch( + "/org/contact", + json={"organisation_id": 1, "contact_type": "billing", key: value}, + ) + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "organisation" in data - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert "organisation" in data + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" - attributes = [ - "email", - "first_name", - "last_name", - "phonenumber", - "vat_number", - "address", - ] + attributes = [ + "email", + "first_name", + "last_name", + "phonenumber", + "vat_number", + "address", + ] - for attribute in attributes: - assert attribute in data["contact"] + for attribute in attributes: + assert attribute in data["contact"] - address_attributes = [ - "post_office_box_number", - "street_address", - "street_address_line_2", - "locality", - "address_region", - "country_code", - "postal_code", - ] + address_attributes = [ + "post_office_box_number", + "street_address", + "street_address_line_2", + "locality", + "address_region", + "country_code", + "postal_code", + ] - for attribute in address_attributes: - assert attribute in data["contact"]["address"] + for attribute in address_attributes: + assert attribute in data["contact"]["address"] - if key in data["contact"]: - assert data["contact"][key] == value - elif key in data["contact"]["address"]: - assert data["contact"]["address"][key] == value - else: - pytest.fail(f"Invalid contact key: {key}") + if key in data["contact"]: + assert data["contact"][key] == value + elif key in data["contact"]["address"]: + assert data["contact"]["address"][key] == value + else: + pytest.fail(f"Invalid contact key: {key}") @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42, "contact_type": "billing"}, 404), - ({"organisation_id": 1, "contact_type": "security"}, 200), - ({"organisation_id": 1, "contact_type": "owner"}, 200), - ({"organisation_id": "Org One", "contact_type": "billing"}, 422), - ({"organisation_id": "", "contact_type": "billing"}, 422), - ({}, 422), - ({"organisation_id": 1, "contact_type": "not_real"}, 422), - ({"organisation_id": 1, "contact_type": 42}, 422), - ({"organisation_id": 1, "contact_type": ""}, 422), - ], + "body, expected_status", + [ + ({"organisation_id": 42, "contact_type": "billing"}, 404), + ({"organisation_id": 1, "contact_type": "security"}, 200), + ({"organisation_id": 1, "contact_type": "owner"}, 200), + ({"organisation_id": "Org One", "contact_type": "billing"}, 422), + ({"organisation_id": "", "contact_type": "billing"}, 422), + ({}, 422), + ({"organisation_id": 1, "contact_type": "not_real"}, 422), + ({"organisation_id": 1, "contact_type": 42}, 422), + ({"organisation_id": 1, "contact_type": ""}, 422), + ], ) @pytest.mark.anyio async def test_patch_org_contact_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.patch("/org/contact", json=body) + resp = await default_client.patch("/org/contact", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_delete_org_success(default_client: AsyncClient): - resp = await default_client.delete("/org?org_id=1") + resp = await default_client.delete("/org?org_id=1") - assert resp.status_code == 204 + assert resp.status_code == 204 @pytest.mark.anyio async def test_delete_org_users_success(default_client: AsyncClient): - resp = await default_client.delete("/org/user?org_id=1&user_id=2") + resp = await default_client.delete("/org/user?org_id=1&user_id=2") - assert resp.status_code == 204 + assert resp.status_code == 204 @pytest.mark.anyio async def test_delete_preapproval_org_success(default_client: AsyncClient): - resp = await default_client.delete("/org/self?org_id=3") + resp = await default_client.delete("/org/self?org_id=3") - assert resp.status_code == 204 + assert resp.status_code == 204 diff --git a/test/test_service.py b/test/test_service.py index 43e8bc4..9b9b98b 100644 --- a/test/test_service.py +++ b/test/test_service.py @@ -8,90 +8,90 @@ from httpx import AsyncClient from .conftest import generate_query_and_status, generate_body_and_status pytestmark = [ - pytest.mark.service_module, + pytest.mark.service_module, ] @pytest.mark.anyio async def test_get_services_success(default_client: AsyncClient): - resp = await default_client.get("/service?org_id=1") - data = resp.json() + resp = await default_client.get("/service?org_id=1") + data = resp.json() - assert resp.status_code == 200 - assert "services" in data - assert isinstance(data["services"], list) - assert data["services"][0]["id"] == 1 - assert data["services"][0]["name"] == "Test Service" + assert resp.status_code == 200 + assert "services" in data + assert isinstance(data["services"], list) + assert data["services"][0]["id"] == 1 + assert data["services"][0]["name"] == "Test Service" @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_services_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/service?{query}") + resp = await default_client.get(f"/service?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_service_success(default_client: AsyncClient): - resp = await default_client.post("/service", json={"name": "New Test Service"}) - data = resp.json() + resp = await default_client.post("/service", json={"name": "New Test Service"}) + data = resp.json() - assert resp.status_code == 200 - assert "service" in data - assert isinstance(data["service"], dict) - assert data["service"]["name"] == "New Test Service" - assert data["service"]["id"] == 2 - assert isinstance(data["service"]["api_key"], str) + assert resp.status_code == 200 + assert "service" in data + assert isinstance(data["service"], dict) + assert data["service"]["name"] == "New Test Service" + assert data["service"]["id"] == 2 + assert isinstance(data["service"]["api_key"], str) @pytest.mark.parametrize("body, expected_status", generate_body_and_status({"name": "str"})) @pytest.mark.anyio async def test_post_service_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.post("/service", json=body) + resp = await default_client.post("/service", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_service_conflict(default_client: AsyncClient): - resp = await default_client.post("/service", json={"name": "Test Service"}) + resp = await default_client.post("/service", json={"name": "Test Service"}) - assert resp.status_code == 409 + assert resp.status_code == 409 @pytest.mark.anyio async def test_patch_service_success(default_client: AsyncClient): - resp = await default_client.patch("/service/key", json={"service_id": 1}) - data = resp.json() + resp = await default_client.patch("/service/key", json={"service_id": 1}) + data = resp.json() - assert resp.status_code == 200 - assert "service" in data - assert isinstance(data["service"], dict) - assert data["service"]["name"] == "Test Service" - assert data["service"]["id"] == 1 - assert isinstance(data["service"]["api_key"], str) + assert resp.status_code == 200 + assert "service" in data + assert isinstance(data["service"], dict) + assert data["service"]["name"] == "Test Service" + assert data["service"]["id"] == 1 + assert isinstance(data["service"]["api_key"], str) @pytest.mark.parametrize( - "body, expected_status", - generate_body_and_status({"service_id": "int"}), + "body, expected_status", + generate_body_and_status({"service_id": "int"}), ) @pytest.mark.anyio async def test_patch_services_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int + default_client: AsyncClient, body: dict[str, str], expected_status: int ): - resp = await default_client.patch("/service/key", json=body) + resp = await default_client.patch("/service/key", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_delete_service_success(default_client: AsyncClient): - resp = await default_client.delete("/service?service_id=1") + resp = await default_client.delete("/service?service_id=1") - assert resp.status_code == 204 + assert resp.status_code == 204 diff --git a/test/test_user.py b/test/test_user.py index a497fa9..b018841 100644 --- a/test/test_user.py +++ b/test/test_user.py @@ -11,191 +11,191 @@ from .conftest import generate_query_and_status pytestmark = [ - pytest.mark.user_module, + pytest.mark.user_module, ] @pytest.mark.anyio async def test_get_self_db_success(default_client: AsyncClient): - resp = await default_client.get("/user/self/db") - data = resp.json() + resp = await default_client.get("/user/self/db") + data = resp.json() - assert resp.status_code == 200 - assert data["first_name"] == "Admin" - assert data["last_name"] == "Test" - assert data["email"] == "admin@test.com" - assert "organisations" in data - assert isinstance(data["organisations"], list) - assert "groups" in data - assert isinstance(data["groups"], dict) + assert resp.status_code == 200 + assert data["first_name"] == "Admin" + assert data["last_name"] == "Test" + assert data["email"] == "admin@test.com" + assert "organisations" in data + assert isinstance(data["organisations"], list) + assert "groups" in data + assert isinstance(data["groups"], dict) @pytest.mark.anyio async def test_get_user_success(default_client: AsyncClient): - resp = await default_client.get("/user?user_id=1") - data = resp.json() + resp = await default_client.get("/user?user_id=1") + data = resp.json() - assert resp.status_code == 200 - assert data["first_name"] == "Admin" - assert data["last_name"] == "Test" - assert data["email"] == "admin@test.com" - assert "organisations" in data - assert isinstance(data["organisations"], list) - assert "groups" in data - assert isinstance(data["groups"], dict) + assert resp.status_code == 200 + assert data["first_name"] == "Admin" + assert data["last_name"] == "Test" + assert data["email"] == "admin@test.com" + assert "organisations" in data + assert isinstance(data["organisations"], list) + assert "groups" in data + assert isinstance(data["groups"], dict) @pytest.mark.anyio @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["user_id"])) async def test_get_user_status_checks( - default_client: AsyncClient, query: str, expected_status: int + default_client: AsyncClient, query: str, expected_status: int ): - resp = await default_client.get(f"/user?{query}") + resp = await default_client.get(f"/user?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.anyio async def test_delete_user_success(default_client: AsyncClient): - resp = await default_client.delete("/user?user_id=1") + resp = await default_client.delete("/user?user_id=1") - assert resp.status_code == 204 + assert resp.status_code == 204 @pytest.mark.anyio async def test_post_user_invitation_success(default_client: AsyncClient): - body = {"user_email": "admin@test.com", "organisation_id": 1} - resp = await default_client.post("/user/invitation", json=body) + body = {"user_email": "admin@test.com", "organisation_id": 1} + resp = await default_client.post("/user/invitation", json=body) - assert resp.status_code == 200 - data = resp.json() - assert "organisation" in data - assert isinstance(data["organisation"], dict) - assert data["organisation"]["id"] == 1 - assert data["organisation"]["name"] == "Org One" + assert resp.status_code == 200 + data = resp.json() + assert "organisation" in data + assert isinstance(data["organisation"], dict) + assert data["organisation"]["id"] == 1 + assert data["organisation"]["name"] == "Org One" - assert "invited_email" in data - assert isinstance(data["invited_email"], str) - assert data["invited_email"] == "admin@test.com" + assert "invited_email" in data + assert isinstance(data["invited_email"], str) + assert data["invited_email"] == "admin@test.com" @pytest.mark.parametrize( - "body, expected_status", - [ - ({"organisation_id": 42, "user_email": "admin@test.com"}, 404), - ({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422), - ({"organisation_id": "", "user_email": "admin@test.com"}, 422), - ({}, 422), - ({"user_email": 42}, 422), - ({"organisation_id": 1, "user_email": "Test User"}, 422), - ], + "body, expected_status", + [ + ({"organisation_id": 42, "user_email": "admin@test.com"}, 404), + ({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422), + ({"organisation_id": "", "user_email": "admin@test.com"}, 422), + ({}, 422), + ({"user_email": 42}, 422), + ({"organisation_id": 1, "user_email": "Test User"}, 422), + ], ) @pytest.mark.anyio async def test_post_user_invitation_status_checks( - default_client: AsyncClient, body, expected_status + default_client: AsyncClient, body, expected_status ): - resp = await default_client.post("/user/invitation", json=body) + resp = await default_client.post("/user/invitation", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status @pytest.mark.parametrize( - "body, expected_status", - [ - ({"jwt": "invalid"}, 401), - ({"jwt": ""}, 401), - ({"jwt": None}, 422), - ({"jwt": 42}, 422), - ], + "body, expected_status", + [ + ({"jwt": "invalid"}, 401), + ({"jwt": ""}, 401), + ({"jwt": None}, 422), + ({"jwt": 42}, 422), + ], ) @pytest.mark.anyio async def test_post_user_invitation_accept_status_checks( - default_client: AsyncClient, body, expected_status + default_client: AsyncClient, body, expected_status ): - resp = await default_client.post("/user/invitation/accept", json=body) + resp = await default_client.post("/user/invitation/accept", json=body) - assert resp.status_code == expected_status + assert resp.status_code == expected_status - if resp.status_code == 401: - assert resp.json()["detail"] == "Invalid JWS" + if resp.status_code == 401: + assert resp.json()["detail"] == "Invalid JWS" @pytest.mark.anyio async def test_get_self_orgs_success(default_client: AsyncClient): - resp = await default_client.get("/user/self/orgs") - assert resp.status_code == 200 + resp = await default_client.get("/user/self/orgs") + assert resp.status_code == 200 - data = resp.json() + data = resp.json() - assert "organisations" in data - assert isinstance(data["organisations"], list) - assert len(data["organisations"]) > 0 + assert "organisations" in data + assert isinstance(data["organisations"], list) + assert len(data["organisations"]) > 0 - org = data["organisations"][0] - assert org["organisation_id"] == 1 - assert org["name"] == "Org One" - assert org["status"] == "approved" - assert org["root_user_email"] == "admin@test.com" - assert "intake_questionnaire" in org - assert isinstance(org["intake_questionnaire"], dict) + org = data["organisations"][0] + assert org["organisation_id"] == 1 + assert org["name"] == "Org One" + assert org["status"] == "approved" + assert org["root_user_email"] == "admin@test.com" + assert "intake_questionnaire" in org + assert isinstance(org["intake_questionnaire"], dict) - assert isinstance(org["billing_contact"], dict) - assert org["billing_contact"]["email"] == "billing@orgone.com" - assert org["billing_contact"]["id"] == 1 + assert isinstance(org["billing_contact"], dict) + assert org["billing_contact"]["email"] == "billing@orgone.com" + assert org["billing_contact"]["id"] == 1 - assert isinstance(org["owner_contact"], dict) - assert org["owner_contact"]["email"] == "owner@orgone.com" - assert org["owner_contact"]["id"] == 2 + assert isinstance(org["owner_contact"], dict) + assert org["owner_contact"]["email"] == "owner@orgone.com" + assert org["owner_contact"]["id"] == 2 - assert isinstance(org["security_contact"], dict) - assert org["security_contact"]["email"] == "security@orgone.com" - assert org["security_contact"]["id"] == 3 + assert isinstance(org["security_contact"], dict) + assert org["security_contact"]["email"] == "security@orgone.com" + assert org["security_contact"]["id"] == 3 @pytest.mark.anyio async def test_get_self_orgs_dynamic(default_client: AsyncClient): - method = "GET" - path = "/user/self/orgs" - expected_data = { - "organisations": [ - { - "organisation_id": 1, - "name": "Org One", - "status": "approved", - "root_user_email": "admin@test.com", - "owner_contact": {"email": "owner@orgone.com", "id": 2}, - "security_contact": {"email": "security@orgone.com", "id": 3}, - "billing_contact": {"email": "billing@orgone.com", "id": 1}, - "intake_questionnaire": { - "questions": { - "question_one": None, - "question_three": None, - "question_two": "answer two", - }, - "metadata": {"version": 0, "submission_date": None}, - }, - } - ] - } + method = "GET" + path = "/user/self/orgs" + expected_data = { + "organisations": [ + { + "organisation_id": 1, + "name": "Org One", + "status": "approved", + "root_user_email": "admin@test.com", + "owner_contact": {"email": "owner@orgone.com", "id": 2}, + "security_contact": {"email": "security@orgone.com", "id": 3}, + "billing_contact": {"email": "billing@orgone.com", "id": 1}, + "intake_questionnaire": { + "questions": { + "question_one": None, + "question_three": None, + "question_two": "answer two", + }, + "metadata": {"version": 0, "submission_date": None}, + }, + } + ] + } - resp = await default_client.get(path) + resp = await default_client.get(path) - route = next( - route - for route in default_client._transport.app.routes # ty:ignore[unresolved-attribute] - if isinstance(route, APIRoute) and path in route.path and method in route.methods - ) + route = next( + route + for route in default_client._transport.app.routes # ty:ignore[unresolved-attribute] + if isinstance(route, APIRoute) and path in route.path and method in route.methods + ) - assert resp.status_code == route.status_code - if route.status_code == 204: - return + assert resp.status_code == route.status_code + if route.status_code == 204: + return - expected_response_schema = route.response_model - data = resp.json() + expected_response_schema = route.response_model + data = resp.json() - response_model = expected_response_schema(**data) - assert isinstance(response_model, expected_response_schema) + response_model = expected_response_schema(**data) + assert isinstance(response_model, expected_response_schema) - expected_response_model = expected_response_schema(**expected_data) + expected_response_model = expected_response_schema(**expected_data) - assert response_model == expected_response_model + assert response_model == expected_response_model From ee47186c5a69acb9f3e469877e991bef5369d5bb Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 15:12:34 +0100 Subject: [PATCH 21/33] fix(db): generator types --- src/database.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/database.py b/src/database.py index 01c55c8..60ec470 100644 --- a/src/database.py +++ b/src/database.py @@ -41,7 +41,7 @@ def get_db_connection() -> Generator[Connection, None, None]: raise -def _get_db_connection() -> Generator[Connection, None]: +def _get_db_connection() -> Generator[Connection, None, None]: with get_db_connection() as connection: yield connection @@ -61,7 +61,7 @@ def get_db_session() -> Generator[Session, None, None]: session.close() -def _get_db_session() -> Generator[Session, None]: +def _get_db_session() -> Generator[Session, None, None]: with get_db_session() as session: yield session From 4b3ab92d2aeeff2e62a88913cfe695d82630f70c Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 15:15:31 +0100 Subject: [PATCH 22/33] fix: fastapi 0.137 router.route changes --- test/conftest.py | 29 +++++++++++++++-------------- test/test_user.py | 13 +++++++++---- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index 022b8dc..cfa329d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,7 +2,7 @@ import pytest from typing import AsyncGenerator from itertools import combinations -from fastapi.routing import APIRoute +from fastapi.routing import APIRoute, iter_route_contexts from httpx import AsyncClient, ASGITransport from sqlalchemy.orm import sessionmaker @@ -258,12 +258,13 @@ def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]: def get_testable_routes(): routes = [] - for route in app.routes: - if not isinstance(route, APIRoute): - continue + contexts = list(iter_route_contexts(app.routes)) + + for route in contexts: if not route.methods: continue - + if not isinstance(route.route, APIRoute): + continue for method in route.methods: if method in {"HEAD", "OPTIONS"}: continue @@ -271,10 +272,10 @@ def get_testable_routes(): routes.append( ( method, - route.path, - route.status_code, - route.response_model, - route.summary, + route.route.path, + route.route.status_code, + route.route.response_model, + route.route.summary, ) ) @@ -282,11 +283,11 @@ def get_testable_routes(): # with open("endpoints.txt", "w") as f: -# for ep in get_testable_routes(): -# f.write(f"[{ep[0]}]({ep[1]}) -> {ep[2]}: {ep[3]}\n") +# for ep in get_testable_routes(): +# f.write(f"[{ep[0]}]({ep[1]}) -> {ep[2]}: {ep[3]}\n") # # ### Docstring formatted output ### -# with open("endpoints.txt", "w") as f: -# for ep in get_testable_routes(): -# f.write(f"- [{ep[0]}]({ep[1]}): []: {ep[4]}\n") +with open("endpoints.txt", "w") as f: + for ep in get_testable_routes(): + f.write(f"- [{ep[0]}]({ep[1]}): []: {ep[4]}\n") diff --git a/test/test_user.py b/test/test_user.py index b018841..a1f693c 100644 --- a/test/test_user.py +++ b/test/test_user.py @@ -5,7 +5,7 @@ import pytest from httpx import AsyncClient -from fastapi.routing import APIRoute +from fastapi.routing import APIRoute, iter_route_contexts from .conftest import generate_query_and_status @@ -180,10 +180,15 @@ async def test_get_self_orgs_dynamic(default_client: AsyncClient): resp = await default_client.get(path) + contexts = list(iter_route_contexts(default_client._transport.app.routes)) # ty:ignore[unresolved-attribute] + route = next( - route - for route in default_client._transport.app.routes # ty:ignore[unresolved-attribute] - if isinstance(route, APIRoute) and path in route.path and method in route.methods + route.route + for route in contexts + if isinstance(route.route, APIRoute) + and path in route.route.path + and isinstance(route.methods, set) + and method in route.methods ) assert resp.status_code == route.status_code From e7bd455b2dd878f77ab4f1908eaf4234795263ad Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 15:18:13 +0100 Subject: [PATCH 23/33] ci: run the build step somewhere --- .forgejo/workflows/publish.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.forgejo/workflows/publish.yaml b/.forgejo/workflows/publish.yaml index a16e5a9..941a8e6 100644 --- a/.forgejo/workflows/publish.yaml +++ b/.forgejo/workflows/publish.yaml @@ -56,6 +56,7 @@ jobs: build: needs: [ ruff, ty, tests ] if: ${{ always() && needs.ruff.result == 'success' && needs.ty.result == 'success' && needs.tests.result == 'success' }} + runs-on: docker container: image: ghcr.io/catthehacker/ubuntu:act-latest options: -v /dind/docker.sock:/var/run/docker.sock From a481be835294860187a8b33ff8f3d82ca5f449db Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 15:20:14 +0100 Subject: [PATCH 24/33] ci: check out the frontend repo --- .forgejo/workflows/publish.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.forgejo/workflows/publish.yaml b/.forgejo/workflows/publish.yaml index 941a8e6..f8e3490 100644 --- a/.forgejo/workflows/publish.yaml +++ b/.forgejo/workflows/publish.yaml @@ -63,6 +63,11 @@ jobs: steps: - name: Checkout the repo uses: actions/checkout@v4 + - name: Checkout the frontend + uses: actions/checkout@v4 + with: + repository: https://guardianproject.dev/sr2/cloud-portal.git + path: frontend - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Login to the registry From 20615f438a01b743f4f3c46e3309fa031e0404d7 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 15:21:31 +0100 Subject: [PATCH 25/33] ci: fix branch name tag --- .forgejo/workflows/publish.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.forgejo/workflows/publish.yaml b/.forgejo/workflows/publish.yaml index f8e3490..ff365c5 100644 --- a/.forgejo/workflows/publish.yaml +++ b/.forgejo/workflows/publish.yaml @@ -81,4 +81,4 @@ jobs: with: file: Containerfile push: true - tags: guardianproject.dev/${{ github.repository }}:${{ github.branch }} + tags: guardianproject.dev/${{ github.repository }}:${{ github.ref }} From 44e1d4986f51374b384f26f61105db9c842926de Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 15:22:42 +0100 Subject: [PATCH 26/33] ci: relative repo path --- .forgejo/workflows/publish.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.forgejo/workflows/publish.yaml b/.forgejo/workflows/publish.yaml index ff365c5..3474a10 100644 --- a/.forgejo/workflows/publish.yaml +++ b/.forgejo/workflows/publish.yaml @@ -66,7 +66,7 @@ jobs: - name: Checkout the frontend uses: actions/checkout@v4 with: - repository: https://guardianproject.dev/sr2/cloud-portal.git + repository: sr2/cloud-portal.git path: frontend - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 From cc4ae4264654a744cea3cf911be340008b0311eb Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 15:24:33 +0100 Subject: [PATCH 27/33] ci: adds frontend ref --- .forgejo/workflows/publish.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.forgejo/workflows/publish.yaml b/.forgejo/workflows/publish.yaml index 3474a10..53d3d90 100644 --- a/.forgejo/workflows/publish.yaml +++ b/.forgejo/workflows/publish.yaml @@ -68,6 +68,7 @@ jobs: with: repository: sr2/cloud-portal.git path: frontend + ref: main - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 - name: Login to the registry From 8ab0390977a3486e5b1ba30e8534ba5bb45b6310 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 15:26:39 +0100 Subject: [PATCH 28/33] ci: fix branch name tag again --- .forgejo/workflows/publish.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.forgejo/workflows/publish.yaml b/.forgejo/workflows/publish.yaml index 53d3d90..7273f2a 100644 --- a/.forgejo/workflows/publish.yaml +++ b/.forgejo/workflows/publish.yaml @@ -82,4 +82,4 @@ jobs: with: file: Containerfile push: true - tags: guardianproject.dev/${{ github.repository }}:${{ github.ref }} + tags: guardianproject.dev/${{ github.repository }}:${{ github.ref_name }} From be46e43042893053f0ea4c935bb941e2ce34a9e0 Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 15:28:41 +0100 Subject: [PATCH 29/33] fix(db): user active default true --- .../2026-06-22_fix_user_activated_default.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 alembic/versions/2026-06-22_fix_user_activated_default.py diff --git a/alembic/versions/2026-06-22_fix_user_activated_default.py b/alembic/versions/2026-06-22_fix_user_activated_default.py new file mode 100644 index 0000000..dfe789b --- /dev/null +++ b/alembic/versions/2026-06-22_fix_user_activated_default.py @@ -0,0 +1,32 @@ +"""fix user activated default + +Revision ID: ae433e1c3b20 +Revises: 661202797ecd +Create Date: 2026-06-22 15:26:57.805129 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'ae433e1c3b20' +down_revision: Union[str, Sequence[str], None] = '661202797ecd' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('user', 'active', server_default=sa.true()) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('user', 'active', server_default=sa.false()) + # ### end Alembic commands ### From 5b98be9787f1480a9331b5d4d39e612d732ca511 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 15:30:19 +0100 Subject: [PATCH 30/33] ci: define context for docker --- .forgejo/workflows/publish.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.forgejo/workflows/publish.yaml b/.forgejo/workflows/publish.yaml index 7273f2a..2cdff17 100644 --- a/.forgejo/workflows/publish.yaml +++ b/.forgejo/workflows/publish.yaml @@ -80,6 +80,7 @@ jobs: - name: Build and push uses: docker/build-push-action@v6 with: - file: Containerfile + file: /workspace/sr2/cloud-api/Containerfile + context: /workspace/sr2/cloud-api/ push: true tags: guardianproject.dev/${{ github.repository }}:${{ github.ref_name }} From a9e059bf0a0415ec9c967a6f3ed456957d51f529 Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 15:36:32 +0100 Subject: [PATCH 31/33] feat: user soft delete --- src/user/router.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/user/router.py b/src/user/router.py index 6c87e24..adcfb57 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -7,6 +7,7 @@ Endpoints: - [GET](/user/): [super admin]: Returns user(id) details. - [DELETE](/user/): [super admin]: Removes a User(id) from the hub database. """ +from datetime import datetime, timezone from fastapi import APIRouter, status, BackgroundTasks @@ -104,7 +105,7 @@ async def get_user_by_id( status.HTTP_404_NOT_FOUND: {"description": "User not found"}, }, ) -async def delete_user_by_id( +async def soft_delete_user_by_id( db: DbSession, user_model: user_model_query_dependency, su: super_admin_dependency, @@ -112,7 +113,8 @@ async def delete_user_by_id( """ Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login. """ - db.delete(user_model) + user_model.active = False + user_model.deleted_at = datetime.now(tz=timezone.utc) db.commit() From bee0dcd4fe1203fccca6e96f17f1c300a3ff3dce Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 16:12:03 +0100 Subject: [PATCH 32/33] feat: soft deleted users access blocked --- src/user/dependencies.py | 10 +++++++--- src/user/router.py | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/user/dependencies.py b/src/user/dependencies.py index 7518c16..d846d79 100644 --- a/src/user/dependencies.py +++ b/src/user/dependencies.py @@ -10,12 +10,13 @@ Exports: from typing import Annotated from fastapi import Depends, Query -from src.user.exceptions import UserNotFoundException -from src.user.models import User - from src.auth.service import claims_dependency from src.database import DbSession from src.schemas import UserIDMixin +from src.exceptions import ForbiddenException + +from src.user.exceptions import UserNotFoundException +from src.user.models import User async def get_user_model_claims(claims: claims_dependency, db: DbSession): @@ -27,6 +28,9 @@ async def get_user_model_claims(claims: claims_dependency, db: DbSession): if user_model is None: raise UserNotFoundException(user_id=user_id) + if not user_model.active: + raise ForbiddenException("User account is not active") + return user_model diff --git a/src/user/router.py b/src/user/router.py index adcfb57..6ea07d8 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -7,6 +7,7 @@ Endpoints: - [GET](/user/): [super admin]: Returns user(id) details. - [DELETE](/user/): [super admin]: Removes a User(id) from the hub database. """ + from datetime import datetime, timezone from fastapi import APIRouter, status, BackgroundTasks From 7dad2e920e68c004fe44476441bb7248d405738d Mon Sep 17 00:00:00 2001 From: luxferre Date: Mon, 22 Jun 2026 16:45:50 +0100 Subject: [PATCH 33/33] tests: get_testable_routes finds auth level Checks all dependencies used on each endpoint and determines the highest level of auth applied to each endpoint. API Key>SU>Root>User>None --- test/conftest.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/test/conftest.py b/test/conftest.py index cfa329d..ffe4065 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,3 +1,4 @@ +from fastapi.dependencies.models import Dependant import pytest from typing import AsyncGenerator @@ -265,6 +266,33 @@ def get_testable_routes(): continue if not isinstance(route.route, APIRoute): continue + + dep_func_names = set() + + unchecked = [] + unchecked.append(route.route.dependant) + while unchecked: + dependant = unchecked.pop(0) + ck = dependant.cache_key[0] + if hasattr(ck, "__name__"): + dep_func_names.add(ck.__name__) + unchecked += [ + dep for dep in dependant.dependencies if isinstance(dep, Dependant) + ] + + auth_level = None + if "get_current_user" in dep_func_names: + auth_level = "User" + if ( + "org_body_root_claims" in dep_func_names + or "org_query_root_claims" in dep_func_names + ): + auth_level = "Root User" + if "user_model_super_admin" in dep_func_names: + auth_level = "Super Admin" + if "valid_service_key" in dep_func_names: + auth_level = "API Key" + for method in route.methods: if method in {"HEAD", "OPTIONS"}: continue @@ -276,18 +304,14 @@ def get_testable_routes(): route.route.status_code, route.route.response_model, route.route.summary, + auth_level, ) ) return routes -# with open("endpoints.txt", "w") as f: -# for ep in get_testable_routes(): -# f.write(f"[{ep[0]}]({ep[1]}) -> {ep[2]}: {ep[3]}\n") -# -# ### Docstring formatted output ### with open("endpoints.txt", "w") as f: for ep in get_testable_routes(): - f.write(f"- [{ep[0]}]({ep[1]}): []: {ep[4]}\n") + f.write(f"- [{ep[0]}]({ep[1]}): [{ep[5]}]: {ep[4]}\n")