2026-04-06 12:41:49 +01:00
|
|
|
"""
|
2026-06-22 12:58:37 +01:00
|
|
|
Database connection and session utilities
|
2026-04-06 12:41:49 +01:00
|
|
|
"""
|
2026-06-22 15:04:11 +01:00
|
|
|
|
2026-06-22 12:58:37 +01:00
|
|
|
from contextlib import contextmanager
|
|
|
|
|
from typing import Annotated, Generator
|
|
|
|
|
from sqlalchemy import create_engine, StaticPool, Connection
|
2026-06-20 18:42:36 +01:00
|
|
|
from sqlalchemy.orm import sessionmaker, Session
|
2026-04-06 12:41:49 +01:00
|
|
|
|
|
|
|
|
from fastapi import Depends
|
|
|
|
|
|
2026-05-29 11:00:01 +01:00
|
|
|
from src.constants import Environment
|
|
|
|
|
from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings
|
2026-04-06 12:41:49 +01:00
|
|
|
|
2026-05-29 11:00:01 +01:00
|
|
|
if global_settings.ENVIRONMENT == Environment.TESTING:
|
2026-06-22 12:58:37 +01:00
|
|
|
connect_args = {"check_same_thread": False}
|
|
|
|
|
engine = create_engine(
|
|
|
|
|
SQLALCHEMY_DATABASE_URI.get_secret_value(),
|
|
|
|
|
connect_args=connect_args,
|
|
|
|
|
poolclass=StaticPool,
|
|
|
|
|
)
|
2026-05-29 11:00:01 +01:00
|
|
|
else:
|
2026-06-22 12:58:37 +01:00
|
|
|
engine = create_engine(
|
|
|
|
|
SQLALCHEMY_DATABASE_URI.get_secret_value(),
|
|
|
|
|
pool_size=global_settings.DATABASE_POOL_SIZE,
|
|
|
|
|
pool_recycle=global_settings.DATABASE_POOL_TTL,
|
|
|
|
|
pool_pre_ping=global_settings.DATABASE_POOL_PRE_PING,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine)
|
|
|
|
|
|
2026-06-22 15:04:11 +01:00
|
|
|
|
2026-06-22 12:58:37 +01:00
|
|
|
@contextmanager
|
|
|
|
|
def get_db_connection() -> Generator[Connection, None, None]:
|
|
|
|
|
with engine.connect() as connection:
|
|
|
|
|
try:
|
|
|
|
|
yield connection
|
|
|
|
|
except Exception:
|
|
|
|
|
connection.rollback()
|
|
|
|
|
raise
|
|
|
|
|
|
2026-06-22 15:04:11 +01:00
|
|
|
|
2026-06-22 15:12:34 +01:00
|
|
|
def _get_db_connection() -> Generator[Connection, None, None]:
|
2026-06-22 12:58:37 +01:00
|
|
|
with get_db_connection() as connection:
|
|
|
|
|
yield connection
|
|
|
|
|
|
2026-06-22 15:04:11 +01:00
|
|
|
|
2026-06-22 12:58:37 +01:00
|
|
|
DbConnection = Annotated[Connection, Depends(_get_db_connection)]
|
|
|
|
|
|
2026-06-22 15:04:11 +01:00
|
|
|
|
2026-06-22 12:58:37 +01:00
|
|
|
@contextmanager
|
|
|
|
|
def get_db_session() -> Generator[Session, None, None]:
|
|
|
|
|
session = sm()
|
|
|
|
|
try:
|
|
|
|
|
yield session
|
|
|
|
|
except Exception:
|
|
|
|
|
session.rollback()
|
|
|
|
|
raise
|
|
|
|
|
finally:
|
|
|
|
|
session.close()
|
|
|
|
|
|
|
|
|
|
|
2026-06-22 15:12:34 +01:00
|
|
|
def _get_db_session() -> Generator[Session, None, None]:
|
2026-06-22 12:58:37 +01:00
|
|
|
with get_db_session() as session:
|
|
|
|
|
yield session
|
|
|
|
|
|
2026-06-22 15:04:11 +01:00
|
|
|
|
2026-06-22 12:58:37 +01:00
|
|
|
DbSession = Annotated[Session, Depends(_get_db_session)]
|