Compare commits

..

6 commits

Author SHA1 Message Date
af680dbc38 feat: get/patch contact includes org info
Resolves #19
2026-06-03 09:45:48 +01:00
8a9f03ee0b feat: get users includes org info 2026-06-03 09:38:54 +01:00
7833386350 feat: patch questionnaire doesn't overwrite with none 2026-06-03 09:29:06 +01:00
c9cd75a7ad fix: missing src in imports 2026-06-03 09:15:25 +01:00
1845012cb7 feat: get org endpoint returns questionnaire 2026-06-02 16:36:56 +01:00
81a4cc6cca feat: org router endpoint cleanup
`/id/` removed from GET
Trailing `/` removed from POST and DELETE
2026-06-02 16:36:11 +01:00
6 changed files with 47 additions and 20 deletions

View file

@ -19,7 +19,7 @@ from fastapi import APIRouter, status
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from psycopg import errors from psycopg import errors
from service.exceptions import ServiceNotFoundException from src.service.exceptions import ServiceNotFoundException
from src.exceptions import ConflictException from src.exceptions import ConflictException
from src.database import db_dependency from src.database import db_dependency
from src.schemas import ResourceName from src.schemas import ResourceName

View file

@ -13,7 +13,7 @@ from pydantic import EmailStr, ConfigDict, Field
from src.service.schemas import ServiceIDMixin from src.service.schemas import ServiceIDMixin
from src.organisation.schemas import OrgIDMixin from src.organisation.schemas import OrgIDMixin
from src.schemas import CustomBaseModel from src.schemas import CustomBaseModel
from user.schemas import UserIDMixin from src.user.schemas import UserIDMixin
class UserSchema(CustomBaseModel): class UserSchema(CustomBaseModel):

View file

@ -25,7 +25,7 @@ def get_org_model(db: Session, request: Request, org_id: int):
root = "/api/v1" root = "/api/v1"
pre_approval_endpoints = [f"PATCH{root}/org/status", f"PATCH{root}/org/questionnaire", f"GET{root}/org/id", f"GET{root}/org/contact", f"PATCH{root}/org/contact"] pre_approval_endpoints = [f"PATCH{root}/org/status", f"PATCH{root}/org/questionnaire", f"GET{root}/org", f"GET{root}/org/contact", f"PATCH{root}/org/contact"]
current_request = f"{request.method}{request.url.path}" current_request = f"{request.method}{request.url.path}"
if current_request not in pre_approval_endpoints and org_model.status != OrgStatus.APPROVED: if current_request not in pre_approval_endpoints and org_model.status != OrgStatus.APPROVED:
raise AwaitingApprovalException(org_id) raise AwaitingApprovalException(org_id)

View file

