from typing import AsyncGenerator from itertools import combinations import pytest from httpx import AsyncClient, ASGITransport from sqlalchemy.orm import sessionmaker from src.user.models import User from src.service.models import Service from src.organisation.models import Organisation as Org from src.contact.models import Contact from src.iam.models import Group, Permission from src.auth.service import get_current_user, get_dev_user from src.auth.dependencies import empty_su_list, get_super_admin_list, testing_su_list from src.main import app # inited FastAPI app from src.database import engine, Base, get_db SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) @pytest.fixture() def db_session(): Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) db = SessionLocal() try: _seed(db) # extracted seeding logic into a plain function yield db finally: db.rollback() db.close() @pytest.fixture async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]: def get_db_override(): return db_session app.dependency_overrides[get_db] = get_db_override app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_super_admin_list] = testing_su_list transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac: yield ac app.dependency_overrides.clear() @pytest.fixture async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]: def get_db_override(): return db_session app.dependency_overrides[get_db] = get_db_override transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac: yield ac app.dependency_overrides.clear() @pytest.fixture async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]: def get_db_override(): return db_session app.dependency_overrides[get_db] = get_db_override app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_super_admin_list] = empty_su_list transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac: yield ac app.dependency_overrides.clear() def _seed(db): db.add(User(email="admin@test.com", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-mnop")) db.add(Contact(org_id=1, email="billing@test.org", phonenumber="07521539927")) db.add(Contact(org_id=1, email="owner@test.org", phonenumber="07521539927")) db.add(Contact(org_id=1, email="security@test.org", phonenumber="07521539927")) db.flush() db.add(Org(name="Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved", intake_questionnaire={"question_two": "answer two"})) db.add(Service(name="Test Service", api_key="123456789")) db.add(Permission(service_id=1, resource="test_resource", action="read")) db.add(Group(name="Test Group", org_id=1)) db.flush() group_model = db.get(Group, 1) perm_model = db.get(Permission, 1) group_model.permission_rel.append(perm_model) user_model = db.get(User, 1) org_model = db.get(Org, 1) org_model.user_rel.append(user_model) org_model.group_rel.append(group_model) db.flush() group_model.user_rel.append(user_model) db.commit() def generate_query_and_status(params) -> list[tuple[str, int]]: possible_values = [0, -1, 42, "banana", ""] defaults = [f"{param}=1" for param in params] # Missing params query_list = [ "&".join(combo) for r in range(len(defaults) + 1) for combo in combinations(defaults, r) ] # Complete query as default for invalid checks default_query = query_list.pop(-1) # Checks for each param being invalid for param in params: for value in possible_values: new_value = f"&{param}={value}" query_list.append(default_query.replace(f"{param}=1", new_value)) query_and_status = [] # Assign expected status for query in query_list: # ID 42 is used to represent a non-existent entry. So it should 404. status = 404 if "42" in query else 422 # Remove leading "&" if present query = query if len(query) > 1 and query[0] != "&" else query[1:] query_and_status.append((query, status)) return query_and_status # # Produces a text file with method and path for every endpoint in the API # from fastapi.routing import APIRoute # # def get_testable_routes(): # routes = [] # # for route in app.routes: # if not isinstance(route, APIRoute): # continue # # for method in route.methods: # if method in {"HEAD", "OPTIONS"}: # continue # # routes.append((route.path, method)) # # return routes # # # with open("endpoints.txt", "w") as f: # for ep in get_testable_routes(): # f.write(f"{ep[1]} {ep[0]}\n")