diff --git a/src/database.py b/src/database.py index 1761560..819faeb 100644 --- a/src/database.py +++ b/src/database.py @@ -17,13 +17,16 @@ engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value()) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + def get_db(): - with SessionLocal.begin() as db: - try: - yield db - finally: - db.rollback() # Anything not explicitly commited is rolled back - db.close() + db = SessionLocal() + try: + yield db + except: + db.rollback() + raise + finally: + db.close() db_dependency = Annotated[Session, Depends(get_db)]