Compare commits

...

2 commits

Author SHA1 Message Date
43ed768f66 feat: minimum lengths for names
All checks were successful
ci / lint_and_test (push) Successful in 15s
2026-06-12 15:58:20 +01:00
092e12a892 feat: org status check moved
Accessing endpoints as super admin no longer requires the org to be approved.
2026-06-12 14:50:32 +01:00
8 changed files with 107 additions and 130 deletions

View file

@ -9,8 +9,9 @@ Exports:
""" """
from typing import Annotated from typing import Annotated
from fastapi import Depends from fastapi import Depends, Request
from src.auth.service import org_status_check
from src.exceptions import ForbiddenException from src.exceptions import ForbiddenException
from src.user.dependencies import user_model_claims_dependency from src.user.dependencies import user_model_claims_dependency
from src.user.models import User from src.user.models import User
@ -37,16 +38,19 @@ async def org_query_root_claims(
user_model: user_model_claims_dependency, user_model: user_model_claims_dependency,
org_model: org_model_query_dependency, org_model: org_model_query_dependency,
su_emails: su_list_dependency, su_emails: su_list_dependency,
request: Request,
): ):
if org_model.root_user_id == user_model.id:
return org_model
try: try:
if await user_model_super_admin(user_model, su_emails): if await user_model_super_admin(user_model, su_emails):
return org_model return org_model
except ForbiddenException: except ForbiddenException:
pass pass
await org_status_check(org_model, request)
if org_model.root_user_id == user_model.id:
return org_model
raise ForbiddenException(message="Must be the org's root user") raise ForbiddenException(message="Must be the org's root user")
@ -59,16 +63,19 @@ async def org_body_root_claims(
user_model: user_model_claims_dependency, user_model: user_model_claims_dependency,
org_model: org_model_body_dependency, org_model: org_model_body_dependency,
su_emails: su_list_dependency, su_emails: su_list_dependency,
request: Request,
): ):
if org_model.root_user_id == user_model.id:
return org_model
try: try:
if await user_model_super_admin(user_model, su_emails): if await user_model_super_admin(user_model, su_emails):
return org_model return org_model
except ForbiddenException: except ForbiddenException:
pass pass
await org_status_check(org_model, request)
if org_model.root_user_id == user_model.id:
return org_model
raise ForbiddenException(message="Must be the org's root user") raise ForbiddenException(message="Must be the org's root user")

View file

@ -14,10 +14,13 @@ from joserfc.errors import ExpiredTokenError
from joserfc.jwk import KeySet from joserfc.jwk import KeySet
from urllib.request import urlopen from urllib.request import urlopen
from fastapi import Depends from fastapi import Depends, Request
from fastapi.security import OpenIdConnect from fastapi.security import OpenIdConnect
from src.exceptions import UnauthorizedException from src.organisation.constants import Status as OrgStatus
from src.organisation.exceptions import AwaitingApprovalException
from src.organisation.models import Organisation as Org
from src.exceptions import UnauthorizedException, ForbiddenException
from src.auth.config import auth_settings from src.auth.config import auth_settings
from src.user.service import add_user_to_db from src.user.service import add_user_to_db
from src.database import db_dependency from src.database import db_dependency
@ -27,7 +30,7 @@ oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG)
oidc_dependency = Annotated[str, Depends(oidc)] oidc_dependency = Annotated[str, Depends(oidc)]
def get_dev_user(): async def get_dev_user():
return {"db_id": 1, "email": "chris@sr2.uk"} return {"db_id": 1, "email": "chris@sr2.uk"}
@ -61,3 +64,26 @@ async def get_current_user(
claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)]
async def org_status_check(org_model: Org, request: Request):
org_status = OrgStatus(org_model.status)
if org_status.is_blocked:
raise ForbiddenException("This organisation cannot perform this action.")
root = "/api/v1"
pre_approval_endpoints = [
f"PATCH{root}/org/status",
f"PATCH{root}/org/questionnaire",
f"GET{root}/org",
f"GET{root}/org/contact",
f"PATCH{root}/org/contact",
f"DELETE{root}/org/self",
]
current_request = f"{request.method}{request.url.path}"
if (
current_request not in pre_approval_endpoints
and org_model.status != OrgStatus.APPROVED
):
raise AwaitingApprovalException(org_model.id)

View file

