diff --git a/src/database.py b/src/database.py index a56f80d..673af86 100644 --- a/src/database.py +++ b/src/database.py @@ -6,7 +6,7 @@ Exports: - Base (sqlalchemy base model) """ from typing import Annotated -from sqlalchemy import create_engine, StaticPool +from sqlalchemy import create_engine from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session from fastapi import Depends @@ -16,10 +16,10 @@ from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings if global_settings.ENVIRONMENT == Environment.TESTING: connect_args = {"check_same_thread": False} - engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value(), connect_args=connect_args, poolclass=StaticPool) else: - engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value()) + connect_args = {} +engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value(), connect_args=connect_args) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) diff --git a/src/main.py b/src/main.py index c421a94..d3fc945 100644 --- a/src/main.py +++ b/src/main.py @@ -67,7 +67,7 @@ app.add_middleware( allow_headers=settings.CORS_HEADERS, ) -if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL): +if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL or settings.ENVIRONMENT == Environment.TESTING): app.dependency_overrides[get_current_user] = get_dev_user diff --git a/src/service/router.py b/src/service/router.py index a8f93ea..d1a6a41 100644 --- a/src/service/router.py +++ b/src/service/router.py @@ -66,6 +66,7 @@ async def register_service(db: db_dependency, su: super_admin_dependency, reques except IntegrityError as e: if isinstance(e.orig, UniqueViolation): raise ConflictException(message="Service with this name already exists") + db.commit() response = ServiceWithKeySchema(**service_model.__dict__) db.commit() return {"service": response} diff --git a/src/user/schemas.py b/src/user/schemas.py index b688412..27d455e 100644 --- a/src/user/schemas.py +++ b/src/user/schemas.py @@ -46,7 +46,7 @@ class UserResponse(CustomBaseModel): last_name: str email: str organisations: list[Optional[str]] - groups: Optional[dict[str, list[str]]] = None + groups: dict[str, list[str]] class OrgResponse(CustomBaseModel): diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index b6fde60..0000000 --- a/test/conftest.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import AsyncGenerator - -import pytest -from httpx import AsyncClient, ASGITransport -from sqlalchemy.orm import sessionmaker - -from src.user.models import User -from src.service.models import Service -from src.organisation.models import Organisation as Org -from src.contact.models import Contact -from src.iam.models import Group, Permission -from src.auth.service import get_current_user, get_dev_user - -from src.main import app # inited FastAPI app -from src.database import engine, Base, get_db - - -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - - -@pytest.fixture() -def db_session(): - db = SessionLocal() - try: - yield db - except: - db.rollback() - raise - finally: - db.close() - - -@pytest.fixture -async def client(db_session) -> AsyncGenerator[AsyncClient, None]: - def get_db_override(): - return db_session - app.dependency_overrides[get_db] = get_db_override - app.dependency_overrides[get_current_user] = get_dev_user - transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac: - yield ac - - app.dependency_overrides.clear() - - -@pytest.fixture(scope="session") -def setup_database(): - Base.metadata.create_all(bind=engine) - yield - Base.metadata.drop_all(bind=engine) - - -@pytest.fixture(scope="session") -def seed_db(): - db = SessionLocal() - try: - db.add(User(email="admin@test.com", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-mnop")) - db.add(Contact(org_id=1)) - db.add(Contact(org_id=1)) - db.add(Contact(org_id=1)) - db.flush() - db.add(Org(name="Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=1, security_contact_id=1, - status="approved", intake_questionnaire="{}")) - db.add(Service(name="Test Service", api_key="123456789")) - db.add(Permission(service_id=1, resource="test_resource", action="read")) - db.add(Group(name="Test Group")) - db.flush() - group_model = db.get(Group, 1) - perm_model = db.get(Permission, 1) - group_model.permission_rel.append(perm_model) - user_model = db.get(User, 1) - org_model = db.get(Org, 1) - org_model.user_rel.append(user_model) - db.flush() - group_model.user_rel.append(user_model) - db.commit() - yield db - finally: - db.close() - - -@pytest.fixture(scope="session", autouse=True) -def seed_data(setup_database, seed_db): - yield - \ No newline at end of file diff --git a/test/test_healthcheck.py b/test/test_healthcheck.py deleted file mode 100644 index 8b18e34..0000000 --- a/test/test_healthcheck.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest -from httpx import AsyncClient - -from .conftest import client - -@pytest.mark.anyio -async def test_healthcheck(client: AsyncClient): - resp = await client.get("/healthcheck") - - assert resp.status_code == 200 - assert resp.json() == {"status": "ok"} diff --git a/test/test_service.py b/test/test_service.py deleted file mode 100644 index 4b5c7d2..0000000 --- a/test/test_service.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -409 on [POST]/service/ not tested because SQLite throws a different error than Postgres -""" -import pytest -from httpx import AsyncClient - -from .conftest import client - - -@pytest.mark.anyio -async def test_get_services_success(client: AsyncClient): - resp = await client.get("/service/?org_id=1") - data = resp.json() - - assert resp.status_code == 200 - assert "services" in data - assert data["services"][0]["id"] == 1 - assert data["services"][0]["name"] == "Test Service" - - -@pytest.mark.anyio -@pytest.mark.parametrize( - "query, expected_status", - [ - ("org_id=2", 404), - ("org_id=banana", 422), - ("", 422), - ], -) -@pytest.mark.anyio -async def test_get_services_failure(client: AsyncClient, query: str, expected_status: int): - resp = await client.get(f"/service/?{query}") - - assert resp.status_code == expected_status - - -@pytest.mark.anyio -async def test_post_service_success(client: AsyncClient): - resp = await client.post("/service/", json={"name": "New Test Service"}) - data = resp.json() - - assert resp.status_code == 200 - assert "service" in data - assert data["service"]["name"] == "New Test Service" - assert data["service"]["id"] == 2 - assert type(data["service"]["api_key"]) == str - - -@pytest.mark.anyio -@pytest.mark.parametrize( - "body, expected_status", - [ - ({"name": 42}, 422), - ({}, 422), - ], -) -@pytest.mark.anyio -async def test_post_services_failure(client: AsyncClient, body: dict[str, str], expected_status: int): - resp = await client.post("/service/", json=body) - - assert resp.status_code == expected_status - - -@pytest.mark.anyio -async def test_patch_service_success(client: AsyncClient): - resp = await client.patch("/service/key", json={"service_id": 1}) - data = resp.json() - - assert resp.status_code == 200 - assert "service" in data - assert data["service"]["name"] == "Test Service" - assert data["service"]["id"] == 1 - assert type(data["service"]["api_key"]) == str - - -@pytest.mark.anyio -@pytest.mark.parametrize( - "body, expected_status", - [ - ({"service_id": 42}, 404), - ({"service_id": "Test Service"}, 422), - ({"service_id": ""}, 422), - ({}, 422), - ], -) -@pytest.mark.anyio -async def test_patch_services_failure(client: AsyncClient, body: dict[str, str], expected_status: int): - resp = await client.patch("/service/key", json=body) - - assert resp.status_code == expected_status diff --git a/test/test_user.py b/test/test_user.py deleted file mode 100644 index 7563c00..0000000 --- a/test/test_user.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -[GET]/user/self/claims is not tested because it requires OIDC authentication. -[DELETE/user/ is not tested because the testing client cannot attach a body to a delete request. -""" - -import pytest -from httpx import AsyncClient - -from .conftest import client - -@pytest.mark.anyio -async def test_get_self_db(client: AsyncClient): - resp = await client.get("/user/self/db") - data = resp.json() - - assert resp.status_code == 200 - assert data["first_name"] == "Admin" - assert data["last_name"] == "Test" - assert data["email"] == "admin@test.com" - assert "organisations" in data - assert "groups" in data - - -@pytest.mark.anyio -async def test_get_user_success(client: AsyncClient): - resp = await client.get("/user/?user_id=1") - data = resp.json() - - assert resp.status_code == 200 - assert data["first_name"] == "Admin" - assert data["last_name"] == "Test" - assert data["email"] == "admin@test.com" - assert "organisations" in data - assert "groups" in data - - -@pytest.mark.anyio -@pytest.mark.parametrize( - "query, expected_status", - [ - ("user_id=1", 200), - ("user_id=2", 404), - ("user_id=banana", 422), - ("", 422), - ], -) -async def test_get_user_fail(client: AsyncClient, query: str, expected_status: int): - resp = await client.get(f"/user/?{query}") - - assert resp.status_code == expected_status