""" Database connection and session utilities """ from contextlib import contextmanager from typing import Annotated, Generator from sqlalchemy import create_engine, StaticPool, Connection from sqlalchemy.orm import sessionmaker, Session from fastapi import Depends from src.constants import Environment from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings if global_settings.ENVIRONMENT == Environment.TESTING: connect_args = {"check_same_thread": False} engine = create_engine( SQLALCHEMY_DATABASE_URI.get_secret_value(), connect_args=connect_args, poolclass=StaticPool, ) else: engine = create_engine( SQLALCHEMY_DATABASE_URI.get_secret_value(), pool_size=global_settings.DATABASE_POOL_SIZE, pool_recycle=global_settings.DATABASE_POOL_TTL, pool_pre_ping=global_settings.DATABASE_POOL_PRE_PING, ) sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine) @contextmanager def get_db_connection() -> Generator[Connection, None, None]: with engine.connect() as connection: try: yield connection except Exception: connection.rollback() raise def _get_db_connection() -> Generator[Connection, None, None]: with get_db_connection() as connection: yield connection DbConnection = Annotated[Connection, Depends(_get_db_connection)] @contextmanager def get_db_session() -> Generator[Session, None, None]: session = sm() try: yield session except Exception: session.rollback() raise finally: session.close() def _get_db_session() -> Generator[Session, None, None]: with get_db_session() as session: yield session DbSession = Annotated[Session, Depends(_get_db_session)]