diff --git a/src/iam/router.py b/src/iam/router.py index 2895baa..316c5fa 100644 --- a/src/iam/router.py +++ b/src/iam/router.py @@ -19,7 +19,7 @@ from fastapi import APIRouter, status from sqlalchemy.exc import IntegrityError from psycopg import errors -from src.service.exceptions import ServiceNotFoundException +from service.exceptions import ServiceNotFoundException from src.exceptions import ConflictException from src.database import db_dependency from src.schemas import ResourceName diff --git a/src/iam/schemas.py b/src/iam/schemas.py index 0d370b8..fa6adfc 100644 --- a/src/iam/schemas.py +++ b/src/iam/schemas.py @@ -13,7 +13,7 @@ from pydantic import EmailStr, ConfigDict, Field from src.service.schemas import ServiceIDMixin from src.organisation.schemas import OrgIDMixin from src.schemas import CustomBaseModel -from src.user.schemas import UserIDMixin +from user.schemas import UserIDMixin class UserSchema(CustomBaseModel): diff --git a/src/organisation/dependencies.py b/src/organisation/dependencies.py index 728b8d0..ec8805c 100644 --- a/src/organisation/dependencies.py +++ b/src/organisation/dependencies.py @@ -25,7 +25,7 @@ def get_org_model(db: Session, request: Request, org_id: int): root = "/api/v1" - pre_approval_endpoints = [f"PATCH{root}/org/status", f"PATCH{root}/org/questionnaire", f"GET{root}/org", f"GET{root}/org/contact", f"PATCH{root}/org/contact"] + pre_approval_endpoints = [f"PATCH{root}/org/status", f"PATCH{root}/org/questionnaire", f"GET{root}/org/id", f"GET{root}/org/contact", f"PATCH{root}/org/contact"] current_request = f"{request.method}{request.url.path}" if current_request not in pre_approval_endpoints and org_model.status != OrgStatus.APPROVED: raise AwaitingApprovalException(org_id) diff --git a/src/organisation/router.py b/src/organisation/router.py index a9672b1..65d2d71 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -39,7 +39,7 @@ from src.organisation.schemas import OrgPostOrgRequest, OrgPatchQuestionnaireReq OrgPatchContactRequest, \ OrgPostUserRequest, OrgGetUserResponse, OrgGetContactResponse, OrgGetOrgResponse, OrgPatchRootRequest, \ OrgGetGroupResponse, OrgDeleteUserRequest, OrgDeleteOrgRequest, OrgPostOrgResponse, OrgPatchQuestionnaireResponse, \ - OrgPatchStatusResponse, OrgPostUserResponse, OrgPatchRootResponse, Questionnaire, OrgPatchContactResponse + OrgPatchStatusResponse, OrgPostUserResponse, OrgPatchRootResponse router = APIRouter( prefix="/org", @@ -47,7 +47,7 @@ router = APIRouter( ) -@router.get("", +@router.get("/id", summary="Get org details by ID.", response_model=OrgGetOrgResponse, status_code=status.HTTP_200_OK, @@ -67,14 +67,13 @@ async def get_org_by_id(org_model: org_model_root_claim_query_dependency): "owner_contact": org_model.owner_contact_rel.email, "billing_contact": org_model.billing_contact_rel.email, "security_contact": org_model.security_contact_rel.email, - "root_user": org_model.root_user_email, - "intake_questionnaire": org_model.intake_questionnaire + "root_user": org_model.root_user_email } return response -@router.post("", +@router.post("/", summary="Create new organisation.", status_code=status.HTTP_201_CREATED, response_model=OrgPostOrgResponse, @@ -131,21 +130,12 @@ async def update_questionnaire(db: db_dependency, org_model: org_model_root_clai The partial bool allows for submission of partially completed questionnaire and/or final "are you sure" check before setting the org to be in "submitted" status, awaiting admin approval. """ - update_data = request_model.intake_questionnaire.model_dump(exclude_none=True) - questionnaire_model = Questionnaire(**org_model.intake_questionnaire) - for key, value in update_data.items(): - if hasattr(questionnaire_model, key): - setattr(questionnaire_model, key, value) - else: - if key == "partial" or key == "organisation_id": - continue - raise UnprocessableContentException("Invalid keys in update request") + org_model.intake_questionnaire = request_model.intake_questionnaire.model_dump() # Allows for partially completed questionnaires to be saved without being submitted for review if not request_model.partial: org_model.status = "submitted" - org_model.intake_questionnaire = questionnaire_model.model_dump() db.flush() response = OrgPatchQuestionnaireResponse(**org_model.__dict__) db.commit() @@ -185,7 +175,7 @@ async def get_users(org_model: org_model_root_claim_query_dependency): """ Returns a list of the email addresses of all users of the organisation. """ - return {"users": [user.email for user in org_model.user_rel], "organisation": org_model} + return {"users": [user.email for user in org_model.user_rel]} @router.post("/user", @@ -211,7 +201,7 @@ async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_bod return response -@router.delete("", +@router.delete("/", summary="Delete organisation from the hub.", status_code=status.HTTP_204_NO_CONTENT, responses={ @@ -313,13 +303,13 @@ async def get_contact(org_model: org_model_root_claim_query_dependency, contact_ address = ContactAddress.model_validate(contact_model) contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address) - return {"contact": contact_response, "organisation": org_model} + return {"contact": contact_response} @router.patch("/contact", summary="Update contact for organisation.", status_code=status.HTTP_200_OK, - response_model=OrgPatchContactResponse, + response_model=OrgGetContactResponse, responses={ status.HTTP_200_OK: {"description": "Successfully updated contact."}, status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, @@ -357,4 +347,4 @@ async def update_contact(db: db_dependency, org_model: org_model_root_claim_body db.commit() - return {"contact": contact_response, "organisation": org_model} + return {"contact": contact_response} diff --git a/src/organisation/schemas.py b/src/organisation/schemas.py index e2fd6a4..c34ef16 100644 --- a/src/organisation/schemas.py +++ b/src/organisation/schemas.py @@ -17,17 +17,13 @@ from src.user.schemas import UserIDMixin from src.organisation.constants import Status, ContactType -class OrgIDMixin(CustomBaseModel): - organisation_id: int = Field(gt=0) - class Questionnaire(CustomBaseModel): question_one: Optional[str] = None question_two: Optional[str] = None question_three: Optional[str] = None -class OrgSchema(CustomBaseModel): - id: int - name: str +class OrgIDMixin(CustomBaseModel): + organisation_id: int = Field(gt=0) class OrgPostOrgRequest(CustomBaseModel): @@ -88,7 +84,6 @@ class OrgPatchRootResponse(CustomBaseModel): class OrgGetUserResponse(CustomBaseModel): users: list[str] - organisation: OrgSchema class OrgGetGroupResponse(CustomBaseModel): groups: list[str] @@ -97,13 +92,6 @@ class OrgGetContactResponse(CustomBaseModel): model_config = ConfigDict(from_attributes=True, extra="ignore") contact: ContactModel - organisation: OrgSchema - -class OrgPatchContactResponse(CustomBaseModel): - model_config = ConfigDict(from_attributes=True, extra="ignore") - - contact: ContactModel - organisation: OrgSchema class OrgGetOrgResponse(CustomBaseModel): name: str @@ -112,7 +100,6 @@ class OrgGetOrgResponse(CustomBaseModel): owner_contact: Optional[str] = None billing_contact: Optional[str] = None security_contact: Optional[str] = None - intake_questionnaire: Optional[Questionnaire] = None class OrgDeleteOrgRequest(OrgIDMixin): pass diff --git a/test/test_organisation.py b/test/test_organisation.py index 6dfaa21..a629d04 100644 --- a/test/test_organisation.py +++ b/test/test_organisation.py @@ -11,7 +11,7 @@ from .conftest import client @pytest.mark.anyio async def test_get_org_success(client: AsyncClient): - resp = await client.get("/org?org_id=1") + resp = await client.get("/org/id?org_id=1") data = resp.json() assert resp.status_code == 200 @@ -33,14 +33,14 @@ async def test_get_org_success(client: AsyncClient): ) @pytest.mark.anyio async def test_get_org_failure(client: AsyncClient, query: str, expected_status: int): - resp = await client.get(f"/org?{query}") + resp = await client.get(f"/org/id?{query}") assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_org_success(client: AsyncClient): - resp = await client.post("/org", json={"name": "New Test Org"}) + resp = await client.post("/org/", json={"name": "New Test Org"}) data = resp.json() assert resp.status_code == 201 @@ -58,7 +58,7 @@ async def test_post_org_success(client: AsyncClient): ) @pytest.mark.anyio async def test_post_org_failure(client: AsyncClient, body: dict[str, str], expected_status: int): - resp = await client.post("/org", json=body) + resp = await client.post("/org/", json=body) assert resp.status_code == expected_status @@ -76,7 +76,7 @@ async def test_patch_org_questionnaire_partial_success(client: AsyncClient, db_s assert data["name"] == "Test Org" assert data["intake_questionnaire"]["question_one"] == "new answer one" assert data["status"] == "partial" - assert data["intake_questionnaire"]["question_two"] == "answer two" + # assert type(data["intake_questionnaire"]["question_two"]) == str assert data["intake_questionnaire"]["question_three"] is None @@ -159,10 +159,6 @@ async def test_get_org_users_success(client: AsyncClient): assert len(data["users"]) == 1 assert data["users"][0] == "admin@test.com" - assert "organisation" in data - assert data["organisation"]["name"] == "Test Org" - assert data["organisation"]["id"] == 1 - @pytest.mark.parametrize( "query, expected_status",