From 2a1d28bc5488c1b68bfffade3d1c444777949e47 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 12:58:37 +0100 Subject: [PATCH] feat(db): db tuning options and consistency --- src/auth/service.py | 4 +- src/config.py | 4 ++ src/database.py | 74 ++++++++++++++++++++------------ src/iam/dependencies.py | 10 ++--- src/iam/router.py | 26 +++++------ src/iam/service.py | 6 +-- src/organisation/dependencies.py | 6 +-- src/organisation/router.py | 22 +++++----- src/service/dependencies.py | 6 +-- src/service/router.py | 12 +++--- src/user/dependencies.py | 8 ++-- src/user/router.py | 6 +-- 12 files changed, 104 insertions(+), 80 deletions(-) diff --git a/src/auth/service.py b/src/auth/service.py index 1b90b8c..130a0f2 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -23,7 +23,7 @@ from src.organisation.models import Organisation as Org from src.exceptions import UnauthorizedException, ForbiddenException from src.auth.config import auth_settings from src.user.service import add_user_to_db -from src.database import db_dependency +from src.database import DbSession oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) @@ -35,7 +35,7 @@ async def get_dev_user(): async def get_current_user( - oidc_auth_string: oidc_dependency, db: db_dependency + oidc_auth_string: oidc_dependency, db: DbSession ) -> dict[str, Any]: config_url = urlopen(auth_settings.OIDC_CONFIG) config = json.loads(config_url.read()) diff --git a/src/config.py b/src/config.py index 6ebfcf5..6fc68f0 100644 --- a/src/config.py +++ b/src/config.py @@ -36,6 +36,10 @@ class Config(CustomBaseSettings): DATABASE_HOSTNAME: str = "localhost" DATABASE_CREDENTIALS: SecretStr = SecretStr(":") + DATABASE_POOL_SIZE: int = 16 + DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes + DATABASE_POOL_PRE_PING: bool = True + LETTERMINT_API_TOKEN: SecretStr = SecretStr("") diff --git a/src/database.py b/src/database.py index fb29a41..41509d4 100644 --- a/src/database.py +++ b/src/database.py @@ -1,13 +1,9 @@ """ -Database connections and init - -Exports: - - db_dependency - - Base (sqlalchemy base model) +Database connection and session utilities """ - -from typing import Annotated -from sqlalchemy import create_engine, StaticPool +from contextlib import contextmanager +from typing import Annotated, Generator +from sqlalchemy import create_engine, StaticPool, Connection from sqlalchemy.orm import sessionmaker, Session from fastapi import Depends @@ -16,28 +12,52 @@ from src.constants import Environment from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings if global_settings.ENVIRONMENT == Environment.TESTING: - connect_args = {"check_same_thread": False} - engine = create_engine( - SQLALCHEMY_DATABASE_URI.get_secret_value(), - connect_args=connect_args, - poolclass=StaticPool, - ) + connect_args = {"check_same_thread": False} + engine = create_engine( + SQLALCHEMY_DATABASE_URI.get_secret_value(), + connect_args=connect_args, + poolclass=StaticPool, + ) else: - engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value()) + engine = create_engine( + SQLALCHEMY_DATABASE_URI.get_secret_value(), + pool_size=global_settings.DATABASE_POOL_SIZE, + pool_recycle=global_settings.DATABASE_POOL_TTL, + pool_pre_ping=global_settings.DATABASE_POOL_PRE_PING, + ) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine) + +@contextmanager +def get_db_connection() -> Generator[Connection, None, None]: + with engine.connect() as connection: + try: + yield connection + except Exception: + connection.rollback() + raise + +def _get_db_connection() -> Generator[Connection, None]: + with get_db_connection() as connection: + yield connection + +DbConnection = Annotated[Connection, Depends(_get_db_connection)] + +@contextmanager +def get_db_session() -> Generator[Session, None, None]: + session = sm() + try: + yield session + except Exception: + session.rollback() + raise + finally: + session.close() -def get_db(): - db = SessionLocal() - try: - yield db - except: - db.rollback() - raise - finally: - db.close() +def _get_db_session() -> Generator[Session, None]: + with get_db_session() as session: + yield session - -db_dependency = Annotated[Session, Depends(get_db)] +DbSession = Annotated[Session, Depends(_get_db_session)] diff --git a/src/iam/dependencies.py b/src/iam/dependencies.py index 14fb4ad..113a3c6 100644 --- a/src/iam/dependencies.py +++ b/src/iam/dependencies.py @@ -11,7 +11,7 @@ from typing import Annotated, Optional from fastapi import Depends, Query -from src.database import db_dependency +from src.database import DbSession from src.iam.models import Group, Permission from src.iam.exceptions import GroupNotFoundException, PermNotFoundException @@ -19,7 +19,7 @@ from src.iam.schemas import GroupIDMixin, PermIDMixin def get_group_model_query( - db: db_dependency, group_id: Annotated[int, Query(gt=0)] + db: DbSession, group_id: Annotated[int, Query(gt=0)] ) -> Group: group_model = db.get(Group, group_id) if group_model is None: @@ -32,7 +32,7 @@ group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)] def get_group_model_body( - db: db_dependency, request_model: Optional[GroupIDMixin] = None + db: DbSession, request_model: Optional[GroupIDMixin] = None ) -> Group: group_id = getattr(request_model, "group_id", None) if group_id is None: @@ -48,7 +48,7 @@ group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)] def get_perm_model_body( - db: db_dependency, request_model: Optional[PermIDMixin] = None + db: DbSession, request_model: Optional[PermIDMixin] = None ) -> Permission: perm_id = getattr(request_model, "permission_id", None) if perm_id is None: @@ -64,7 +64,7 @@ perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)] def get_perm_model_query( - db: db_dependency, perm_id: Annotated[int, Query(gt=0)] + db: DbSession, perm_id: Annotated[int, Query(gt=0)] ) -> Permission: perm_model = db.get(Permission, perm_id) if perm_model is None: diff --git a/src/iam/router.py b/src/iam/router.py index af783a1..1b2c297 100644 --- a/src/iam/router.py +++ b/src/iam/router.py @@ -32,7 +32,7 @@ from src.exceptions import ( ForbiddenException, UnprocessableContentException, ) -from src.database import db_dependency +from src.database import DbSession from src.auth.service import claims_dependency from src.auth.dependencies import ( org_model_root_claim_query_dependency, @@ -107,7 +107,7 @@ router = APIRouter( ) async def can_act_on_resource( valid_key: service_key_dependency, - db: db_dependency, + db: DbSession, user_claims: claims_dependency, request_model: IAMCAoRRequest, ): @@ -270,7 +270,7 @@ async def get_group_users( }, ) async def create_group( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_body_dependency, request_model: IAMPostGroupRequest, ): @@ -310,7 +310,7 @@ async def create_group( }, ) async def add_group_permission( - db: db_dependency, + db: DbSession, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, @@ -356,7 +356,7 @@ async def add_group_permission( }, ) async def add_group_user( - db: db_dependency, + db: DbSession, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, @@ -399,7 +399,7 @@ async def add_group_user( }, ) async def remove_group_permission( - db: db_dependency, + db: DbSession, group_model: group_model_query_dependency, perm_model: perm_model_query_dependency, org_model: org_model_root_claim_query_dependency, @@ -436,7 +436,7 @@ async def remove_group_permission( }, ) async def remove_group_user( - db: db_dependency, + db: DbSession, group_model: group_model_query_dependency, user_model: user_model_query_dependency, org_model: org_model_root_claim_query_dependency, @@ -469,7 +469,7 @@ async def remove_group_user( }, ) async def get_permissions( - db: db_dependency, org_model: org_model_root_claim_query_dependency + db: DbSession, org_model: org_model_root_claim_query_dependency ): """ Returns a full list of permissions. @@ -493,7 +493,7 @@ async def get_permissions( }, ) async def create_new_permission( - db: db_dependency, + db: DbSession, su: super_admin_dependency, request_model: IAMPostPermissionRequest, service_model: service_model_body_dependency, # Used to verify service model exists @@ -529,7 +529,7 @@ async def create_new_permission( responses={}, ) async def delete_permission( - db: db_dependency, + db: DbSession, su: super_admin_dependency, perm_model: perm_model_query_dependency, ): @@ -548,7 +548,7 @@ async def delete_permission( responses={}, ) async def permissions_search( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_body_dependency, request_model: IAMGetPermissionsSearchRequest, ): @@ -632,7 +632,7 @@ async def invitation( }, ) async def accept_invitation( - db: db_dependency, + db: DbSession, user_model: user_model_claims_dependency, request_model: IAMPutGroupInvitationAcceptRequest, ): @@ -678,7 +678,7 @@ async def accept_invitation( }, ) async def add_org_permissions( - db: db_dependency, + db: DbSession, su: super_admin_dependency, org_model: org_model_body_dependency, request_model: IAMPutOrgPermissionsRequest, diff --git a/src/iam/service.py b/src/iam/service.py index 4f23969..2112c6c 100644 --- a/src/iam/service.py +++ b/src/iam/service.py @@ -10,7 +10,7 @@ from datetime import datetime, timedelta, timezone from fastapi import Request, Depends from sqlalchemy.orm import Session -from src.database import db_dependency +from src.database import DbSession from src.exceptions import UnauthorizedException from src.utils import send_email, generate_jwt from src.iam.models import Group @@ -23,7 +23,7 @@ from src.service.schemas import HasServiceName def valid_service_key( - db: db_dependency, request: Request, request_model: HasServiceName + db: DbSession, request: Request, request_model: HasServiceName ) -> bool: rn = request_model.rn api_key = request.headers.get("X-API-Key", None) @@ -90,7 +90,7 @@ async def create_group_and_assign_perms( async def assign_default_group( - db: db_dependency, + db: DbSession, org_model: Org, user_model: User, group_name: str, diff --git a/src/organisation/dependencies.py b/src/organisation/dependencies.py index 1ecdca8..2686560 100644 --- a/src/organisation/dependencies.py +++ b/src/organisation/dependencies.py @@ -10,14 +10,14 @@ from typing import Annotated, Optional from fastapi import Depends, Query -from src.database import db_dependency +from src.database import DbSession from src.organisation.schemas import OrgIDMixin from src.organisation.models import Organisation as Org from src.organisation.exceptions import OrgNotFoundException -def get_org_model_query(db: db_dependency, org_id: Annotated[int, Query(gt=0)]) -> Org: +def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org: org_model = db.get(Org, org_id) if org_model is None: raise OrgNotFoundException(org_id) @@ -27,7 +27,7 @@ def get_org_model_query(db: db_dependency, org_id: Annotated[int, Query(gt=0)]) org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)] -def get_org_model_body(db: db_dependency, request_model: OrgIDMixin) -> Org: +def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org: org_id: Optional[int] = getattr(request_model, "organisation_id", None) if org_id is None: raise OrgNotFoundException() diff --git a/src/organisation/router.py b/src/organisation/router.py index d968e8e..5f58852 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -33,7 +33,7 @@ from src.exceptions import ( from src.contact.models import Contact from src.contact.schemas import ContactAddress from src.contact.exceptions import ContactNotFoundException -from src.database import db_dependency +from src.database import DbSession from src.organisation.schemas_questionnaires import QuestionnaireQuestionsVersion0 from src.organisation.service import assign_defaults from src.user.dependencies import ( @@ -98,7 +98,7 @@ router = APIRouter( }, ) async def get_org_by_id( - db: db_dependency, org_model: org_model_root_claim_query_dependency + db: DbSession, org_model: org_model_root_claim_query_dependency ): """ Returns organisation details including key member email addresses @@ -143,7 +143,7 @@ async def get_org_by_id( }, ) async def create_org( - db: db_dependency, + db: DbSession, user_model: user_model_claims_dependency, request_model: OrgPostOrgRequest, background_tasks: BackgroundTasks, @@ -217,7 +217,7 @@ async def create_org( }, ) async def update_questionnaire( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchQuestionnaireRequest, ): @@ -281,7 +281,7 @@ async def update_questionnaire( }, ) async def update_status( - db: db_dependency, + db: DbSession, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchStatusRequest, @@ -338,7 +338,7 @@ async def get_users(org_model: org_model_root_claim_query_dependency): }, ) async def add_user_to_org( - db: db_dependency, + db: DbSession, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, @@ -380,7 +380,7 @@ async def add_user_to_org( }, ) async def delete_organisation_by_id( - db: db_dependency, + db: DbSession, org_model: org_model_query_dependency, su: super_admin_dependency, ): @@ -450,7 +450,7 @@ async def delete_organisation_by_id( }, ) async def delete_preapproved_organisation_by_id( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_query_dependency, ): """ @@ -478,7 +478,7 @@ async def delete_preapproved_organisation_by_id( }, ) async def update_root_user( - db: db_dependency, + db: DbSession, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, @@ -538,7 +538,7 @@ async def get_org_groups(org_model: org_model_root_claim_query_dependency): }, ) async def remove_user_from_org( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_query_dependency, user_model: user_model_query_dependency, ): @@ -609,7 +609,7 @@ async def get_contact( }, ) async def update_contact( - db: db_dependency, + db: DbSession, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchContactRequest, ): diff --git a/src/service/dependencies.py b/src/service/dependencies.py index bf6b314..ea42c89 100644 --- a/src/service/dependencies.py +++ b/src/service/dependencies.py @@ -9,7 +9,7 @@ Exports: from typing import Annotated from fastapi import Depends, Query -from src.database import db_dependency +from src.database import DbSession from src.service.exceptions import ServiceNotFoundException from src.service.models import Service @@ -17,7 +17,7 @@ from src.service.schemas import ServiceIDMixin async def get_service_model_query( - db: db_dependency, service_id: Annotated[int, Query(gt=0)] + db: DbSession, service_id: Annotated[int, Query(gt=0)] ): service_model = db.get(Service, service_id) if service_model is None: @@ -29,7 +29,7 @@ async def get_service_model_query( service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)] -async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixin): +async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin): service_model = db.get(Service, request_model.service_id) if service_model is None: raise ServiceNotFoundException(service_id=request_model.service_id) diff --git a/src/service/router.py b/src/service/router.py index 143fd7a..54a7a3a 100644 --- a/src/service/router.py +++ b/src/service/router.py @@ -13,7 +13,7 @@ from sqlalchemy.exc import IntegrityError from psycopg.errors import UniqueViolation from src.exceptions import ConflictException -from src.database import db_dependency +from src.database import DbSession from src.auth.dependencies import ( super_admin_dependency, org_model_root_claim_query_dependency, @@ -77,7 +77,7 @@ router = APIRouter( }, ) async def get_all_services( - db: db_dependency, org_model: org_model_root_claim_query_dependency + db: DbSession, org_model: org_model_root_claim_query_dependency ): """ Returns the ID and name of all services registered to the hub. @@ -99,7 +99,7 @@ async def get_all_services( }, ) async def register_service( - db: db_dependency, + db: DbSession, su: super_admin_dependency, request_model: ServicePostServiceRequest, ): @@ -135,7 +135,7 @@ async def register_service( }, ) async def regenerate_api_key( - db: db_dependency, + db: DbSession, su: super_admin_dependency, service_model: service_model_body_dependency, request_model: ServicePatchKeyRequest, @@ -162,7 +162,7 @@ async def regenerate_api_key( }, ) async def remove_service( - db: db_dependency, + db: DbSession, service_model: service_model_query_dependency, su: super_admin_dependency, ): @@ -185,7 +185,7 @@ async def remove_service( }, ) async def service_create_new_permissions( - db: db_dependency, + db: DbSession, request_model: ServicePostPermissionsRequest, valid_key: service_key_dependency, ): diff --git a/src/user/dependencies.py b/src/user/dependencies.py index 0d50daa..de23693 100644 --- a/src/user/dependencies.py +++ b/src/user/dependencies.py @@ -14,11 +14,11 @@ from src.user.exceptions import UserNotFoundException from src.user.models import User from src.auth.service import claims_dependency -from src.database import db_dependency +from src.database import DbSession from src.schemas import UserIDMixin -async def get_user_model_claims(claims: claims_dependency, db: db_dependency): +async def get_user_model_claims(claims: claims_dependency, db: DbSession): user_id = claims.get("db_id", None) if user_id is None: raise UserNotFoundException() @@ -33,7 +33,7 @@ async def get_user_model_claims(claims: claims_dependency, db: db_dependency): user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)] -async def get_user_model_query(db: db_dependency, user_id: Annotated[int, Query(gt=0)]): +async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]): user_model = db.get(User, user_id) if user_model is None: raise UserNotFoundException(user_id=user_id) @@ -44,7 +44,7 @@ async def get_user_model_query(db: db_dependency, user_id: Annotated[int, Query( user_model_query_dependency = Annotated[User, Depends(get_user_model_query)] -async def get_user_model_body(db: db_dependency, request_model: UserIDMixin): +async def get_user_model_body(db: DbSession, request_model: UserIDMixin): user_model = db.get(User, request_model.user_id) if user_model is None: raise UserNotFoundException(user_id=request_model.user_id) diff --git a/src/user/router.py b/src/user/router.py index a5ae4f5..ad0ccb1 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -33,7 +33,7 @@ from src.auth.dependencies import ( org_model_root_claim_body_dependency, ) from src.auth.service import claims_dependency -from src.database import db_dependency +from src.database import DbSession from src.utils import verify_email_token router = APIRouter( @@ -105,7 +105,7 @@ async def get_user_by_id( }, ) async def delete_user_by_id( - db: db_dependency, + db: DbSession, user_model: user_model_query_dependency, su: super_admin_dependency, ): @@ -186,7 +186,7 @@ async def invitation( response_model=UserPostInvitationAcceptResponse, ) async def accept_invitation( - db: db_dependency, + db: DbSession, user_model: user_model_claims_dependency, request_model: UserPostInvitationAcceptRequest, ):