cloud-api/test/conftest.py

106 lines
3.3 KiB
Python

from typing import AsyncGenerator
import pytest
from httpx import AsyncClient, ASGITransport
from sqlalchemy.orm import sessionmaker
from src.user.models import User
from src.service.models import Service
from src.organisation.models import Organisation as Org
from src.contact.models import Contact
from src.iam.models import Group, Permission
from src.auth.service import get_current_user, get_dev_user
from src.auth.dependencies import user_model_super_admin, never_super_admin
from src.main import app # inited FastAPI app
from src.database import engine, Base, get_db
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture()
def db_session():
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
db = SessionLocal()
try:
_seed(db) # extracted seeding logic into a plain function
yield db
finally:
db.rollback()
db.close()
@pytest.fixture
async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override():
return db_session
app.dependency_overrides[get_db] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac:
yield ac
app.dependency_overrides.clear()
@pytest.fixture
async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override():
return db_session
app.dependency_overrides[get_db] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[user_model_super_admin] = never_super_admin
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac:
yield ac
app.dependency_overrides.clear()
def _seed(db):
db.add(User(email="admin@test.com", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-mnop"))
db.add(Contact(org_id=1, email="billing@test.org", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="owner@test.org", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="security@test.org", phonenumber="07521539927"))
db.flush()
db.add(Org(name="Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3,
status="approved", intake_questionnaire={"question_two": "answer two"}))
db.add(Service(name="Test Service", api_key="123456789"))
db.add(Permission(service_id=1, resource="test_resource", action="read"))
db.add(Group(name="Test Group", org_id=1))
db.flush()
group_model = db.get(Group, 1)
perm_model = db.get(Permission, 1)
group_model.permission_rel.append(perm_model)
user_model = db.get(User, 1)
org_model = db.get(Org, 1)
org_model.user_rel.append(user_model)
org_model.group_rel.append(group_model)
db.flush()
group_model.user_rel.append(user_model)
db.commit()
# # Produces a text file with method and path for every endpoint in the API
# from fastapi.routing import APIRoute
#
# def get_testable_routes():
# routes = []
#
# for route in app.routes:
# if not isinstance(route, APIRoute):
# continue
#
# for method in route.methods:
# if method in {"HEAD", "OPTIONS"}:
# continue
#
# routes.append((route.path, method))
#
# return routes
#
#
# with open("endpoints.txt", "w") as f:
# for ep in get_testable_routes():
# f.write(f"{ep[1]} {ep[0]}\n")