diff --git a/src/auth/dependencies.py b/src/auth/dependencies.py index ddba545..aaabdbb 100644 --- a/src/auth/dependencies.py +++ b/src/auth/dependencies.py @@ -9,8 +9,9 @@ Exports: """ from typing import Annotated -from fastapi import Depends +from fastapi import Depends, Request +from src.auth.service import org_status_check from src.exceptions import ForbiddenException from src.user.dependencies import user_model_claims_dependency from src.user.models import User @@ -37,16 +38,19 @@ async def org_query_root_claims( user_model: user_model_claims_dependency, org_model: org_model_query_dependency, su_emails: su_list_dependency, + request: Request, ): - if org_model.root_user_id == user_model.id: - return org_model - try: if await user_model_super_admin(user_model, su_emails): return org_model except ForbiddenException: pass + await org_status_check(org_model, request) + + if org_model.root_user_id == user_model.id: + return org_model + raise ForbiddenException(message="Must be the org's root user") @@ -59,16 +63,19 @@ async def org_body_root_claims( user_model: user_model_claims_dependency, org_model: org_model_body_dependency, su_emails: su_list_dependency, + request: Request, ): - if org_model.root_user_id == user_model.id: - return org_model - try: if await user_model_super_admin(user_model, su_emails): return org_model except ForbiddenException: pass + await org_status_check(org_model, request) + + if org_model.root_user_id == user_model.id: + return org_model + raise ForbiddenException(message="Must be the org's root user") diff --git a/src/auth/service.py b/src/auth/service.py index 25c2fa7..aa1b060 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -14,10 +14,13 @@ from joserfc.errors import ExpiredTokenError from joserfc.jwk import KeySet from urllib.request import urlopen -from fastapi import Depends +from fastapi import Depends, Request from fastapi.security import OpenIdConnect -from src.exceptions import UnauthorizedException +from src.organisation.constants import Status as OrgStatus +from src.organisation.exceptions import AwaitingApprovalException +from src.organisation.models import Organisation as Org +from src.exceptions import UnauthorizedException, ForbiddenException from src.auth.config import auth_settings from src.user.service import add_user_to_db from src.database import db_dependency @@ -27,7 +30,7 @@ oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc_dependency = Annotated[str, Depends(oidc)] -def get_dev_user(): +async def get_dev_user(): return {"db_id": 1, "email": "chris@sr2.uk"} @@ -61,3 +64,26 @@ async def get_current_user( claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] + + +async def org_status_check(org_model: Org, request: Request): + org_status = OrgStatus(org_model.status) + if org_status.is_blocked: + raise ForbiddenException("This organisation cannot perform this action.") + + root = "/api/v1" + + pre_approval_endpoints = [ + f"PATCH{root}/org/status", + f"PATCH{root}/org/questionnaire", + f"GET{root}/org", + f"GET{root}/org/contact", + f"PATCH{root}/org/contact", + f"DELETE{root}/org/self", + ] + current_request = f"{request.method}{request.url.path}" + if ( + current_request not in pre_approval_endpoints + and org_model.status != OrgStatus.APPROVED + ): + raise AwaitingApprovalException(org_model.id) diff --git a/src/organisation/dependencies.py b/src/organisation/dependencies.py index 20c50a8..4c22685 100644 --- a/src/organisation/dependencies.py +++ b/src/organisation/dependencies.py @@ -7,65 +7,38 @@ Exports: """ from typing import Annotated, Optional -from sqlalchemy.orm import Session -from fastapi import Depends, Query, Request +from fastapi import Depends, Query from src.database import db_dependency -from src.exceptions import ForbiddenException from src.organisation.schemas import OrgIDMixin from src.organisation.models import Organisation as Org -from src.organisation.exceptions import OrgNotFoundException, AwaitingApprovalException -from src.organisation.constants import Status as OrgStatus - - -def get_org_model(db: Session, request: Request, org_id: int): - org_model = db.get(Org, org_id) - if org_model is None: - raise OrgNotFoundException(org_id) - - org_status = OrgStatus(org_model.status) - if org_status.is_blocked: - raise ForbiddenException("This organisation cannot perform this action.") - - root = "/api/v1" - - pre_approval_endpoints = [ - f"PATCH{root}/org/status", - f"PATCH{root}/org/questionnaire", - f"GET{root}/org", - f"GET{root}/org/contact", - f"PATCH{root}/org/contact", - f"DELETE{root}/org/self", - ] - current_request = f"{request.method}{request.url.path}" - if ( - current_request not in pre_approval_endpoints - and org_model.status != OrgStatus.APPROVED - ): - raise AwaitingApprovalException(org_id) - - return org_model +from src.organisation.exceptions import OrgNotFoundException def get_org_model_query( - db: db_dependency, request: Request, org_id: Annotated[int, Query(gt=0)] + db: db_dependency, org_id: Annotated[int, Query(gt=0)] ) -> type[Org]: - return get_org_model(db, request, org_id) + org_model = db.get(Org, org_id) + if org_model is None: + raise OrgNotFoundException(org_id) + return org_model org_model_query_dependency = Annotated[type[Org], Depends(get_org_model_query)] -def get_org_model_body( - db: db_dependency, request: Request, request_model: OrgIDMixin -) -> type[Org]: +def get_org_model_body(db: db_dependency, request_model: OrgIDMixin) -> type[Org]: org_id: Optional[int] = getattr(request_model, "organisation_id", None) if org_id is None: raise OrgNotFoundException() - return get_org_model(db, request, org_id) + org_model = db.get(Org, org_id) + if org_model is None: + raise OrgNotFoundException(org_id) + + return org_model org_model_body_dependency = Annotated[type[Org], Depends(get_org_model_body)] diff --git a/src/organisation/schemas.py b/src/organisation/schemas.py index 4641118..522158e 100644 --- a/src/organisation/schemas.py +++ b/src/organisation/schemas.py @@ -9,7 +9,7 @@ Models follow the nomenclature of: from typing import Optional from datetime import datetime -from pydantic import EmailStr, ConfigDict +from pydantic import EmailStr, ConfigDict, Field from src.schemas import ( CustomBaseModel, @@ -55,7 +55,7 @@ class OrgSchema(OrgIDMixin): class OrgPostOrgRequest(CustomBaseModel): - name: str + name: str = Field(min_length=3) intake_questionnaire: Optional[CurrentQuestions] = None diff --git a/src/schemas.py b/src/schemas.py index 0244c11..cb2e742 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -55,3 +55,8 @@ class GroupSummary(CustomBaseModel): class UserSummary(CustomBaseModel): id: int email: str + + +class ServiceSummary(CustomBaseModel): + id: int + name: str diff --git a/src/service/schemas.py b/src/service/schemas.py index 531951f..71bb215 100644 --- a/src/service/schemas.py +++ b/src/service/schemas.py @@ -6,28 +6,21 @@ Models follow the nomenclature of: - Models: "" ie "ServiceGetServiceResponse" """ -from pydantic import ConfigDict +from pydantic import Field -from src.schemas import CustomBaseModel, ServiceIDMixin +from src.schemas import CustomBaseModel, ServiceIDMixin, ServiceSummary -class ServiceSchema(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") - - id: int - name: str - - -class ServiceWithKeySchema(ServiceSchema): +class ServiceWithKeySchema(ServiceSummary): api_key: str class ServiceGetServiceResponse(CustomBaseModel): - services: list[ServiceSchema] + services: list[ServiceSummary] class ServicePostServiceRequest(CustomBaseModel): - name: str + name: str = Field(min_length=3) class ServicePostServiceResponse(CustomBaseModel): diff --git a/test/conftest.py b/test/conftest.py index 750d2f5..a36e9a3 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -215,7 +215,7 @@ def generate_query_and_status(params) -> list[tuple[str, int]]: def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]: possible_values_int = [0, -1, 42, "banana", ""] - possible_values_str = [0] + possible_values_str = [0, "", "a"] defaults = [{param: 1 for param in params.keys()}] diff --git a/test/test_auth_approval.py b/test/test_auth_approval.py index 82ea50a..3f4dd03 100644 --- a/test/test_auth_approval.py +++ b/test/test_auth_approval.py @@ -14,15 +14,15 @@ pytestmark = [ @pytest.mark.anyio -async def test_get_org_auth_approval(default_client: AsyncClient): - resp = await default_client.get("/org?org_id=3") +async def test_get_org_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.get("/org?org_id=3") assert resp.status_code != 422 assert resp.status_code == 200 @pytest.mark.anyio -async def test_patch_org_questionnaire_auth_approval(default_client: AsyncClient): - resp = await default_client.patch( +async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.patch( "/org/questionnaire", json={ "organisation_id": 3, @@ -39,56 +39,29 @@ async def test_patch_org_questionnaire_auth_approval(default_client: AsyncClient @pytest.mark.anyio -async def test_patch_org_status_auth_approval(default_client: AsyncClient): - resp = await default_client.patch( - "/org/status", json={"organisation_id": 3, "status": "submitted"} - ) +async def test_get_org_users_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.get("/org/users?org_id=3") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] + + +@pytest.mark.anyio +async def test_get_org_groups_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.get("/org/groups?org_id=3") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] + + +@pytest.mark.anyio +async def test_get_org_contact_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing") assert resp.status_code != 422 assert resp.status_code == 200 @pytest.mark.anyio -async def test_get_org_users_auth_approval(default_client: AsyncClient): - resp = await default_client.get("/org/users?org_id=3") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] - - -@pytest.mark.anyio -async def test_post_org_user_auth_approval(default_client: AsyncClient): - resp = await default_client.post( - "/org/user", json={"organisation_id": 3, "user_id": 2} - ) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] - - -@pytest.mark.anyio -async def test_patch_org_root_user_auth_approval(default_client: AsyncClient): - resp = await default_client.patch( - "/org/root_user", json={"organisation_id": 3, "user_id": 2} - ) - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] - - -@pytest.mark.anyio -async def test_get_org_groups_auth_approval(default_client: AsyncClient): - resp = await default_client.get("/org/groups?org_id=3") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] - - -@pytest.mark.anyio -async def test_get_org_contact_auth_approval(default_client: AsyncClient): - resp = await default_client.get("/org/contact?org_id=3&contact_type=billing") - assert resp.status_code != 422 - assert resp.status_code == 200 - - -@pytest.mark.anyio -async def test_patch_org_contact_auth_approval(default_client: AsyncClient): - resp = await default_client.patch( +async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.patch( "/org/contact", json={ "organisation_id": 3, @@ -101,29 +74,29 @@ async def test_patch_org_contact_auth_approval(default_client: AsyncClient): @pytest.mark.anyio -async def test_get_service_auth_approval(default_client: AsyncClient): - resp = await default_client.get("/service?org_id=3") +async def test_get_service_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.get("/service?org_id=3") assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_get_iam_group_permissions_auth_approval(default_client: AsyncClient): - resp = await default_client.get("/iam/group/permissions?org_id=3&group_id=1") +async def test_get_iam_group_permissions_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1") assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_get_iam_group_users_auth_approval(default_client: AsyncClient): - resp = await default_client.get("/iam/group/users?org_id=3&group_id=1") +async def test_get_iam_group_users_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1") assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_post_iam_group_auth_approval(default_client: AsyncClient): - resp = await default_client.post( +async def test_post_iam_group_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.post( "/iam/group", json={"name": "New Group", "organisation_id": 3} ) assert resp.status_code != 422 @@ -131,8 +104,8 @@ async def test_post_iam_group_auth_approval(default_client: AsyncClient): @pytest.mark.anyio -async def test_put_iam_group_permission_auth_approval(default_client: AsyncClient): - resp = await default_client.put( +async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.put( "/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 3}, ) @@ -141,8 +114,8 @@ async def test_put_iam_group_permission_auth_approval(default_client: AsyncClien @pytest.mark.anyio -async def test_put_iam_group_user_auth_approval(default_client: AsyncClient): - resp = await default_client.put( +async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.put( "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3} ) assert resp.status_code != 422 @@ -150,15 +123,15 @@ async def test_put_iam_group_user_auth_approval(default_client: AsyncClient): @pytest.mark.anyio -async def test_get_iam_permissions_auth_approval(default_client: AsyncClient): - resp = await default_client.get("/iam/permissions?org_id=3") +async def test_get_iam_permissions_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.get("/iam/permissions?org_id=3") assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_post_iam_permissions_search_auth_approval(default_client: AsyncClient): - resp = await default_client.post( +async def test_post_iam_permissions_search_auth_approval(no_su_client: AsyncClient): + resp = await no_su_client.post( "/iam/permissions/search", json={"organisation_id": 3, "action": "read"} ) assert resp.status_code != 422