forked from sr2/cloud-api
feat(db): db tuning options and consistency
This commit is contained in:
parent
40918fd8b8
commit
84ba3b6bee
12 changed files with 104 additions and 80 deletions
|
|
@ -1,13 +1,9 @@
|
|||
"""
|
||||
Database connections and init
|
||||
|
||||
Exports:
|
||||
- db_dependency
|
||||
- Base (sqlalchemy base model)
|
||||
Database connection and session utilities
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
from sqlalchemy import create_engine, StaticPool
|
||||
from contextlib import contextmanager
|
||||
from typing import Annotated, Generator
|
||||
from sqlalchemy import create_engine, StaticPool, Connection
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
|
||||
from fastapi import Depends
|
||||
|
|
@ -16,28 +12,52 @@ from src.constants import Environment
|
|||
from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings
|
||||
|
||||
if global_settings.ENVIRONMENT == Environment.TESTING:
|
||||
connect_args = {"check_same_thread": False}
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URI.get_secret_value(),
|
||||
connect_args=connect_args,
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
connect_args = {"check_same_thread": False}
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URI.get_secret_value(),
|
||||
connect_args=connect_args,
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
else:
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value())
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URI.get_secret_value(),
|
||||
pool_size=global_settings.DATABASE_POOL_SIZE,
|
||||
pool_recycle=global_settings.DATABASE_POOL_TTL,
|
||||
pool_pre_ping=global_settings.DATABASE_POOL_PRE_PING,
|
||||
)
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine)
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection() -> Generator[Connection, None, None]:
|
||||
with engine.connect() as connection:
|
||||
try:
|
||||
yield connection
|
||||
except Exception:
|
||||
connection.rollback()
|
||||
raise
|
||||
|
||||
def _get_db_connection() -> Generator[Connection, None]:
|
||||
with get_db_connection() as connection:
|
||||
yield connection
|
||||
|
||||
DbConnection = Annotated[Connection, Depends(_get_db_connection)]
|
||||
|
||||
@contextmanager
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
session = sm()
|
||||
try:
|
||||
yield session
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
except:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
def _get_db_session() -> Generator[Session, None]:
|
||||
with get_db_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
db_dependency = Annotated[Session, Depends(get_db)]
|
||||
DbSession = Annotated[Session, Depends(_get_db_session)]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue