Compare commits

..

No commits in common. "b8b6e6c7eeb4bbf7844dedbbf1a34c82890c5419" and "1a81be210ac81f506aa7ef0c0a8cb1cdb8f83106" have entirely different histories.

9 changed files with 6 additions and 241 deletions

View file

@ -6,7 +6,7 @@ Exports:
- Base (sqlalchemy base model) - Base (sqlalchemy base model)
""" """
from typing import Annotated from typing import Annotated
from sqlalchemy import create_engine, StaticPool from sqlalchemy import create_engine
from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session
from fastapi import Depends from fastapi import Depends
@ -16,10 +16,10 @@ from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings
if global_settings.ENVIRONMENT == Environment.TESTING: if global_settings.ENVIRONMENT == Environment.TESTING:
connect_args = {"check_same_thread": False} connect_args = {"check_same_thread": False}
engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value(), connect_args=connect_args, poolclass=StaticPool)
else: else:
engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value()) connect_args = {}
engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value(), connect_args=connect_args)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

View file

@ -67,7 +67,7 @@ app.add_middleware(
allow_headers=settings.CORS_HEADERS, allow_headers=settings.CORS_HEADERS,
) )
if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL): if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL or settings.ENVIRONMENT == Environment.TESTING):
app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_current_user] = get_dev_user

View file

@ -66,6 +66,7 @@ async def register_service(db: db_dependency, su: super_admin_dependency, reques
except IntegrityError as e: except IntegrityError as e:
if isinstance(e.orig, UniqueViolation): if isinstance(e.orig, UniqueViolation):
raise ConflictException(message="Service with this name already exists") raise ConflictException(message="Service with this name already exists")
db.commit()
response = ServiceWithKeySchema(**service_model.__dict__) response = ServiceWithKeySchema(**service_model.__dict__)
db.commit() db.commit()
return {"service": response} return {"service": response}

View file

@ -46,7 +46,7 @@ class UserResponse(CustomBaseModel):
last_name: str last_name: str
email: str email: str
organisations: list[Optional[str]] organisations: list[Optional[str]]
groups: Optional[dict[str, list[str]]] = None groups: dict[str, list[str]]
class OrgResponse(CustomBaseModel): class OrgResponse(CustomBaseModel):

View file

View file

@ -1,85 +0,0 @@
from typing import AsyncGenerator
import pytest
from httpx import AsyncClient, ASGITransport
from sqlalchemy.orm import sessionmaker
from src.user.models import User
from src.service.models import Service
from src.organisation.models import Organisation as Org
from src.contact.models import Contact
from src.iam.models import Group, Permission
from src.auth.service import get_current_user, get_dev_user
from src.main import app # inited FastAPI app
from src.database import engine, Base, get_db
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture()
def db_session():
db = SessionLocal()
try:
yield db
except:
db.rollback()
raise
finally:
db.close()
@pytest.fixture
async def client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override():
return db_session
app.dependency_overrides[get_db] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac:
yield ac
app.dependency_overrides.clear()
@pytest.fixture(scope="session")
def setup_database():
Base.metadata.create_all(bind=engine)
yield
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="session")
def seed_db():
db = SessionLocal()
try:
db.add(User(email="admin@test.com", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-mnop"))
db.add(Contact(org_id=1))
db.add(Contact(org_id=1))
db.add(Contact(org_id=1))
db.flush()
db.add(Org(name="Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=1, security_contact_id=1,
status="approved", intake_questionnaire="{}"))
db.add(Service(name="Test Service", api_key="123456789"))
db.add(Permission(service_id=1, resource="test_resource", action="read"))
db.add(Group(name="Test Group"))
db.flush()
group_model = db.get(Group, 1)
perm_model = db.get(Permission, 1)
group_model.permission_rel.append(perm_model)
user_model = db.get(User, 1)
org_model = db.get(Org, 1)
org_model.user_rel.append(user_model)
db.flush()
group_model.user_rel.append(user_model)
db.commit()
yield db
finally:
db.close()
@pytest.fixture(scope="session", autouse=True)
def seed_data(setup_database, seed_db):
yield

View file

@ -1,11 +0,0 @@
import pytest
from httpx import AsyncClient
from .conftest import client
@pytest.mark.anyio
async def test_healthcheck(client: AsyncClient):
resp = await client.get("/healthcheck")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}

View file

@ -1,90 +0,0 @@
"""
409 on [POST]/service/ not tested because SQLite throws a different error than Postgres
"""
import pytest
from httpx import AsyncClient
from .conftest import client
@pytest.mark.anyio
async def test_get_services_success(client: AsyncClient):
resp = await client.get("/service/?org_id=1")
data = resp.json()
assert resp.status_code == 200
assert "services" in data
assert data["services"][0]["id"] == 1
assert data["services"][0]["name"] == "Test Service"
@pytest.mark.anyio
@pytest.mark.parametrize(
"query, expected_status",
[
("org_id=2", 404),
("org_id=banana", 422),
("", 422),
],
)
@pytest.mark.anyio
async def test_get_services_failure(client: AsyncClient, query: str, expected_status: int):
resp = await client.get(f"/service/?{query}")
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_post_service_success(client: AsyncClient):
resp = await client.post("/service/", json={"name": "New Test Service"})
data = resp.json()
assert resp.status_code == 200
assert "service" in data
assert data["service"]["name"] == "New Test Service"
assert data["service"]["id"] == 2
assert type(data["service"]["api_key"]) == str
@pytest.mark.anyio
@pytest.mark.parametrize(
"body, expected_status",
[
({"name": 42}, 422),
({}, 422),
],
)
@pytest.mark.anyio
async def test_post_services_failure(client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await client.post("/service/", json=body)
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_patch_service_success(client: AsyncClient):
resp = await client.patch("/service/key", json={"service_id": 1})
data = resp.json()
assert resp.status_code == 200
assert "service" in data
assert data["service"]["name"] == "Test Service"
assert data["service"]["id"] == 1
assert type(data["service"]["api_key"]) == str
@pytest.mark.anyio
@pytest.mark.parametrize(
"body, expected_status",
[
({"service_id": 42}, 404),
({"service_id": "Test Service"}, 422),
({"service_id": ""}, 422),
({}, 422),
],
)
@pytest.mark.anyio
async def test_patch_services_failure(client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await client.patch("/service/key", json=body)
assert resp.status_code == expected_status

View file

@ -1,50 +0,0 @@
"""
[GET]/user/self/claims is not tested because it requires OIDC authentication.
[DELETE/user/ is not tested because the testing client cannot attach a body to a delete request.
"""
import pytest
from httpx import AsyncClient
from .conftest import client
@pytest.mark.anyio
async def test_get_self_db(client: AsyncClient):
resp = await client.get("/user/self/db")
data = resp.json()
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert "groups" in data
@pytest.mark.anyio
async def test_get_user_success(client: AsyncClient):
resp = await client.get("/user/?user_id=1")
data = resp.json()
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert "groups" in data
@pytest.mark.anyio
@pytest.mark.parametrize(
"query, expected_status",
[
("user_id=1", 200),
("user_id=2", 404),
("user_id=banana", 422),
("", 422),
],
)
async def test_get_user_fail(client: AsyncClient, query: str, expected_status: int):
resp = await client.get(f"/user/?{query}")
assert resp.status_code == expected_status