diff --git a/src/auth/dependencies.py b/src/auth/dependencies.py index aaabdbb..ddba545 100644 --- a/src/auth/dependencies.py +++ b/src/auth/dependencies.py @@ -9,9 +9,8 @@ Exports: """ from typing import Annotated -from fastapi import Depends, Request +from fastapi import Depends -from src.auth.service import org_status_check from src.exceptions import ForbiddenException from src.user.dependencies import user_model_claims_dependency from src.user.models import User @@ -38,19 +37,16 @@ async def org_query_root_claims( user_model: user_model_claims_dependency, org_model: org_model_query_dependency, su_emails: su_list_dependency, - request: Request, ): + if org_model.root_user_id == user_model.id: + return org_model + try: if await user_model_super_admin(user_model, su_emails): return org_model except ForbiddenException: pass - await org_status_check(org_model, request) - - if org_model.root_user_id == user_model.id: - return org_model - raise ForbiddenException(message="Must be the org's root user") @@ -63,19 +59,16 @@ async def org_body_root_claims( user_model: user_model_claims_dependency, org_model: org_model_body_dependency, su_emails: su_list_dependency, - request: Request, ): + if org_model.root_user_id == user_model.id: + return org_model + try: if await user_model_super_admin(user_model, su_emails): return org_model except ForbiddenException: pass - await org_status_check(org_model, request) - - if org_model.root_user_id == user_model.id: - return org_model - raise ForbiddenException(message="Must be the org's root user") diff --git a/src/auth/service.py b/src/auth/service.py index aa1b060..25c2fa7 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -14,13 +14,10 @@ from joserfc.errors import ExpiredTokenError from joserfc.jwk import KeySet from urllib.request import urlopen -from fastapi import Depends, Request +from fastapi import Depends from fastapi.security import OpenIdConnect -from src.organisation.constants import Status as OrgStatus -from src.organisation.exceptions import AwaitingApprovalException -from src.organisation.models import Organisation as Org -from src.exceptions import UnauthorizedException, ForbiddenException +from src.exceptions import UnauthorizedException from src.auth.config import auth_settings from src.user.service import add_user_to_db from src.database import db_dependency @@ -30,7 +27,7 @@ oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc_dependency = Annotated[str, Depends(oidc)] -async def get_dev_user(): +def get_dev_user(): return {"db_id": 1, "email": "chris@sr2.uk"} @@ -64,26 +61,3 @@ async def get_current_user( claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] - - -async def org_status_check(org_model: Org, request: Request): - org_status = OrgStatus(org_model.status) - if org_status.is_blocked: - raise ForbiddenException("This organisation cannot perform this action.") - - root = "/api/v1" - - pre_approval_endpoints = [ - f"PATCH{root}/org/status", - f"PATCH{root}/org/questionnaire", - f"GET{root}/org", - f"GET{root}/org/contact", - f"PATCH{root}/org/contact", - f"DELETE{root}/org/self", - ] - current_request = f"{request.method}{request.url.path}" - if ( - current_request not in pre_approval_endpoints - and org_model.status != OrgStatus.APPROVED - ): - raise AwaitingApprovalException(org_model.id) diff --git a/src/organisation/dependencies.py b/src/organisation/dependencies.py index 4c22685..20c50a8 100644 --- a/src/organisation/dependencies.py +++ b/src/organisation/dependencies.py @@ -7,38 +7,65 @@ Exports: """ from typing import Annotated, Optional +from sqlalchemy.orm import Session -from fastapi import Depends, Query +from fastapi import Depends, Query, Request from src.database import db_dependency +from src.exceptions import ForbiddenException from src.organisation.schemas import OrgIDMixin from src.organisation.models import Organisation as Org -from src.organisation.exceptions import OrgNotFoundException +from src.organisation.exceptions import OrgNotFoundException, AwaitingApprovalException +from src.organisation.constants import Status as OrgStatus -def get_org_model_query( - db: db_dependency, org_id: Annotated[int, Query(gt=0)] -) -> type[Org]: +def get_org_model(db: Session, request: Request, org_id: int): org_model = db.get(Org, org_id) if org_model is None: raise OrgNotFoundException(org_id) + + org_status = OrgStatus(org_model.status) + if org_status.is_blocked: + raise ForbiddenException("This organisation cannot perform this action.") + + root = "/api/v1" + + pre_approval_endpoints = [ + f"PATCH{root}/org/status", + f"PATCH{root}/org/questionnaire", + f"GET{root}/org", + f"GET{root}/org/contact", + f"PATCH{root}/org/contact", + f"DELETE{root}/org/self", + ] + current_request = f"{request.method}{request.url.path}" + if ( + current_request not in pre_approval_endpoints + and org_model.status != OrgStatus.APPROVED + ): + raise AwaitingApprovalException(org_id) + return org_model +def get_org_model_query( + db: db_dependency, request: Request, org_id: Annotated[int, Query(gt=0)] +) -> type[Org]: + return get_org_model(db, request, org_id) + + org_model_query_dependency = Annotated[type[Org], Depends(get_org_model_query)] -def get_org_model_body(db: db_dependency, request_model: OrgIDMixin) -> type[Org]: +def get_org_model_body( + db: db_dependency, request: Request, request_model: OrgIDMixin +) -> type[Org]: org_id: Optional[int] = getattr(request_model, "organisation_id", None) if org_id is None: raise OrgNotFoundException() - org_model = db.get(Org, org_id) - if org_model is None: - raise OrgNotFoundException(org_id) - - return org_model + return get_org_model(db, request, org_id) org_model_body_dependency = Annotated[type[Org], Depends(get_org_model_body)] diff --git a/src/organisation/schemas.py b/src/organisation/schemas.py index 522158e..4641118 100644 --- a/src/organisation/schemas.py +++ b/src/organisation/schemas.py @@ -9,7 +9,7 @@ Models follow the nomenclature of: from typing import Optional from datetime import datetime -from pydantic import EmailStr, ConfigDict, Field +from pydantic import EmailStr, ConfigDict from src.schemas import ( CustomBaseModel, @@ -55,7 +55,7 @@ class OrgSchema(OrgIDMixin): class OrgPostOrgRequest(CustomBaseModel): - name: str = Field(min_length=3) + name: str intake_questionnaire: Optional[CurrentQuestions] = None diff --git a/src/schemas.py b/src/schemas.py index cb2e742..0244c11 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -55,8 +55,3 @@ class GroupSummary(CustomBaseModel): class UserSummary(CustomBaseModel): id: int email: str - - -class ServiceSummary(CustomBaseModel): - id: int - name: str diff --git a/src/service/schemas.py b/src/service/schemas.py index 71bb215..531951f 100644 --- a/src/service/schemas.py +++ b/src/service/schemas.py @@ -6,21 +6,28 @@ Models follow the nomenclature of: - Models: "" ie "ServiceGetServiceResponse" """ -from pydantic import Field +from pydantic import ConfigDict -from src.schemas import CustomBaseModel, ServiceIDMixin, ServiceSummary +from src.schemas import CustomBaseModel, ServiceIDMixin -class ServiceWithKeySchema(ServiceSummary): +class ServiceSchema(CustomBaseModel): + model_config = ConfigDict(from_attributes=True, extra="ignore") + + id: int + name: str + + +class ServiceWithKeySchema(ServiceSchema): api_key: str class ServiceGetServiceResponse(CustomBaseModel): - services: list[ServiceSummary] + services: list[ServiceSchema] class ServicePostServiceRequest(CustomBaseModel): - name: str = Field(min_length=3) + name: str class ServicePostServiceResponse(CustomBaseModel): diff --git a/test/conftest.py b/test/conftest.py index a36e9a3..750d2f5 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -215,7 +215,7 @@ def generate_query_and_status(params) -> list[tuple[str, int]]: def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]: possible_values_int = [0, -1, 42, "banana", ""] - possible_values_str = [0, "", "a"] + possible_values_str = [0] defaults = [{param: 1 for param in params.keys()}] diff --git a/test/test_auth_approval.py b/test/test_auth_approval.py index 3f4dd03..82ea50a 100644 --- a/test/test_auth_approval.py +++ b/test/test_auth_approval.py @@ -14,15 +14,15 @@ pytestmark = [ @pytest.mark.anyio -async def test_get_org_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/org?org_id=3") +async def test_get_org_auth_approval(default_client: AsyncClient): + resp = await default_client.get("/org?org_id=3") assert resp.status_code != 422 assert resp.status_code == 200 @pytest.mark.anyio -async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.patch( +async def test_patch_org_questionnaire_auth_approval(default_client: AsyncClient): + resp = await default_client.patch( "/org/questionnaire", json={ "organisation_id": 3, @@ -39,29 +39,56 @@ async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient): @pytest.mark.anyio -async def test_get_org_users_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/users?org_id=3") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] - - -@pytest.mark.anyio -async def test_get_org_groups_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/groups?org_id=3") - assert resp.status_code != 422 - assert "has not been approved." in resp.json()["detail"] - - -@pytest.mark.anyio -async def test_get_org_contact_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing") +async def test_patch_org_status_auth_approval(default_client: AsyncClient): + resp = await default_client.patch( + "/org/status", json={"organisation_id": 3, "status": "submitted"} + ) assert resp.status_code != 422 assert resp.status_code == 200 @pytest.mark.anyio -async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.patch( +async def test_get_org_users_auth_approval(default_client: AsyncClient): + resp = await default_client.get("/org/users?org_id=3") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] + + +@pytest.mark.anyio +async def test_post_org_user_auth_approval(default_client: AsyncClient): + resp = await default_client.post( + "/org/user", json={"organisation_id": 3, "user_id": 2} + ) + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] + + +@pytest.mark.anyio +async def test_patch_org_root_user_auth_approval(default_client: AsyncClient): + resp = await default_client.patch( + "/org/root_user", json={"organisation_id": 3, "user_id": 2} + ) + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] + + +@pytest.mark.anyio +async def test_get_org_groups_auth_approval(default_client: AsyncClient): + resp = await default_client.get("/org/groups?org_id=3") + assert resp.status_code != 422 + assert "has not been approved." in resp.json()["detail"] + + +@pytest.mark.anyio +async def test_get_org_contact_auth_approval(default_client: AsyncClient): + resp = await default_client.get("/org/contact?org_id=3&contact_type=billing") + assert resp.status_code != 422 + assert resp.status_code == 200 + + +@pytest.mark.anyio +async def test_patch_org_contact_auth_approval(default_client: AsyncClient): + resp = await default_client.patch( "/org/contact", json={ "organisation_id": 3, @@ -74,29 +101,29 @@ async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient): @pytest.mark.anyio -async def test_get_service_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/service?org_id=3") +async def test_get_service_auth_approval(default_client: AsyncClient): + resp = await default_client.get("/service?org_id=3") assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_get_iam_group_permissions_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1") +async def test_get_iam_group_permissions_auth_approval(default_client: AsyncClient): + resp = await default_client.get("/iam/group/permissions?org_id=3&group_id=1") assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_get_iam_group_users_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1") +async def test_get_iam_group_users_auth_approval(default_client: AsyncClient): + resp = await default_client.get("/iam/group/users?org_id=3&group_id=1") assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_post_iam_group_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.post( +async def test_post_iam_group_auth_approval(default_client: AsyncClient): + resp = await default_client.post( "/iam/group", json={"name": "New Group", "organisation_id": 3} ) assert resp.status_code != 422 @@ -104,8 +131,8 @@ async def test_post_iam_group_auth_approval(no_su_client: AsyncClient): @pytest.mark.anyio -async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.put( +async def test_put_iam_group_permission_auth_approval(default_client: AsyncClient): + resp = await default_client.put( "/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 3}, ) @@ -114,8 +141,8 @@ async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient) @pytest.mark.anyio -async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.put( +async def test_put_iam_group_user_auth_approval(default_client: AsyncClient): + resp = await default_client.put( "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3} ) assert resp.status_code != 422 @@ -123,15 +150,15 @@ async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient): @pytest.mark.anyio -async def test_get_iam_permissions_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.get("/iam/permissions?org_id=3") +async def test_get_iam_permissions_auth_approval(default_client: AsyncClient): + resp = await default_client.get("/iam/permissions?org_id=3") assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_post_iam_permissions_search_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.post( +async def test_post_iam_permissions_search_auth_approval(default_client: AsyncClient): + resp = await default_client.post( "/iam/permissions/search", json={"organisation_id": 3, "action": "read"} ) assert resp.status_code != 422