@ -7,65 +7,38 @@ Exports:
""" """
from typing import Annotated, Optional from typing import Annotated, Optional
from sqlalchemy.orm import Session
from fastapi import Depends, Query, Request from fastapi import Depends, Query
from src.database import db_dependency from src.database import db_dependency
from src.exceptions import ForbiddenException
from src.organisation.schemas import OrgIDMixin from src.organisation.schemas import OrgIDMixin
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
from src.organisation.exceptions import OrgNotFoundException, AwaitingApprovalException from src.organisation.exceptions import OrgNotFoundException
from src.organisation.constants import Status as OrgStatus
def get_org_model(db: Session, request: Request, org_id: int):
org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
org_status = OrgStatus(org_model.status)
if org_status.is_blocked:
raise ForbiddenException("This organisation cannot perform this action.")
root = "/api/v1"
pre_approval_endpoints = [
f"PATCH{root}/org/status",
f"PATCH{root}/org/questionnaire",
f"GET{root}/org",
f"GET{root}/org/contact",
f"PATCH{root}/org/contact",
f"DELETE{root}/org/self",
]
current_request = f"{request.method}{request.url.path}"
if (
current_request not in pre_approval_endpoints
and org_model.status != OrgStatus.APPROVED
):
raise AwaitingApprovalException(org_id)
return org_model
def get_org_model_query( def get_org_model_query(
db: db_dependency, request: Request, org_id: Annotated[int, Query(gt=0)] db: db_dependency, org_id: Annotated[int, Query(gt=0)]
) -> type[Org]: ) -> type[Org]:
return get_org_model(db, request, org_id) org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
return org_model
org_model_query_dependency = Annotated[type[Org], Depends(get_org_model_query)] org_model_query_dependency = Annotated[type[Org], Depends(get_org_model_query)]
def get_org_model_body( def get_org_model_body(db: db_dependency, request_model: OrgIDMixin) -> type[Org]:
db: db_dependency, request: Request, request_model: OrgIDMixin
) -> type[Org]:
org_id: Optional[int] = getattr(request_model, "organisation_id", None) org_id: Optional[int] = getattr(request_model, "organisation_id", None)
if org_id is None: if org_id is None:
raise OrgNotFoundException() raise OrgNotFoundException()
return get_org_model(db, request, org_id) org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
return org_model
org_model_body_dependency = Annotated[type[Org], Depends(get_org_model_body)] org_model_body_dependency = Annotated[type[Org], Depends(get_org_model_body)]

View file

@ -9,7 +9,7 @@ Models follow the nomenclature of:
from typing import Optional from typing import Optional
from datetime import datetime from datetime import datetime
from pydantic import EmailStr, ConfigDict from pydantic import EmailStr, ConfigDict, Field
from src.schemas import ( from src.schemas import (
CustomBaseModel, CustomBaseModel,
@ -55,7 +55,7 @@ class OrgSchema(OrgIDMixin):
class OrgPostOrgRequest(CustomBaseModel): class OrgPostOrgRequest(CustomBaseModel):
name: str name: str = Field(min_length=3)
intake_questionnaire: Optional[CurrentQuestions] = None intake_questionnaire: Optional[CurrentQuestions] = None

View file

@ -55,3 +55,8 @@ class GroupSummary(CustomBaseModel):
class UserSummary(CustomBaseModel): class UserSummary(CustomBaseModel):
id: int id: int
email: str email: str
class ServiceSummary(CustomBaseModel):
id: int
name: str

View file

@ -6,28 +6,21 @@ Models follow the nomenclature of:
- Models: "<Module><Method><Resource><Opt:Resource><Direction>" ie "ServiceGetServiceResponse" - Models: "<Module><Method><Resource><Opt:Resource><Direction>" ie "ServiceGetServiceResponse"
""" """
from pydantic import ConfigDict from pydantic import Field
from src.schemas import CustomBaseModel, ServiceIDMixin from src.schemas import CustomBaseModel, ServiceIDMixin, ServiceSummary
class ServiceSchema(CustomBaseModel): class ServiceWithKeySchema(ServiceSummary):
model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int
name: str
class ServiceWithKeySchema(ServiceSchema):
api_key: str api_key: str
class ServiceGetServiceResponse(CustomBaseModel): class ServiceGetServiceResponse(CustomBaseModel):
services: list[ServiceSchema] services: list[ServiceSummary]
class ServicePostServiceRequest(CustomBaseModel): class ServicePostServiceRequest(CustomBaseModel):
name: str name: str = Field(min_length=3)
class ServicePostServiceResponse(CustomBaseModel): class ServicePostServiceResponse(CustomBaseModel):

View file

@ -215,7 +215,7 @@ def generate_query_and_status(params) -> list[tuple[str, int]]:
def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]: def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]:
possible_values_int = [0, -1, 42, "banana", ""] possible_values_int = [0, -1, 42, "banana", ""]
possible_values_str = [0] possible_values_str = [0, "", "a"]
defaults = [{param: 1 for param in params.keys()}] defaults = [{param: 1 for param in params.keys()}]

