Compare commits

...

6 commits

Author SHA1 Message Date
b8b6e6c7ee test: services router 2026-05-29 16:55:21 +01:00
c4b4000d62 fix: user response groups optional 2026-05-29 16:55:05 +01:00
7e89ab0afd test: broader db seeding 2026-05-29 16:54:37 +01:00
a2e18300b9 fix: post service extra db commit 2026-05-29 16:38:16 +01:00
79f8104f2f tests: user router 2026-05-29 15:18:19 +01:00
19145271ae tests: test init 2026-05-29 15:18:10 +01:00
9 changed files with 241 additions and 6 deletions

View file

@ -6,7 +6,7 @@ Exports:
- Base (sqlalchemy base model)
"""
from typing import Annotated
from sqlalchemy import create_engine
from sqlalchemy import create_engine, StaticPool
from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session
from fastapi import Depends
@ -16,10 +16,10 @@ from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings
if global_settings.ENVIRONMENT == Environment.TESTING:
connect_args = {"check_same_thread": False}
engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value(), connect_args=connect_args, poolclass=StaticPool)
else:
connect_args = {}
engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value())
engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value(), connect_args=connect_args)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

View file

@ -67,7 +67,7 @@ app.add_middleware(
allow_headers=settings.CORS_HEADERS,
)
if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL or settings.ENVIRONMENT == Environment.TESTING):
if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL):
app.dependency_overrides[get_current_user] = get_dev_user

View file

@ -66,7 +66,6 @@ async def register_service(db: db_dependency, su: super_admin_dependency, reques
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise ConflictException(message="Service with this name already exists")
db.commit()
response = ServiceWithKeySchema(**service_model.__dict__)
db.commit()
return {"service": response}

View file

@ -46,7 +46,7 @@ class UserResponse(CustomBaseModel):
last_name: str
email: str
organisations: list[Optional[str]]
groups: dict[str, list[str]]
groups: Optional[dict[str, list[str]]] = None
class OrgResponse(CustomBaseModel):

0
test/__init__.py Normal file
View file

85
test/conftest.py Normal file
View file

@ -0,0 +1,85 @@
from typing import AsyncGenerator
import pytest
from httpx import AsyncClient, ASGITransport
from sqlalchemy.orm import sessionmaker
from src.user.models import User
from src.service.models import Service
from src.organisation.models import Organisation as Org
from src.contact.models import Contact
from src.iam.models import Group, Permission
from src.auth.service import get_current_user, get_dev_user
from src.main import app # inited FastAPI app
from src.database import engine, Base, get_db
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture()
def db_session():
db = SessionLocal()
try:
yield db
except:
db.rollback()
raise
finally:
db.close()
@pytest.fixture
async def client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override():
return db_session
app.dependency_overrides[get_db] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac:
yield ac
app.dependency_overrides.clear()
@pytest.fixture(scope="session")
def setup_database():
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="session")
def seed_db():
db = SessionLocal()
try:
db.add(User(email="admin@test.com", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-mnop"))
db.add(Contact(org_id=1))
db.add(Contact(org_id=1))
db.add(Contact(org_id=1))
db.flush()
db.add(Org(name="Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=1, security_contact_id=1,
status="approved", intake_questionnaire="{}"))
db.add(Service(name="Test Service", api_key="123456789"))
db.add(Permission(service_id=1, resource="test_resource", action="read"))
db.add(Group(name="Test Group"))
db.flush()
group_model = db.get(Group, 1)
perm_model = db.get(Permission, 1)
group_model.permission_rel.append(perm_model)
user_model = db.get(User, 1)
org_model = db.get(Org, 1)
org_model.user_rel.append(user_model)
db.flush()
group_model.user_rel.append(user_model)
db.commit()
yield db
finally:
db.close()
@pytest.fixture(scope="session", autouse=True)
def seed_data(setup_database, seed_db):
yield

11
test/test_healthcheck.py Normal file
View file

@ -0,0 +1,11 @@
import pytest
from httpx import AsyncClient
from .conftest import client
@pytest.mark.anyio
async def test_healthcheck(client: AsyncClient):
resp = await client.get("/healthcheck")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}

90
test/test_service.py Normal file
View file

@ -0,0 +1,90 @@
"""
409 on [POST]/service/ not tested because SQLite throws a different error than Postgres
"""
import pytest
from httpx import AsyncClient
from .conftest import client
@pytest.mark.anyio
async def test_get_services_success(client: AsyncClient):
resp = await client.get("/service/?org_id=1")
data = resp.json()
assert resp.status_code == 200
assert "services" in data
assert data["services"][0]["id"] == 1
assert data["services"][0]["name"] == "Test Service"
@pytest.mark.anyio
@pytest.mark.parametrize(
"query, expected_status",
[
("org_id=2", 404),
("org_id=banana", 422),
("", 422),
],
)
@pytest.mark.anyio
async def test_get_services_failure(client: AsyncClient, query: str, expected_status: int):
resp = await client.get(f"/service/?{query}")
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_post_service_success(client: AsyncClient):
resp = await client.post("/service/", json={"name": "New Test Service"})
data = resp.json()
assert resp.status_code == 200
assert "service" in data
assert data["service"]["name"] == "New Test Service"
assert data["service"]["id"] == 2
assert type(data["service"]["api_key"]) == str
@pytest.mark.anyio
@pytest.mark.parametrize(
"body, expected_status",
[
({"name": 42}, 422),
({}, 422),
],
)
@pytest.mark.anyio
async def test_post_services_failure(client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await client.post("/service/", json=body)
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_patch_service_success(client: AsyncClient):
resp = await client.patch("/service/key", json={"service_id": 1})
data = resp.json()
assert resp.status_code == 200
assert "service" in data
assert data["service"]["name"] == "Test Service"
assert data["service"]["id"] == 1
assert type(data["service"]["api_key"]) == str
@pytest.mark.anyio
@pytest.mark.parametrize(
"body, expected_status",
[
({"service_id": 42}, 404),
({"service_id": "Test Service"}, 422),
({"service_id": ""}, 422),
({}, 422),
],
)
@pytest.mark.anyio
async def test_patch_services_failure(client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await client.patch("/service/key", json=body)
assert resp.status_code == expected_status

50
test/test_user.py Normal file
View file

@ -0,0 +1,50 @@
"""
[GET]/user/self/claims is not tested because it requires OIDC authentication.
[DELETE/user/ is not tested because the testing client cannot attach a body to a delete request.
"""
import pytest
from httpx import AsyncClient
from .conftest import client
@pytest.mark.anyio
async def test_get_self_db(client: AsyncClient):
resp = await client.get("/user/self/db")
data = resp.json()
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert "groups" in data
@pytest.mark.anyio
async def test_get_user_success(client: AsyncClient):
resp = await client.get("/user/?user_id=1")
data = resp.json()
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert "groups" in data
@pytest.mark.anyio
@pytest.mark.parametrize(
"query, expected_status",
[
("user_id=1", 200),
("user_id=2", 404),
("user_id=banana", 422),
("", 422),
],
)
async def test_get_user_fail(client: AsyncClient, query: str, expected_status: int):
resp = await client.get(f"/user/?{query}")
assert resp.status_code == expected_status