Compare commits

..

5 commits

Author SHA1 Message Date
d3d9316741 tests: query generator 2026-06-05 12:17:32 +01:00
29245e5c13 tests: questionnaire submission test
Assertion to verify that None type question answers don't overwrite preexisting answers.
2026-06-05 09:31:56 +01:00
b8b39188f6 tests/minor: rename test functions
`_failure` in name wasn't necessarily an accurate descriptor for parameterized testing with expected statuses. `_status_checks` used instead.
2026-06-05 09:22:44 +01:00
f600664789 tests: improved coverage 2026-06-05 09:10:55 +01:00
c8024daa97 minor: renames and error messages 2026-06-04 14:53:35 +01:00
10 changed files with 253 additions and 106 deletions

View file

@ -75,20 +75,20 @@ async def can_act_on_resource(valid_key: service_key_dependency, db: db_dependen
else: else:
return False return False
except Exception: except Exception:
raise UnauthorizedException() return False
@router.get("/group/permissions", response_model=IAMGetGroupPermissionsResponse) @router.get("/group/permissions", response_model=IAMGetGroupPermissionsResponse)
async def get_group_permissions(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency): async def get_group_permissions(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency):
if group_model.org_id != org_model.id: if group_model.org_id != org_model.id:
raise UnauthorizedException() raise UnauthorizedException("Group does not belong to this organization")
return {"permissions": group_model.permission_rel} return {"permissions": group_model.permission_rel}
@router.get("/group/users", response_model=IAMGetGroupUsersResponse) @router.get("/group/users", response_model=IAMGetGroupUsersResponse)
async def get_group_users(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency): async def get_group_users(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency):
if group_model.org_id != org_model.id: if group_model.org_id != org_model.id:
raise UnauthorizedException() raise UnauthorizedException("Group does not belong to this organization")
return {"users": group_model.user_rel} return {"users": group_model.user_rel}
@ -110,7 +110,7 @@ async def create_group(db: db_dependency, org_model: org_model_root_claim_body_d
@router.put("/group/permission", response_model=IAMPutGroupPermissionResponse) @router.put("/group/permission", response_model=IAMPutGroupPermissionResponse)
async def add_group_permission(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupPermissionRequest): async def add_group_permission(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupPermissionRequest):
if group_model.org_id != org_model.id: if group_model.org_id != org_model.id:
raise UnauthorizedException() raise UnauthorizedException("Group does not belong to this organization")
if perm_model in group_model.permission_rel: if perm_model in group_model.permission_rel:
raise ConflictException("Group already has this permission") raise ConflictException("Group already has this permission")
@ -126,7 +126,7 @@ async def add_group_permission(db: db_dependency, group_model: group_model_body_
@router.put("/group/user", response_model=IAMPutGroupUserResponse) @router.put("/group/user", response_model=IAMPutGroupUserResponse)
async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupUserRequest): async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupUserRequest):
if group_model.org_id != org_model.id: if group_model.org_id != org_model.id:
raise UnauthorizedException() raise UnauthorizedException("Group does not belong to this organization")
if user_model in group_model.user_rel: if user_model in group_model.user_rel:
raise ConflictException("User already in group") raise ConflictException("User already in group")
@ -141,7 +141,7 @@ async def add_group_user(db: db_dependency, group_model: group_model_body_depend
@router.delete("/group/permissions") @router.delete("/group/permissions")
async def remove_group_permissions(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupPermissionRequest): async def remove_group_permissions(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupPermissionRequest):
if group_model.org_id != org_model.id: if group_model.org_id != org_model.id:
raise UnauthorizedException() raise UnauthorizedException("Group does not belong to this organization")
group_model.permission_rel.remove(perm_model) group_model.permission_rel.remove(perm_model)
db.flush() db.flush()
@ -154,7 +154,7 @@ async def remove_group_permissions(db: db_dependency, group_model: group_model_b
@router.delete("/group/user") @router.delete("/group/user")
async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupUserRequest): async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupUserRequest):
if group_model.org_id != org_model.id: if group_model.org_id != org_model.id:
raise UnauthorizedException() raise UnauthorizedException("Group does not belong to this organization")
user_model.group_rel.remove(group_model) user_model.group_rel.remove(group_model)
db.flush() db.flush()

View file

@ -9,6 +9,7 @@ from typing import Annotated
from src.service.models import Service from src.service.models import Service
from src.database import db_dependency from src.database import db_dependency
from src.schemas import ResourceName from src.schemas import ResourceName
from src.auth.exceptions import UnauthorizedException
from fastapi import HTTPException, status, Request, Depends from fastapi import HTTPException, status, Request, Depends
@ -16,11 +17,11 @@ from fastapi import HTTPException, status, Request, Depends
def valid_service_key(db: db_dependency, request: Request, rn: ResourceName) -> bool: def valid_service_key(db: db_dependency, request: Request, rn: ResourceName) -> bool:
api_key = request.headers.get("X-API-Key", None) api_key = request.headers.get("X-API-Key", None)
if not api_key: if not api_key:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise UnauthorizedException("Missing API key")
service = rn.service service = rn.service
result = db.query(Service).filter(Service.name == service).filter(Service.api_key == api_key).first() result = db.query(Service).filter(Service.name == service).filter(Service.api_key == api_key).first()
if result is None: if result is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) raise UnauthorizedException("Invalid API key")
return True return True