View file

@ -14,15 +14,15 @@ pytestmark = [
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_auth_approval(default_client: AsyncClient): async def test_get_org_auth_approval(no_su_client: AsyncClient):
resp = await default_client.get("/org?org_id=3") resp = await no_su_client.get("/org?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_questionnaire_auth_approval(default_client: AsyncClient): async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient):
resp = await default_client.patch( resp = await no_su_client.patch(
"/org/questionnaire", "/org/questionnaire",
json={ json={
"organisation_id": 3, "organisation_id": 3,
@ -39,56 +39,29 @@ async def test_patch_org_questionnaire_auth_approval(default_client: AsyncClient
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_status_auth_approval(default_client: AsyncClient): async def test_get_org_users_auth_approval(no_su_client: AsyncClient):
resp = await default_client.patch( resp = await no_su_client.get("/org/users?org_id=3")
"/org/status", json={"organisation_id": 3, "status": "submitted"} assert resp.status_code != 422
) assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_groups_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/groups?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_contact_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_users_auth_approval(default_client: AsyncClient): async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient):
resp = await default_client.get("/org/users?org_id=3") resp = await no_su_client.patch(
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_post_org_user_auth_approval(default_client: AsyncClient):
resp = await default_client.post(
"/org/user", json={"organisation_id": 3, "user_id": 2}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_patch_org_root_user_auth_approval(default_client: AsyncClient):
resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 3, "user_id": 2}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_groups_auth_approval(default_client: AsyncClient):
resp = await default_client.get("/org/groups?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_contact_auth_approval(default_client: AsyncClient):
resp = await default_client.get("/org/contact?org_id=3&contact_type=billing")
assert resp.status_code != 422
assert resp.status_code == 200
@pytest.mark.anyio
async def test_patch_org_contact_auth_approval(default_client: AsyncClient):
resp = await default_client.patch(
"/org/contact", "/org/contact",
json={ json={
"organisation_id": 3, "organisation_id": 3,
@ -101,29 +74,29 @@ async def test_patch_org_contact_auth_approval(default_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_service_auth_approval(default_client: AsyncClient): async def test_get_service_auth_approval(no_su_client: AsyncClient):
resp = await default_client.get("/service?org_id=3") resp = await no_su_client.get("/service?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_group_permissions_auth_approval(default_client: AsyncClient): async def test_get_iam_group_permissions_auth_approval(no_su_client: AsyncClient):
resp = await default_client.get("/iam/group/permissions?org_id=3&group_id=1") resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_group_users_auth_approval(default_client: AsyncClient): async def test_get_iam_group_users_auth_approval(no_su_client: AsyncClient):
resp = await default_client.get("/iam/group/users?org_id=3&group_id=1") resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_group_auth_approval(default_client: AsyncClient): async def test_post_iam_group_auth_approval(no_su_client: AsyncClient):
resp = await default_client.post( resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 3} "/iam/group", json={"name": "New Group", "organisation_id": 3}
) )
assert resp.status_code != 422 assert resp.status_code != 422
@ -131,8 +104,8 @@ async def test_post_iam_group_auth_approval(default_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_permission_auth_approval(default_client: AsyncClient): async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient):
resp = await default_client.put( resp = await no_su_client.put(
"/iam/group/permission", "/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 3}, json={"permission_id": 1, "group_id": 2, "organisation_id": 3},
) )
@ -141,8 +114,8 @@ async def test_put_iam_group_permission_auth_approval(default_client: AsyncClien
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_user_auth_approval(default_client: AsyncClient): async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient):
resp = await default_client.put( resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3} "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3}
) )
assert resp.status_code != 422 assert resp.status_code != 422
@ -150,15 +123,15 @@ async def test_put_iam_group_user_auth_approval(default_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_permissions_auth_approval(default_client: AsyncClient): async def test_get_iam_permissions_auth_approval(no_su_client: AsyncClient):
resp = await default_client.get("/iam/permissions?org_id=3") resp = await no_su_client.get("/iam/permissions?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_permissions_search_auth_approval(default_client: AsyncClient): async def test_post_iam_permissions_search_auth_approval(no_su_client: AsyncClient):
resp = await default_client.post( resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 3, "action": "read"} "/iam/permissions/search", json={"organisation_id": 3, "action": "read"}
) )
assert resp.status_code != 422 assert resp.status_code != 422