diff --git a/src/iam/router.py b/src/iam/router.py index 316c5fa..2895baa 100644 --- a/src/iam/router.py +++ b/src/iam/router.py @@ -19,7 +19,7 @@ from fastapi import APIRouter, status from sqlalchemy.exc import IntegrityError from psycopg import errors -from service.exceptions import ServiceNotFoundException +from src.service.exceptions import ServiceNotFoundException from src.exceptions import ConflictException from src.database import db_dependency from src.schemas import ResourceName diff --git a/src/iam/schemas.py b/src/iam/schemas.py index fa6adfc..0d370b8 100644 --- a/src/iam/schemas.py +++ b/src/iam/schemas.py @@ -13,7 +13,7 @@ from pydantic import EmailStr, ConfigDict, Field from src.service.schemas import ServiceIDMixin from src.organisation.schemas import OrgIDMixin from src.schemas import CustomBaseModel -from user.schemas import UserIDMixin +from src.user.schemas import UserIDMixin class UserSchema(CustomBaseModel): diff --git a/src/organisation/dependencies.py b/src/organisation/dependencies.py index ec8805c..728b8d0 100644 --- a/src/organisation/dependencies.py +++ b/src/organisation/dependencies.py @@ -25,7 +25,7 @@ def get_org_model(db: Session, request: Request, org_id: int): root = "/api/v1" - pre_approval_endpoints = [f"PATCH{root}/org/status", f"PATCH{root}/org/questionnaire", f"GET{root}/org/id", f"GET{root}/org/contact", f"PATCH{root}/org/contact"] + pre_approval_endpoints = [f"PATCH{root}/org/status", f"PATCH{root}/org/questionnaire", f"GET{root}/org", f"GET{root}/org/contact", f"PATCH{root}/org/contact"] current_request = f"{request.method}{request.url.path}" if current_request not in pre_approval_endpoints and org_model.status != OrgStatus.APPROVED: raise AwaitingApprovalException(org_id) diff --git a/src/organisation/router.py b/src/organisation/router.py index 65d2d71..a9672b1 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -39,7 +39,7 @@ from src.organisation.schemas import OrgPostOrgRequest, OrgPatchQuestionnaireReq OrgPatchContactRequest, \ OrgPostUserRequest, OrgGetUserResponse, OrgGetContactResponse, OrgGetOrgResponse, OrgPatchRootRequest, \ OrgGetGroupResponse, OrgDeleteUserRequest, OrgDeleteOrgRequest, OrgPostOrgResponse, OrgPatchQuestionnaireResponse, \ - OrgPatchStatusResponse, OrgPostUserResponse, OrgPatchRootResponse + OrgPatchStatusResponse, OrgPostUserResponse, OrgPatchRootResponse, Questionnaire, OrgPatchContactResponse router = APIRouter( prefix="/org", @@ -47,7 +47,7 @@ router = APIRouter( ) -@router.get("/id", +@router.get("", summary="Get org details by ID.", response_model=OrgGetOrgResponse, status_code=status.HTTP_200_OK, @@ -67,13 +67,14 @@ async def get_org_by_id(org_model: org_model_root_claim_query_dependency): "owner_contact": org_model.owner_contact_rel.email, "billing_contact": org_model.billing_contact_rel.email, "security_contact": org_model.security_contact_rel.email, - "root_user": org_model.root_user_email + "root_user": org_model.root_user_email, + "intake_questionnaire": org_model.intake_questionnaire } return response -@router.post("/", +@router.post("", summary="Create new organisation.", status_code=status.HTTP_201_CREATED, response_model=OrgPostOrgResponse, @@ -130,12 +131,21 @@ async def update_questionnaire(db: db_dependency, org_model: org_model_root_clai The partial bool allows for submission of partially completed questionnaire and/or final "are you sure" check before setting the org to be in "submitted" status, awaiting admin approval. """ - org_model.intake_questionnaire = request_model.intake_questionnaire.model_dump() + update_data = request_model.intake_questionnaire.model_dump(exclude_none=True) + questionnaire_model = Questionnaire(**org_model.intake_questionnaire) + for key, value in update_data.items(): + if hasattr(questionnaire_model, key): + setattr(questionnaire_model, key, value) + else: + if key == "partial" or key == "organisation_id": + continue + raise UnprocessableContentException("Invalid keys in update request") # Allows for partially completed questionnaires to be saved without being submitted for review if not request_model.partial: org_model.status = "submitted" + org_model.intake_questionnaire = questionnaire_model.model_dump() db.flush() response = OrgPatchQuestionnaireResponse(**org_model.__dict__) db.commit() @@ -175,7 +185,7 @@ async def get_users(org_model: org_model_root_claim_query_dependency): """ Returns a list of the email addresses of all users of the organisation. """ - return {"users": [user.email for user in org_model.user_rel]} + return {"users": [user.email for user in org_model.user_rel], "organisation": org_model} @router.post("/user", @@ -201,7 +211,7 @@ async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_bod return response -@router.delete("/", +@router.delete("", summary="Delete organisation from the hub.", status_code=status.HTTP_204_NO_CONTENT, responses={ @@ -303,13 +313,13 @@ async def get_contact(org_model: org_model_root_claim_query_dependency, contact_ address = ContactAddress.model_validate(contact_model) contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address) - return {"contact": contact_response} + return {"contact": contact_response, "organisation": org_model} @router.patch("/contact", summary="Update contact for organisation.", status_code=status.HTTP_200_OK, - response_model=OrgGetContactResponse, + response_model=OrgPatchContactResponse, responses={ status.HTTP_200_OK: {"description": "Successfully updated contact."}, status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, @@ -347,4 +357,4 @@ async def update_contact(db: db_dependency, org_model: org_model_root_claim_body db.commit() - return {"contact": contact_response} + return {"contact": contact_response, "organisation": org_model} diff --git a/src/organisation/schemas.py b/src/organisation/schemas.py index c34ef16..e2fd6a4 100644 --- a/src/organisation/schemas.py +++ b/src/organisation/schemas.py @@ -17,13 +17,17 @@ from src.user.schemas import UserIDMixin from src.organisation.constants import Status, ContactType +class OrgIDMixin(CustomBaseModel): + organisation_id: int = Field(gt=0) + class Questionnaire(CustomBaseModel): question_one: Optional[str] = None question_two: Optional[str] = None question_three: Optional[str] = None -class OrgIDMixin(CustomBaseModel): - organisation_id: int = Field(gt=0) +class OrgSchema(CustomBaseModel): + id: int + name: str class OrgPostOrgRequest(CustomBaseModel): @@ -84,6 +88,7 @@ class OrgPatchRootResponse(CustomBaseModel): class OrgGetUserResponse(CustomBaseModel): users: list[str] + organisation: OrgSchema class OrgGetGroupResponse(CustomBaseModel): groups: list[str] @@ -92,6 +97,13 @@ class OrgGetContactResponse(CustomBaseModel): model_config = ConfigDict(from_attributes=True, extra="ignore") contact: ContactModel + organisation: OrgSchema + +class OrgPatchContactResponse(CustomBaseModel): + model_config = ConfigDict(from_attributes=True, extra="ignore") + + contact: ContactModel + organisation: OrgSchema class OrgGetOrgResponse(CustomBaseModel): name: str @@ -100,6 +112,7 @@ class OrgGetOrgResponse(CustomBaseModel): owner_contact: Optional[str] = None billing_contact: Optional[str] = None security_contact: Optional[str] = None + intake_questionnaire: Optional[Questionnaire] = None class OrgDeleteOrgRequest(OrgIDMixin): pass diff --git a/test/test_organisation.py b/test/test_organisation.py index a629d04..6dfaa21 100644 --- a/test/test_organisation.py +++ b/test/test_organisation.py @@ -11,7 +11,7 @@ from .conftest import client @pytest.mark.anyio async def test_get_org_success(client: AsyncClient): - resp = await client.get("/org/id?org_id=1") + resp = await client.get("/org?org_id=1") data = resp.json() assert resp.status_code == 200 @@ -33,14 +33,14 @@ async def test_get_org_success(client: AsyncClient): ) @pytest.mark.anyio async def test_get_org_failure(client: AsyncClient, query: str, expected_status: int): - resp = await client.get(f"/org/id?{query}") + resp = await client.get(f"/org?{query}") assert resp.status_code == expected_status @pytest.mark.anyio async def test_post_org_success(client: AsyncClient): - resp = await client.post("/org/", json={"name": "New Test Org"}) + resp = await client.post("/org", json={"name": "New Test Org"}) data = resp.json() assert resp.status_code == 201 @@ -58,7 +58,7 @@ async def test_post_org_success(client: AsyncClient): ) @pytest.mark.anyio async def test_post_org_failure(client: AsyncClient, body: dict[str, str], expected_status: int): - resp = await client.post("/org/", json=body) + resp = await client.post("/org", json=body) assert resp.status_code == expected_status @@ -76,7 +76,7 @@ async def test_patch_org_questionnaire_partial_success(client: AsyncClient, db_s assert data["name"] == "Test Org" assert data["intake_questionnaire"]["question_one"] == "new answer one" assert data["status"] == "partial" - # assert type(data["intake_questionnaire"]["question_two"]) == str + assert data["intake_questionnaire"]["question_two"] == "answer two" assert data["intake_questionnaire"]["question_three"] is None @@ -159,6 +159,10 @@ async def test_get_org_users_success(client: AsyncClient): assert len(data["users"]) == 1 assert data["users"][0] == "admin@test.com" + assert "organisation" in data + assert data["organisation"]["name"] == "Test Org" + assert data["organisation"]["id"] == 1 + @pytest.mark.parametrize( "query, expected_status",