@ -39,7 +39,7 @@ from src.organisation.schemas import OrgPostOrgRequest, OrgPatchQuestionnaireReq
OrgPatchContactRequest, \ OrgPatchContactRequest, \
OrgPostUserRequest, OrgGetUserResponse, OrgGetContactResponse, OrgGetOrgResponse, OrgPatchRootRequest, \ OrgPostUserRequest, OrgGetUserResponse, OrgGetContactResponse, OrgGetOrgResponse, OrgPatchRootRequest, \
OrgGetGroupResponse, OrgDeleteUserRequest, OrgDeleteOrgRequest, OrgPostOrgResponse, OrgPatchQuestionnaireResponse, \ OrgGetGroupResponse, OrgDeleteUserRequest, OrgDeleteOrgRequest, OrgPostOrgResponse, OrgPatchQuestionnaireResponse, \
OrgPatchStatusResponse, OrgPostUserResponse, OrgPatchRootResponse OrgPatchStatusResponse, OrgPostUserResponse, OrgPatchRootResponse, Questionnaire, OrgPatchContactResponse
router = APIRouter( router = APIRouter(
prefix="/org", prefix="/org",
@ -47,7 +47,7 @@ router = APIRouter(
) )
@router.get("/id", @router.get("",
summary="Get org details by ID.", summary="Get org details by ID.",
response_model=OrgGetOrgResponse, response_model=OrgGetOrgResponse,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
@ -67,13 +67,14 @@ async def get_org_by_id(org_model: org_model_root_claim_query_dependency):
"owner_contact": org_model.owner_contact_rel.email, "owner_contact": org_model.owner_contact_rel.email,
"billing_contact": org_model.billing_contact_rel.email, "billing_contact": org_model.billing_contact_rel.email,
"security_contact": org_model.security_contact_rel.email, "security_contact": org_model.security_contact_rel.email,
"root_user": org_model.root_user_email "root_user": org_model.root_user_email,
"intake_questionnaire": org_model.intake_questionnaire
} }
return response return response
@router.post("/", @router.post("",
summary="Create new organisation.", summary="Create new organisation.",
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
response_model=OrgPostOrgResponse, response_model=OrgPostOrgResponse,
@ -130,12 +131,21 @@ async def update_questionnaire(db: db_dependency, org_model: org_model_root_clai
The partial bool allows for submission of partially completed questionnaire and/or The partial bool allows for submission of partially completed questionnaire and/or
final "are you sure" check before setting the org to be in "submitted" status, awaiting admin approval. final "are you sure" check before setting the org to be in "submitted" status, awaiting admin approval.
""" """
org_model.intake_questionnaire = request_model.intake_questionnaire.model_dump() update_data = request_model.intake_questionnaire.model_dump(exclude_none=True)
questionnaire_model = Questionnaire(**org_model.intake_questionnaire)
for key, value in update_data.items():
if hasattr(questionnaire_model, key):
setattr(questionnaire_model, key, value)
else:
if key == "partial" or key == "organisation_id":
continue
raise UnprocessableContentException("Invalid keys in update request")
# Allows for partially completed questionnaires to be saved without being submitted for review # Allows for partially completed questionnaires to be saved without being submitted for review
if not request_model.partial: if not request_model.partial:
org_model.status = "submitted" org_model.status = "submitted"
org_model.intake_questionnaire = questionnaire_model.model_dump()
db.flush() db.flush()
response = OrgPatchQuestionnaireResponse(**org_model.__dict__) response = OrgPatchQuestionnaireResponse(**org_model.__dict__)
db.commit() db.commit()
@ -175,7 +185,7 @@ async def get_users(org_model: org_model_root_claim_query_dependency):
""" """
Returns a list of the email addresses of all users of the organisation. Returns a list of the email addresses of all users of the organisation.
""" """
return {"users": [user.email for user in org_model.user_rel]} return {"users": [user.email for user in org_model.user_rel], "organisation": org_model}
@router.post("/user", @router.post("/user",
@ -201,7 +211,7 @@ async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_bod
return response return response
@router.delete("/", @router.delete("",
summary="Delete organisation from the hub.", summary="Delete organisation from the hub.",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
responses={ responses={
@ -303,13 +313,13 @@ async def get_contact(org_model: org_model_root_claim_query_dependency, contact_
address = ContactAddress.model_validate(contact_model) address = ContactAddress.model_validate(contact_model)
contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address) contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address)
return {"contact": contact_response} return {"contact": contact_response, "organisation": org_model}
@router.patch("/contact", @router.patch("/contact",
summary="Update contact for organisation.", summary="Update contact for organisation.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=OrgGetContactResponse, response_model=OrgPatchContactResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successfully updated contact."}, status.HTTP_200_OK: {"description": "Successfully updated contact."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
@ -347,4 +357,4 @@ async def update_contact(db: db_dependency, org_model: org_model_root_claim_body
db.commit() db.commit()
return {"contact": contact_response} return {"contact": contact_response, "organisation": org_model}

View file

@ -17,13 +17,17 @@ from src.user.schemas import UserIDMixin
from src.organisation.constants import Status, ContactType from src.organisation.constants import Status, ContactType
class OrgIDMixin(CustomBaseModel):
organisation_id: int = Field(gt=0)
class Questionnaire(CustomBaseModel): class Questionnaire(CustomBaseModel):
question_one: Optional[str] = None question_one: Optional[str] = None
question_two: Optional[str] = None question_two: Optional[str] = None
question_three: Optional[str] = None question_three: Optional[str] = None
class OrgIDMixin(CustomBaseModel): class OrgSchema(CustomBaseModel):
organisation_id: int = Field(gt=0) id: int
name: str
class OrgPostOrgRequest(CustomBaseModel): class OrgPostOrgRequest(CustomBaseModel):
@ -84,6 +88,7 @@ class OrgPatchRootResponse(CustomBaseModel):
class OrgGetUserResponse(CustomBaseModel): class OrgGetUserResponse(CustomBaseModel):
users: list[str] users: list[str]
organisation: OrgSchema
class OrgGetGroupResponse(CustomBaseModel): class OrgGetGroupResponse(CustomBaseModel):
groups: list[str] groups: list[str]
@ -92,6 +97,13 @@ class OrgGetContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel contact: ContactModel
organisation: OrgSchema
class OrgPatchContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel
organisation: OrgSchema
class OrgGetOrgResponse(CustomBaseModel): class OrgGetOrgResponse(CustomBaseModel):
name: str name: str
@ -100,6 +112,7 @@ class OrgGetOrgResponse(CustomBaseModel):
owner_contact: Optional[str] = None owner_contact: Optional[str] = None
billing_contact: Optional[str] = None billing_contact: Optional[str] = None
security_contact: Optional[str] = None security_contact: Optional[str] = None
intake_questionnaire: Optional[Questionnaire] = None
class OrgDeleteOrgRequest(OrgIDMixin): class OrgDeleteOrgRequest(OrgIDMixin):
pass pass

View file

@ -11,7 +11,7 @@ from .conftest import client
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_success(client: AsyncClient): async def test_get_org_success(client: AsyncClient):
resp = await client.get("/org/id?org_id=1") resp = await client.get("/org?org_id=1")
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -33,14 +33,14 @@ async def test_get_org_success(client: AsyncClient):
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_failure(client: AsyncClient, query: str, expected_status: int): async def test_get_org_failure(client: AsyncClient, query: str, expected_status: int):
resp = await client.get(f"/org/id?{query}") resp = await client.get(f"/org?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_success(client: AsyncClient): async def test_post_org_success(client: AsyncClient):
resp = await client.post("/org/", json={"name": "New Test Org"}) resp = await client.post("/org", json={"name": "New Test Org"})
data = resp.json() data = resp.json()
assert resp.status_code == 201 assert resp.status_code == 201
@ -58,7 +58,7 @@ async def test_post_org_success(client: AsyncClient):
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_failure(client: AsyncClient, body: dict[str, str], expected_status: int): async def test_post_org_failure(client: AsyncClient, body: dict[str, str], expected_status: int):
resp = await client.post("/org/", json=body) resp = await client.post("/org", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -76,7 +76,7 @@ async def test_patch_org_questionnaire_partial_success(client: AsyncClient, db_s
assert data["name"] == "Test Org" assert data["name"] == "Test Org"
assert data["intake_questionnaire"]["question_one"] == "new answer one" assert data["intake_questionnaire"]["question_one"] == "new answer one"
assert data["status"] == "partial" assert data["status"] == "partial"
# assert type(data["intake_questionnaire"]["question_two"]) == str assert data["intake_questionnaire"]["question_two"] == "answer two"
assert data["intake_questionnaire"]["question_three"] is None assert data["intake_questionnaire"]["question_three"] is None
@ -159,6 +159,10 @@ async def test_get_org_users_success(client: AsyncClient):
assert len(data["users"]) == 1 assert len(data["users"]) == 1
assert data["users"][0] == "admin@test.com" assert data["users"][0] == "admin@test.com"
assert "organisation" in data
assert data["organisation"]["name"] == "Test Org"
assert data["organisation"]["id"] == 1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", "query, expected_status",