View file

@ -137,8 +137,6 @@ async def update_questionnaire(db: db_dependency, org_model: org_model_root_clai
if hasattr(questionnaire_model, key): if hasattr(questionnaire_model, key):
setattr(questionnaire_model, key, value) setattr(questionnaire_model, key, value)
else: else:
if key == "partial" or key == "organisation_id":
continue
raise UnprocessableContentException("Invalid keys in update request") raise UnprocessableContentException("Invalid keys in update request")
# Allows for partially completed questionnaires to be saved without being submitted for review # Allows for partially completed questionnaires to be saved without being submitted for review
@ -241,7 +239,7 @@ async def update_root_user(db: db_dependency, org_model: org_model_body_dependen
Promotes an existing organisation user to the root user, giving them full control of the org. Promotes an existing organisation user to the root user, giving them full control of the org.
""" """
if user_model not in org_model.user_rel: if user_model not in org_model.user_rel:
raise UnauthorizedException(message="This user does not belong to your organisation.") raise UnprocessableContentException(message="This user does not belong to your organisation.")
org_model.root_user_rel = user_model org_model.root_user_rel = user_model
db.flush() db.flush()
response = OrgPatchRootResponse(name=org_model.name, root_user_email=org_model.root_user_email) response = OrgPatchRootResponse(name=org_model.name, root_user_email=org_model.root_user_email)

View file

@ -1,4 +1,5 @@
from typing import AsyncGenerator from typing import AsyncGenerator
from itertools import combinations
import pytest import pytest
from httpx import AsyncClient, ASGITransport from httpx import AsyncClient, ASGITransport
@ -95,6 +96,40 @@ def _seed(db):
group_model.user_rel.append(user_model) group_model.user_rel.append(user_model)
db.commit() db.commit()
def generate_query_and_status(params) -> list[tuple[str, int]]:
possible_values = [0, -1, 42, "banana", ""]
defaults = [f"{param}=1" for param in params]
# Missing params
query_list = [
"&".join(combo)
for r in range(len(defaults) + 1)
for combo in combinations(defaults, r)
]
# Complete query as default for invalid checks
default_query = query_list.pop(-1)
# Checks for each param being invalid
for param in params:
for value in possible_values:
new_value = f"&{param}={value}"
query_list.append(default_query.replace(f"{param}=1", new_value))
query_and_status = []
# Assign expected status
for query in query_list:
# ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if "42" in query else 422
# Remove leading "&" if present
query = query if len(query) > 1 and query[0] != "&" else query[1:]
query_and_status.append((query, status))
return query_and_status
# # Produces a text file with method and path for every endpoint in the API # # Produces a text file with method and path for every endpoint in the API
# from fastapi.routing import APIRoute # from fastapi.routing import APIRoute
# #

26
test/test_auth_general.py Normal file
View file

@ -0,0 +1,26 @@
"""
"""
import pytest
from httpx import AsyncClient
from .conftest import no_su_client
from src.organisation.models import Organisation as Org
from src.user.models import User
from src.iam.models import Group
@pytest.mark.anyio
async def test_get_org_auth_root_su(default_client: AsyncClient, db_session):
# If a super admin can access a resource when not the root user
db_session.add(User(email="admin@test.org", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-4321"))
db_session.flush()
db_session.add(
Org(name="Test Org Two", root_user_id=2, billing_contact_id=1, owner_contact_id=2, security_contact_id=3,
status="approved", intake_questionnaire={}))
db_session.flush()
resp = await default_client.get("/org?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 200
assert resp.json()["name"] == "Test Org Two"

View file

@ -8,7 +8,7 @@ from .conftest import no_user_client
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_self_db(no_user_client: AsyncClient): async def test_get_self_db_auth_user(no_user_client: AsyncClient):
resp = await no_user_client.get("/user/self/db") resp = await no_user_client.get("/user/self/db")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
@ -16,7 +16,7 @@ async def test_get_self_db(no_user_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_success(no_user_client: AsyncClient): async def test_post_org_success_auth_user(no_user_client: AsyncClient):
resp = await no_user_client.post("/org", json={"name": "New Test Org"}) resp = await no_user_client.post("/org", json={"name": "New Test Org"})
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401

View file

@ -1,14 +1,14 @@
""" """
Act on resource tests only check for pass/fail on input validation. Logic is not tested.
""" """
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from src.user.models import User from src.user.models import User
from .conftest import default_client, db_session from src.organisation.models import Organisation as Org
from src.iam.models import Group from src.iam.models import Group
from .conftest import default_client, db_session, generate_query_and_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient): async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient):
@ -28,7 +28,48 @@ async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient
assert data == True assert data == True
@pytest.mark.parametrize(
"service, api_key",
[
("Test Service", "not_the_correct_key"),
("Test Service Two", "123456789")
],
)
@pytest.mark.anyio @pytest.mark.anyio
async def test_act_on_resource_wrong_key(default_client: AsyncClient, db_session, service: str, api_key: str):
body = {
"service": service,
"organisation": "Test Org",
"resource": "test_resource"
}
headers = {
"Authorization": "Bearer not_checked_when_auth_is_disabled",
"X-API-Key": api_key
}
resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers)
data = resp.json()
assert resp.status_code == 401
assert data["detail"] == "Invalid API key"
@pytest.mark.anyio
async def test_act_on_resource_missing_key(default_client: AsyncClient):
body = {
"service": "Test Service",
"organisation": "Test Org",
"resource": "test_resource"
}
headers = {
"Authorization": "Bearer not_checked_when_auth_is_disabled"
}
resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers)
data = resp.json()
assert resp.status_code == 401
assert data["detail"] == "Missing API key"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"service, org, resource, action, expected_status", "service, org, resource, action, expected_status",
[ [
@ -41,7 +82,7 @@ async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_act_on_resource_endpoint_failure(default_client: AsyncClient, service, org, resource, action, async def test_act_on_resource_endpoint_status_checks(default_client: AsyncClient, service, org, resource, action,
expected_status: int): expected_status: int):
body = { body = {
"service": service, "service": service,
@ -57,6 +98,34 @@ async def test_post_act_on_resource_endpoint_failure(default_client: AsyncClient
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.parametrize(
"service, org, resource, action, expected_response",
[
("Test Service", "Test Org", "test_resource", "read", True),
("Test Service", "Test Org", "test_resource", "create", False),
("Test Service", "Test Org", "no_access_here", "read", False),
("Test Service", "Test Org Two", "test_resource", "read", False),
],
)
@pytest.mark.anyio
async def test_act_on_resource_logic(default_client: AsyncClient, db_session, service, org, resource, action,
expected_response: bool):
body = {
"service": service,
"organisation": org,
"resource": resource
}
headers = {
"Authorization": "Bearer not_checked_when_auth_is_disabled",
"X-API-Key": "123456789"
}
resp = await default_client.post(f"/iam/can_act_on_resource?action={action}", json=body, headers=headers)
data = resp.json()
assert resp.status_code == 200
assert data == expected_response
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_group_permissions_success(default_client: AsyncClient): async def test_get_group_permissions_success(default_client: AsyncClient):
resp = await default_client.get("/iam/group/permissions?org_id=1&group_id=1") resp = await default_client.get("/iam/group/permissions?org_id=1&group_id=1")
@ -72,28 +141,33 @@ async def test_get_group_permissions_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", "query, expected_status",
[ generate_query_and_status(["group_id", "org_id"])
("org_id=2&group_id=1", 404), # Non-exists org, valid group
("org_id=banana&group_id=1", 422), # Invalid org, valid group
("org_id=&group_id=1", 422), # Blank org, valid group
("org_id=-1&group_id=1", 422), # Negative org, valid group
("group_id=1", 422), # Only group
("", 422), # Blank query
("org_id=&group_id=", 422), # Both blank
("org_id=1&group_id=2", 404), # Valid org, non-exists group
("org_id=1&group_id=banana", 422), # Valid org, invalid group
("org_id=1&group_id=", 422), # Valid org, blank group
("org_id=1&group_id=-1", 422), # Valid org, negative group
("org_id=1", 422), # Only org
],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_group_permissions_failure(default_client: AsyncClient, query: str, expected_status: int): async def test_get_group_permissions_status_checks(default_client: AsyncClient, db_session, query: str, expected_status: int):
resp = await default_client.get(f"/iam/group/permissions?{query}") resp = await default_client.get(f"/iam/group/permissions?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.parametrize(
"query",
[
"org_id=1&group_id=2",
"org_id=2&group_id=1",
],
)
@pytest.mark.anyio
async def test_get_group_permissions_mismatch(default_client: AsyncClient, db_session, query: str):
db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved"))
db_session.add(Group(name="Another Test Group", org_id=2))
db_session.flush()
resp = await default_client.get(f"/iam/group/permissions?{query}")
assert resp.status_code == 401
assert resp.json()["detail"] == "Group does not belong to this organization"
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_group_users_success(default_client: AsyncClient): async def test_get_group_users_success(default_client: AsyncClient):
resp = await default_client.get("/iam/group/users?org_id=1&group_id=1") resp = await default_client.get("/iam/group/users?org_id=1&group_id=1")
@ -110,28 +184,33 @@ async def test_get_group_users_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", "query, expected_status",
[ generate_query_and_status(["group_id", "org_id"])
("org_id=2&group_id=1", 404), # Non-exists org, valid group
("org_id=banana&group_id=1", 422), # Invalid org, valid group
("org_id=&group_id=1", 422), # Blank org, valid group
("org_id=-1&group_id=1", 422), # Negative org, valid group
("group_id=1", 422), # Only group
("", 422), # Blank query
("org_id=&group_id=", 422), # Both blank
("org_id=1&group_id=2", 404), # Valid org, non-exists group
("org_id=1&group_id=banana", 422), # Valid org, invalid group
("org_id=1&group_id=", 422), # Valid org, blank group
("org_id=1&group_id=-1", 422), # Valid org, negative group
("org_id=1", 422), # Only org
],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_group_users_failure(default_client: AsyncClient, query: str, expected_status: int): async def test_get_group_users_status_checks(default_client: AsyncClient, query: str, expected_status: int):
resp = await default_client.get(f"/iam/group/users?{query}") resp = await default_client.get(f"/iam/group/users?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.parametrize(
"query",
[
"org_id=1&group_id=2",
"org_id=2&group_id=1",
],
)
@pytest.mark.anyio
async def test_get_group_users_mismatch(default_client: AsyncClient, db_session, query: str):
db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved"))
db_session.add(Group(name="Another Test Group", org_id=2))
db_session.flush()
resp = await default_client.get(f"/iam/group/users?{query}")
assert resp.status_code == 401
assert resp.json()["detail"] == "Group does not belong to this organization"
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_group_success(default_client: AsyncClient): async def test_post_group_success(default_client: AsyncClient):
resp = await default_client.post("/iam/group", json={"name": "New Group", "organisation_id": 1}) resp = await default_client.post("/iam/group", json={"name": "New Group", "organisation_id": 1})
@ -161,7 +240,7 @@ async def test_post_group_success(default_client: AsyncClient):
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_group_failure(default_client: AsyncClient, body: dict[str, str], expected_status: int): async def test_post_group_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await default_client.post("/iam/group", json=body) resp = await default_client.post("/iam/group", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -214,15 +293,35 @@ async def test_put_group_perm_success(default_client: AsyncClient, db_session):
({"group_id": 1, "permission_id": 1}, 422), # Missing organisation ({"group_id": 1, "permission_id": 1}, 422), # Missing organisation
({"organisation_id": 1, "group_id": 1}, 422), # Missing permission ({"organisation_id": 1, "group_id": 1}, 422), # Missing permission
({"organisation_id": 1, "group_id": 1, "permission_id": 1}, 409), # Permission already in group
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_group_perm_failure(default_client: AsyncClient, body: dict[str, str], expected_status: int): async def test_put_group_perm_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await default_client.put("/iam/group/permission", json=body) resp = await default_client.put("/iam/group/permission", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.parametrize(
"body",
[
{"organisation_id": 1, "group_id": 2, "permission_id": 1},
{"organisation_id": 2, "group_id": 1, "permission_id": 1},
],
)
@pytest.mark.anyio
async def test_put_group_perm_mismatch(default_client: AsyncClient, db_session, body: dict):
db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved"))
db_session.add(Group(name="Another Test Group", org_id=2))
db_session.flush()
resp = await default_client.put(f"/iam/group/permission", json=body)
assert resp.status_code == 401
assert resp.json()["detail"] == "Group does not belong to this organization"
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_group_user_success(default_client: AsyncClient, db_session): async def test_put_group_user_success(default_client: AsyncClient, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
@ -275,7 +374,7 @@ async def test_put_group_user_success(default_client: AsyncClient, db_session):
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_group_user_failure(default_client: AsyncClient, body: dict[str, str], expected_status: int): async def test_put_group_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await default_client.put("/iam/group/user", json=body) resp = await default_client.put("/iam/group/user", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -296,16 +395,10 @@ async def test_get_permissions_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", "query, expected_status",
[ generate_query_and_status(["org_id"])
("org_id=42", 404), # Non-exists org
("org_id=banana", 422), # Invalid org
("org_id=", 422), # Blank org
("org_id=-1", 422), # Negative org
("", 422), # Blank query
],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_permissions_failure(default_client: AsyncClient, query: str, expected_status: int): async def test_get_permissions_status_checks(default_client: AsyncClient, query: str, expected_status: int):
resp = await default_client.get(f"/iam/permissions?{query}") resp = await default_client.get(f"/iam/permissions?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -351,7 +444,7 @@ async def test_post_perm_success(default_client: AsyncClient, db_session):
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_perm_failure(default_client: AsyncClient, body: dict[str, str], expected_status: int): async def test_post_perm_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await default_client.post("/iam/permission", json=body) resp = await default_client.post("/iam/permission", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -408,7 +501,7 @@ async def test_post_perm_search_success(default_client: AsyncClient, db_session,
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_perm_search_failure(default_client: AsyncClient, body: dict[str, str], expected_status: int): async def test_post_perm_search_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await default_client.post("/iam/permissions/search", json=body) resp = await default_client.post("/iam/permissions/search", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status

View file

@ -6,7 +6,7 @@ from httpx import AsyncClient
from src.organisation.models import Organisation, OrgUsers from src.organisation.models import Organisation, OrgUsers
from src.user.models import User from src.user.models import User
from .conftest import default_client from .conftest import default_client, generate_query_and_status
@pytest.mark.anyio @pytest.mark.anyio
@ -25,14 +25,10 @@ async def test_get_org_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", "query, expected_status",
[ generate_query_and_status(["org_id"])
("org_id=2", 404),
("org_id=banana", 422),
("", 422),
],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_failure(default_client: AsyncClient, query: str, expected_status: int): async def test_get_org_status_checks(default_client: AsyncClient, query: str, expected_status: int):
resp = await default_client.get(f"/org?{query}") resp = await default_client.get(f"/org?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -57,7 +53,7 @@ async def test_post_org_success(default_client: AsyncClient):
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_failure(default_client: AsyncClient, body: dict[str, str], expected_status: int): async def test_post_org_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await default_client.post("/org", json=body) resp = await default_client.post("/org", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -93,7 +89,7 @@ async def test_patch_org_questionnaire_partial_success(default_client: AsyncClie
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_questionnaire_partial_failure(default_client: AsyncClient, body: dict[str, str], expected_status: int): async def test_patch_questionnaire_partial_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await default_client.patch("/org/questionnaire", json=body) resp = await default_client.patch("/org/questionnaire", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -112,7 +108,7 @@ async def test_patch_org_questionnaire_submit_success(default_client: AsyncClien
assert data["name"] == "Test Org" assert data["name"] == "Test Org"
assert data["intake_questionnaire"]["question_one"] == "new answer one" assert data["intake_questionnaire"]["question_one"] == "new answer one"
assert data["status"] == "submitted" assert data["status"] == "submitted"
# assert type(data["intake_questionnaire"]["question_two"]) == str assert data["intake_questionnaire"]["question_two"] == "answer two"
assert data["intake_questionnaire"]["question_three"] is None assert data["intake_questionnaire"]["question_three"] is None
@ -142,7 +138,7 @@ async def test_patch_org_status_success(default_client: AsyncClient, status: str
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_status_failure(default_client: AsyncClient, body: dict[str, str], expected_status: int): async def test_patch_org_status_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await default_client.patch("/org/status", json=body) resp = await default_client.patch("/org/status", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -166,14 +162,10 @@ async def test_get_org_users_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", "query, expected_status",
[ generate_query_and_status(["org_id"])
("org_id=2", 404),
("org_id=banana", 422),
("", 422),
],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_users_failure(default_client: AsyncClient, query: str, expected_status: int): async def test_get_org_users_status_checks(default_client: AsyncClient, query: str, expected_status: int):
resp = await default_client.get(f"/org/users?{query}") resp = await default_client.get(f"/org/users?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -201,10 +193,11 @@ async def test_post_org_user_success(default_client: AsyncClient, db_session):
({"organisation_id": 1, "user_id": "id"}, 422), ({"organisation_id": 1, "user_id": "id"}, 422),
({"user_id": 2}, 422), ({"user_id": 2}, 422),
({"organisation_id": 1, "user_id": 42}, 404), ({"organisation_id": 1, "user_id": 42}, 404),
({"organisation_id": 1, "user_id": 1}, 409),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_failure(default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session): async def test_post_org_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
db_session.flush() db_session.flush()
@ -241,7 +234,7 @@ async def test_patch_org_root_user_success(default_client: AsyncClient, db_sessi
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_root_user_failure(default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session): async def test_patch_root_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
db_session.flush() db_session.flush()
db_session.add(OrgUsers(org_id=1, user_id=2)) db_session.add(OrgUsers(org_id=1, user_id=2))
@ -252,6 +245,18 @@ async def test_patch_root_user_failure(default_client: AsyncClient, body: dict[s
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_patch_org_root_user_non_member(default_client: AsyncClient, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
db_session.flush()
resp = await default_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2})
data = resp.json()
assert resp.status_code == 422
assert data["detail"] == "This user does not belong to your organisation."
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_groups_success(default_client: AsyncClient): async def test_get_org_groups_success(default_client: AsyncClient):
resp = await default_client.get("/org/groups?org_id=1") resp = await default_client.get("/org/groups?org_id=1")
@ -265,14 +270,10 @@ async def test_get_org_groups_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", "query, expected_status",
[ generate_query_and_status(["org_id"])
("org_id=2", 404),
("org_id=banana", 422),
("", 422),
],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_groups_failure(default_client: AsyncClient, query: str, expected_status: int): async def test_get_org_groups_status_checks(default_client: AsyncClient, query: str, expected_status: int):
resp = await default_client.get(f"/org/groups?{query}") resp = await default_client.get(f"/org/groups?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -326,7 +327,7 @@ async def test_get_org_contact_success(default_client: AsyncClient, contact_type
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_contact_failure(default_client: AsyncClient, query: str, expected_status: int): async def test_get_org_contact_status_checks(default_client: AsyncClient, query: str, expected_status: int):
resp = await default_client.get(f"/org/contact?{query}") resp = await default_client.get(f"/org/contact?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -367,6 +368,8 @@ async def test_patch_org_contact_success(default_client: AsyncClient, key: str,
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 42, "contact_type": "billing"}, 404), ({"organisation_id": 42, "contact_type": "billing"}, 404),
({"organisation_id": 1, "contact_type": "security"}, 200),
({"organisation_id": 1, "contact_type": "owner"}, 200),
({"organisation_id": "Test Org", "contact_type": "billing"}, 422), ({"organisation_id": "Test Org", "contact_type": "billing"}, 422),
({"organisation_id": "", "contact_type": "billing"}, 422), ({"organisation_id": "", "contact_type": "billing"}, 422),
({}, 422), ({}, 422),
@ -376,7 +379,7 @@ async def test_patch_org_contact_success(default_client: AsyncClient, key: str,
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_status_failure(default_client: AsyncClient, body: dict[str, str], expected_status: int): async def test_patch_org_status_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await default_client.patch("/org/contact", json=body) resp = await default_client.patch("/org/contact", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status

View file

@ -4,7 +4,7 @@
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from .conftest import default_client from .conftest import default_client, generate_query_and_status
@pytest.mark.anyio @pytest.mark.anyio
@ -20,14 +20,10 @@ async def test_get_services_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", "query, expected_status",
[ generate_query_and_status(["org_id"])
("org_id=2", 404),
("org_id=banana", 422),
("", 422),
],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_services_failure(default_client: AsyncClient, query: str, expected_status: int): async def test_get_services_status_checks(default_client: AsyncClient, query: str, expected_status: int):
resp = await default_client.get(f"/service/?{query}") resp = await default_client.get(f"/service/?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -53,7 +49,7 @@ async def test_post_service_success(default_client: AsyncClient):
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_services_failure(default_client: AsyncClient, body: dict[str, str], expected_status: int): async def test_post_services_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await default_client.post("/service/", json=body) resp = await default_client.post("/service/", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -81,7 +77,7 @@ async def test_patch_service_success(default_client: AsyncClient):
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_services_failure(default_client: AsyncClient, body: dict[str, str], expected_status: int): async def test_patch_services_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await default_client.patch("/service/key", json=body) resp = await default_client.patch("/service/key", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status

View file

@ -6,10 +6,10 @@
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from .conftest import default_client from .conftest import default_client, generate_query_and_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_self_db(default_client: AsyncClient): async def test_get_self_db_success(default_client: AsyncClient):
resp = await default_client.get("/user/self/db") resp = await default_client.get("/user/self/db")
data = resp.json() data = resp.json()
@ -37,14 +37,9 @@ async def test_get_user_success(default_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", "query, expected_status",
[ generate_query_and_status(["user_id"])
("user_id=1", 200),
("user_id=2", 404),
("user_id=banana", 422),
("", 422),
],
) )
async def test_get_user_fail(default_client: AsyncClient, query: str, expected_status: int): async def test_get_user_status_checks(default_client: AsyncClient, query: str, expected_status: int):
resp = await default_client.get(f"/user/?{query}") resp = await default_client.get(f"/user/?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status