cloud-api/test/conftest.py
luxferre 43ed768f66
All checks were successful
ci / lint_and_test (push) Successful in 15s
feat: minimum lengths for names
2026-06-12 15:58:20 +01:00

287 lines
7.7 KiB
Python

import pytest
from typing import AsyncGenerator
from itertools import combinations
from fastapi.routing import APIRoute
from httpx import AsyncClient, ASGITransport
from sqlalchemy.orm import sessionmaker
from src.user.models import User
from src.service.models import Service
from src.organisation.models import Organisation as Org, OrgUsers
from src.contact.models import Contact
from src.iam.models import Group, Permission
from src.auth.service import get_current_user, get_dev_user
from src.auth.dependencies import empty_su_list, get_super_admin_list, testing_su_list
from src.main import app # inited FastAPI app
from src.database import engine, Base, get_db
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture()
def db_session():
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
db = SessionLocal()
try:
_seed(db) # extracted seeding logic into a plain function
yield db
finally:
db.rollback()
db.close()
@pytest.fixture
async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override():
return db_session
app.dependency_overrides[get_db] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_super_admin_list] = testing_su_list
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac
app.dependency_overrides.clear()
@pytest.fixture
async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override():
return db_session
app.dependency_overrides[get_db] = get_db_override
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac
app.dependency_overrides.clear()
@pytest.fixture
async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override():
return db_session
app.dependency_overrides[get_db] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_super_admin_list] = empty_su_list
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac
app.dependency_overrides.clear()
def _seed(db):
db.add(
User(
email="admin@test.com",
first_name="Admin",
last_name="Test",
oidc_id="abcd-efgh-ijkl-mnop",
)
)
db.add(
User(
email="user@orgone.com",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-qwer",
)
)
db.add(
User(
email="root@orgtwo.com",
first_name="Root",
last_name="Test",
oidc_id="abcd-efgh-ijkl-hjkl",
)
)
db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927"))
db.flush()
db.add(
Org(
name="Org One",
root_user_id=1,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(
Org(
name="Org Two",
root_user_id=3,
billing_contact_id=4,
owner_contact_id=5,
security_contact_id=6,
status="approved",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(
Org(
name="Org Three",
root_user_id=1,
billing_contact_id=7,
owner_contact_id=8,
security_contact_id=9,
status="partial",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(OrgUsers(org_id=1, user_id=2))
db.add(Service(name="Test Service", api_key="123456789"))
db.add(Permission(service_id=1, resource="test_resource", action="read"))
db.add(Permission(service_id=1, resource="test_resource", action="move"))
db.add(Group(name="Org One Group", org_id=1))
db.add(Group(name="Org Two Group", org_id=2))
db.add(Group(name="Org One Group Two", org_id=1))
db.flush()
group_model = db.get(Group, 1)
perm_model = db.get(Permission, 1)
group_model.permission_rel.append(perm_model)
user_model = db.get(User, 1)
org_model = db.get(Org, 1)
org_model.user_rel.append(user_model)
org_model.group_rel.append(group_model)
db.flush()
group_model.user_rel.append(user_model)
db.commit()
def generate_query_and_status(params) -> list[tuple[str, int]]:
possible_values = [0, -1, 42, "banana", ""]
defaults = [f"{param}=1" for param in params]
# Missing params
query_list = [
"&".join(combo)
for r in range(len(defaults) + 1)
for combo in combinations(defaults, r)
]
# Complete query as default for invalid checks
default_query = query_list.pop(-1)
# Checks for each param being invalid
for param in params:
for value in possible_values:
new_value = f"&{param}={value}"
query_list.append(default_query.replace(f"{param}=1", new_value))
query_and_status = []
# Assign expected status
for query in query_list:
# ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if "42" in query else 422
# Remove leading "&" if present
query = query if len(query) > 1 and query[0] != "&" else query[1:]
query_and_status.append((query, status))
return query_and_status
def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]:
possible_values_int = [0, -1, 42, "banana", ""]
possible_values_str = [0, "", "a"]
defaults = [{param: 1 for param in params.keys()}]
# Missing params
body_list = [
{key: ("valid string" if params[key] == "str" else 1) for key in combo}
for r in range(len(defaults[0].keys()) + 1)
for combo in combinations(defaults[0].keys(), r)
]
# Complete body as default for generating invalid checks
default_body = body_list.pop(-1)
# Generates checks for each param being invalid
for param, typ in params.items():
if typ == "int":
possible_values = possible_values_int
elif typ == "str":
possible_values = possible_values_str
else:
raise TypeError(f"Unknown type {typ}")
for value in possible_values:
new_record = default_body.copy()
new_record[param] = value
body_list.append(new_record)
body_and_status = []
# Assign expected status
for body in body_list:
# ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if 42 in body.values() else 422
body_and_status.append((body, status))
return body_and_status
def get_testable_routes():
routes = []
for route in app.routes:
if not isinstance(route, APIRoute):
continue
for method in route.methods:
if method in {"HEAD", "OPTIONS"}:
continue
routes.append(
(
method,
route.path,
route.status_code,
route.response_model,
route.summary,
)
)
return routes
# with open("endpoints.txt", "w") as f:
# for ep in get_testable_routes():
# f.write(f"[{ep[0]}]({ep[1]}) -> {ep[2]}: {ep[3]}\n")
#
#
### Docstring formatted output ###
# with open("endpoints.txt", "w") as f:
# for ep in get_testable_routes():
# f.write(f"- [{ep[0]}]({ep[1]}): []: {ep[4]}\n")