test: db setup fix

Previous set up seemed to work but the db was persisting between tests, allowing for contamination.
This commit is contained in:
Chris Milne 2026-06-01 15:25:50 +01:00
parent fc6990c43d
commit c0d353077b

View file

@ -20,13 +20,14 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture() @pytest.fixture()
def db_session(): def db_session():
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
db = SessionLocal() db = SessionLocal()
try: try:
_seed(db) # extracted seeding logic into a plain function
yield db yield db
except:
db.rollback()
raise
finally: finally:
db.rollback()
db.close() db.close()
@ -43,43 +44,25 @@ async def client(db_session) -> AsyncGenerator[AsyncClient, None]:
app.dependency_overrides.clear() app.dependency_overrides.clear()
@pytest.fixture(scope="session") def _seed(db):
def setup_database(): db.add(User(email="admin@test.com", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-mnop"))
Base.metadata.create_all(bind=engine) db.add(Contact(org_id=1, email="billing@test.org", phonenumber="07521539927"))
yield db.add(Contact(org_id=1, email="owner@test.org", phonenumber="07521539927"))
Base.metadata.drop_all(bind=engine) db.add(Contact(org_id=1, email="security@test.org", phonenumber="07521539927"))
db.flush()
db.add(Org(name="Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3,
@pytest.fixture(scope="session") status="approved", intake_questionnaire={"question_two": "answer two"}))
def seed_db(): db.add(Service(name="Test Service", api_key="123456789"))
db = SessionLocal() db.add(Permission(service_id=1, resource="test_resource", action="read"))
try: db.add(Group(name="Test Group"))
db.add(User(email="admin@test.com", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-mnop")) db.flush()
db.add(Contact(org_id=1)) group_model = db.get(Group, 1)
db.add(Contact(org_id=1)) perm_model = db.get(Permission, 1)
db.add(Contact(org_id=1)) group_model.permission_rel.append(perm_model)
db.flush() user_model = db.get(User, 1)
db.add(Org(name="Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=1, security_contact_id=1, org_model = db.get(Org, 1)
status="approved", intake_questionnaire="{}")) org_model.user_rel.append(user_model)
db.add(Service(name="Test Service", api_key="123456789")) org_model.group_rel.append(group_model)
db.add(Permission(service_id=1, resource="test_resource", action="read")) db.flush()
db.add(Group(name="Test Group")) group_model.user_rel.append(user_model)
db.flush() db.commit()
group_model = db.get(Group, 1)
perm_model = db.get(Permission, 1)
group_model.permission_rel.append(perm_model)
user_model = db.get(User, 1)
org_model = db.get(Org, 1)
org_model.user_rel.append(user_model)
db.flush()
group_model.user_rel.append(user_model)
db.commit()
yield db
finally:
db.close()
@pytest.fixture(scope="session", autouse=True)
def seed_data(setup_database, seed_db):
yield