Compare commits

...

5 commits

Author SHA1 Message Date
88a64d2047 feat: root user dependencies also allow super admins 2026-05-28 10:56:45 +01:00
9efd86cd5f feat: org status check in auth dependencies
There is a hardcoded list of methods/endpoints for which the status check isn't done. i.e. the endpoints which need to be accessed before the org is approved.

Resolves #11
2026-05-28 10:56:45 +01:00
4bf5933376 minor: org pydantic model cleanup
Contact models also updated since they are now fully incorporated into orgs.

Issue #9
2026-05-27 16:51:46 +01:00
216836e2fd minor: cleanup service router imports 2026-05-27 16:30:12 +01:00
1ed0cfb38c feat: handling for integrity errors
Resolves: #7
2026-05-27 16:26:34 +01:00
9 changed files with 138 additions and 117 deletions

View file

@ -19,6 +19,13 @@ from src.organisation.models import Organisation as Org
from src.auth.exceptions import UnauthorizedException
def is_super_admin(user_model) -> bool:
super_admin_emails = ["chris@sr2.uk"]
if user_model.email not in super_admin_emails:
raise UnauthorizedException()
return True
async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency):
if user_model in org_model.user_rel:
return True
@ -33,6 +40,9 @@ async def org_query_root_claims(user_model: user_model_claims_dependency, org_mo
if org_model.root_user_id == user_model.id:
return org_model
if is_super_admin(user_model):
return org_model
raise UnauthorizedException()
@ -43,17 +53,20 @@ async def org_body_root_claims(user_model: user_model_claims_dependency, org_mod
if org_model.root_user_id == user_model.id:
return org_model
if is_super_admin(user_model):
return org_model
raise UnauthorizedException()
org_model_root_claim_body_dependency = Annotated[type[Org], Depends(org_body_root_claims)]
async def is_super_admin(user_model: user_model_claims_dependency):
super_admin_emails = []
if user_model.email not in super_admin_emails:
raise UnauthorizedException()
return True
async def user_model_super_admin(user_model: user_model_claims_dependency):
if is_super_admin(user_model):
return user_model
raise UnauthorizedException()
super_admin_dependency = Annotated[bool, Depends(is_super_admin)]
super_admin_dependency = Annotated[bool, Depends(user_model_super_admin)]

View file

@ -9,7 +9,6 @@ from typing import Optional
from pydantic import EmailStr, ConfigDict
from src.organisation.constants import ContactType
from src.schemas import CustomBaseModel
@ -25,50 +24,11 @@ class ContactAddress(CustomBaseModel):
postal_code: Optional[str] = None
class ContactContactGetResponse(CustomBaseModel):
email: str
first_name: str
last_name: str
phonenumber: str
vat_number: Optional[str] = None
class ContactAddressGetResponse(CustomBaseModel):
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None # If using a PO box, there would be no street address
street_address_line_2: Optional[str] = None
locality: str
address_region: Optional[str] = None
country_code: str
postal_code: str
class ContactContactPostRequest(CustomBaseModel):
email: EmailStr
first_name: str
last_name: str
phonenumber: str
vat_number: Optional[str] = None
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None
street_address_line_2: Optional[str] = None
locality: str
address_region: Optional[str] = None
country_code: str
postal_code: str
class ContactUpdateRequest(CustomBaseModel):
class ContactModel(CustomBaseModel):
email: Optional[EmailStr] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phonenumber: Optional[str] = None
vat_number: Optional[str] = None
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None
street_address_line_2: Optional[str] = None
locality: Optional[str] = None
address_region: Optional[str] = None
country_code: Optional[str] = None
postal_code: Optional[str] = None
class ContactOrgGetResponse(CustomBaseModel):
name: str
contact_types: list[ContactType]
address: ContactAddress

View file

@ -13,3 +13,12 @@ class UnprocessableContent(HTTPException):
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=detail,
)
class Conflict(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Conflict" if not message else message
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=detail,
)

View file

@ -6,8 +6,10 @@ Endpoints:
- Endpoints: Description
"""
from fastapi import APIRouter, status
from sqlalchemy.exc import IntegrityError
from psycopg import errors
from src.exceptions import Conflict
from src.database import db_dependency
from src.schemas import ResourceName
from src.auth.exceptions import UnauthorizedException
@ -84,7 +86,11 @@ async def create_group(db: db_dependency, org_model: org_model_root_claim_body_d
group_model = Group(name=request_model.name, org_id=org_model.id)
db.add(group_model)
db.flush()
try:
db.flush()
except IntegrityError as e:
if isinstance(e.orig, errors.UniqueViolation):
raise Conflict("Group with this name already exists")
response = GroupResponse(**group_model.__dict__)
db.commit()
return {"group": response}
@ -95,6 +101,9 @@ async def add_group_permission(db: db_dependency, group_model: group_model_body_
if group_model.org_id == org_model.id:
raise UnauthorizedException()
if perm_model in group_model.permission_rel:
raise Conflict("Group already has this permission")
group_model.permission_rel.append(perm_model)
db.flush()
@ -108,6 +117,9 @@ async def add_group_user(db: db_dependency, group_model: group_model_body_depend
if group_model.org_id == org_model.id:
raise UnauthorizedException()
if user_model in group_model.user_rel:
raise Conflict("User already in group")
group_model.user_rel.append(user_model)
db.flush()
response = IAMPutGroupUserResponse(group=GroupResponse(**group_model.__dict__), users=group_model.user_rel)
@ -151,8 +163,11 @@ async def get_permissions(db: db_dependency, org_model: org_model_root_claim_bod
@router.post("/permission")
async def create_new_permission(db: db_dependency, su: super_admin_dependency, request_mode: IAMPostPermissionRequest):
perm_model = Perm(**request_mode.__dict__)
db.add(perm_model)
try:
db.add(perm_model)
except IntegrityError as e:
if isinstance(e.orig, errors.UniqueViolation):
raise Conflict(message="Permission already exists")
db.flush()
response = IAMPostPermissionResponse(permission=PermissionResponse(**perm_model.__dict__))
db.commit()

View file

@ -9,35 +9,42 @@ Functions:
- List: Description
- Functions: Description
"""
from typing import Annotated
from typing import Annotated, Optional
from fastapi import Depends, Query
from fastapi import Depends, Query, Request
from src.database import db_dependency
from src.organisation.schemas import OrgIDMixin
from src.organisation.models import Organisation as Org
from src.organisation.exceptions import OrgNotFoundException
from src.organisation.exceptions import OrgNotFoundException, AwaitingApprovalException
from src.organisation.constants import Status as OrgStatus
def get_org_model_query(db: db_dependency, org_id: Annotated[int, Query(gt=0)]) -> type[Org]:
def get_org_model(db, request: Request, org_id: int):
org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
pre_approval_endpoints = ["PATCH/org/status", "PATCH/org/questionnaire", "GET/org/id"]
current_request = f"{request.method}{request.url.path}"
if current_request not in pre_approval_endpoints and org_model.status != OrgStatus.APPROVED:
raise AwaitingApprovalException(org_id)
return org_model
def get_org_model_query(db: db_dependency, request: Request, org_id: Annotated[int, Query(gt=0)]) -> type[Org]:
return get_org_model(db, request, org_id)
org_model_query_dependency = Annotated[type[Org], Depends(get_org_model_query)]
def get_org_model_body(db: db_dependency, request_model: OrgIDMixin) -> type[Org]:
org_id = getattr(request_model, "organisation_id", None)
def get_org_model_body(db: db_dependency, request: Request, request_model: OrgIDMixin) -> type[Org]:
org_id: Optional[int] = getattr(request_model, "organisation_id", None)
if org_id is None:
raise OrgNotFoundException
org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
return org_model
return get_org_model(db, request, org_id)
org_model_body_dependency = Annotated[type[Org], Depends(get_org_model_body)]

View file

@ -17,3 +17,11 @@ class OrgNotFoundException(HTTPException):
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
class AwaitingApprovalException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None:
detail = "Organisation has not been approved." if org_id is None else f"Organisation with ID '{org_id}' has not been approved."
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)

View file

@ -16,8 +16,11 @@ from typing import Annotated, Optional
from fastapi import APIRouter, status
from fastapi.params import Query
from psycopg.errors import UniqueViolation
from sqlalchemy.exc import IntegrityError
from src.exceptions import UnprocessableContent
from contact.schemas import ContactModel
from src.exceptions import UnprocessableContent, Conflict
from src.contact.models import Contact
from src.contact.schemas import ContactAddress
from src.contact.exceptions import ContactNotFoundException
@ -29,10 +32,10 @@ from src.auth.dependencies import super_admin_dependency, org_model_root_claim_q
from src.organisation.dependencies import org_model_body_dependency
from src.organisation.constants import ContactType
from src.organisation.models import Organisation as Org
from src.organisation.schemas import OrgOrgPostRequest, OrgQuestionnairePatchRequest, OrgStatusPatchRequest, \
OrgContactPatchRequest, \
OrgUserPostRequest, OrgUserGetResponse, OrgContactGetResponse, OrgOrgGetResponse, OrgRootPatchRequest, \
OrgGroupGetResponse, OrgUserDeleteRequest, OrgDeleteOrgRequest
from src.organisation.schemas import OrgPostOrgRequest, OrgPatchQuestionnaireRequest, OrgPatchStatusRequest, \
OrgPatchContactRequest, \
OrgPostUserRequest, OrgGetUserResponse, OrgGetContactResponse, OrgGetOrgResponse, OrgPatchRootRequest, \
OrgGetGroupResponse, OrgDeleteUserRequest, OrgDeleteOrgRequest
router = APIRouter(
@ -41,7 +44,7 @@ router = APIRouter(
)
@router.get("/id", response_model=OrgOrgGetResponse)
@router.get("/id", response_model=OrgGetOrgResponse)
async def get_org_by_id(org_model: org_model_root_claim_query_dependency):
response = {
"name": org_model.name,
@ -56,7 +59,7 @@ async def get_org_by_id(org_model: org_model_root_claim_query_dependency):
@router.post("/")
async def create_org(db: db_dependency, user_model: user_model_claims_dependency, request_model: OrgOrgPostRequest):
async def create_org(db: db_dependency, user_model: user_model_claims_dependency, request_model: OrgPostOrgRequest):
if request_model.intake_questionnaire:
intake_questionnaire = request_model.intake_questionnaire.model_dump()
else:
@ -66,9 +69,12 @@ async def create_org(db: db_dependency, user_model: user_model_claims_dependency
org_model.status = "partial" # Status is always set to partial at first, see update_questionnaire() doc
db.add(org_model)
db.flush()
try:
db.flush()
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise Conflict(message="Organisation with this name already exists")
# Adds currently logged-in user to org users list and sets them as root_user
user_model = db.get(User, db_id)
org_model.user_rel.append(user_model)
org_model.root_user_rel = user_model
for contact_type in ["billing_contact_id", "security_contact_id", "owner_contact_id"]:
@ -80,7 +86,7 @@ async def create_org(db: db_dependency, user_model: user_model_claims_dependency
@router.patch("/questionnaire")
async def update_questionnaire(db: db_dependency, org_model: org_model_root_claim_query_dependency, request_model: OrgQuestionnairePatchRequest):
async def update_questionnaire(db: db_dependency, org_model: org_model_root_claim_query_dependency, request_model: OrgPatchQuestionnaireRequest):
"""
Route for updating questionnaire.
The partial bool allows for submission of partially completed questionnaire and/or
@ -96,21 +102,21 @@ async def update_questionnaire(db: db_dependency, org_model: org_model_root_clai
@router.patch("/status")
async def update_status(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgStatusPatchRequest):
async def update_status(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchStatusRequest):
org_model.status = request_model.status
db.commit()
@router.get("/users", response_model=OrgUserGetResponse)
@router.get("/users", response_model=OrgGetUserResponse)
async def get_users(org_model: org_model_root_claim_query_dependency):
return {"users": [user.email for user in org_model.user_rel]}
@router.post("/users")
async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgUserPostRequest):
async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgPostUserRequest):
if user_model in org_model.user_rel:
return
raise Conflict(message="User already a part of this organisation")
org_model.user_rel.append(user_model)
db.commit()
@ -122,18 +128,18 @@ async def delete_organisation_by_id(db: db_dependency, org_model: org_model_body
@router.patch("/root_user", status_code=status.HTTP_204_NO_CONTENT)
async def update_root_user(db: db_dependency, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, request_model: OrgRootPatchRequest):
async def update_root_user(db: db_dependency, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchRootRequest):
org_model.root_user_rel = user_model
db.commit()
@router.get("/groups", response_model=OrgGroupGetResponse)
@router.get("/groups", response_model=OrgGetGroupResponse)
async def get_org_groups(org_model: org_model_root_claim_query_dependency):
return {"groups": [group.name for group in org_model.group_rel]}
@router.delete("/user", status_code=status.HTTP_204_NO_CONTENT)
async def remove_user_from_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgUserDeleteRequest):
async def remove_user_from_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgDeleteUserRequest):
if user_model not in org_model.user_rel:
return
@ -141,7 +147,7 @@ async def remove_user_from_org(db: db_dependency, org_model: org_model_root_clai
db.commit()
@router.get("/contact", response_model=OrgContactGetResponse)
@router.get("/contact", response_model=OrgGetContactResponse)
async def get_contact(org_model: org_model_root_claim_query_dependency, contact_type: Annotated[ContactType, Query()]):
match contact_type:
case "billing":
@ -156,14 +162,14 @@ async def get_contact(org_model: org_model_root_claim_query_dependency, contact_
if contact_model is None:
raise ContactNotFoundException()
return OrgContactGetResponse.model_construct(
**contact_model.__dict__,
address=ContactAddress.model_validate(contact_model)
)
address = ContactAddress.model_validate(contact_model)
contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address)
return {"contact": contact_response}
@router.patch("/contact", response_model=OrgContactGetResponse)
async def update_contact(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: OrgContactPatchRequest):
@router.patch("/contact", response_model=OrgGetContactResponse)
async def update_contact(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchContactRequest):
match request_model.contact_type:
case "billing":
contact_model = org_model.billing_contact_rel
@ -185,11 +191,9 @@ async def update_contact(db: db_dependency, org_model: org_model_root_claim_body
raise UnprocessableContent("Invalid keys in update request")
db.flush()
response = OrgContactGetResponse.model_construct(
**contact_model.__dict__,
address=ContactAddress.model_validate(contact_model)
)
address = ContactAddress.model_validate(contact_model)
contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address)
db.commit()
return response
return {"contact": contact_response}

View file

@ -10,8 +10,11 @@ from typing import Optional
from pydantic import EmailStr, ConfigDict
from src.schemas import CustomBaseModel
from src.contact.schemas import ContactModel
from src.user.schemas import UserIDMixin
from src.organisation.constants import Status, ContactType
from src.contact.schemas import ContactAddress
class OrgQuestionnaire(CustomBaseModel):
question_one: str
@ -21,18 +24,19 @@ class OrgQuestionnaire(CustomBaseModel):
class OrgIDMixin(CustomBaseModel):
organisation_id: int
class OrgOrgPostRequest(CustomBaseModel):
class OrgPostOrgRequest(CustomBaseModel):
name: str
intake_questionnaire: Optional[OrgQuestionnaire] = None
class OrgQuestionnairePatchRequest(OrgIDMixin):
class OrgPatchQuestionnaireRequest(OrgIDMixin):
intake_questionnaire: OrgQuestionnaire
partial: bool
class OrgStatusPatchRequest(OrgIDMixin):
class OrgPatchStatusRequest(OrgIDMixin):
status: Status
class OrgContactPatchRequest(OrgIDMixin):
class OrgPatchContactRequest(OrgIDMixin):
contact_type: ContactType
email: Optional[EmailStr] = None
@ -48,33 +52,27 @@ class OrgContactPatchRequest(OrgIDMixin):
country_code: Optional[str] = None
postal_code: Optional[str] = None
class OrgUserPostRequest(OrgIDMixin):
user_id: int
class OrgPostUserRequest(OrgIDMixin, UserIDMixin):
pass
class OrgUserDeleteRequest(OrgIDMixin):
user_id: int
class OrgDeleteUserRequest(OrgIDMixin, UserIDMixin):
pass
class OrgRootPatchRequest(OrgIDMixin):
user_id: int
class OrgPatchRootRequest(OrgIDMixin, UserIDMixin):
pass
class OrgUserGetResponse(CustomBaseModel):
class OrgGetUserResponse(CustomBaseModel):
users: list[str]
class OrgGroupGetResponse(CustomBaseModel):
class OrgGetGroupResponse(CustomBaseModel):
groups: list[str]
class OrgContactGetResponse(CustomBaseModel):
class OrgGetContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
email: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phonenumber: Optional[str] = None
vat_number: Optional[str] = None
contact: ContactModel
address: ContactAddress
class OrgOrgGetResponse(CustomBaseModel):
class OrgGetOrgResponse(CustomBaseModel):
name: str
status: Status
root_user: Optional[str] = None
@ -83,4 +81,4 @@ class OrgOrgGetResponse(CustomBaseModel):
security_contact: Optional[str] = None
class OrgDeleteOrgRequest(OrgIDMixin):
pass
pass

View file

@ -6,10 +6,12 @@ Endpoints:
- Endpoints: Description
"""
from fastapi import APIRouter, status
from psycopg.errors import UniqueViolation
from sqlalchemy.exc import IntegrityError
from src.exceptions import Conflict
from src.database import db_dependency
from src.auth.service import claims_dependency
from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency, org_model_root_claim_body_dependency
from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency
from src.service.models import Service
from src.service.utils import generate_api_key
@ -34,7 +36,12 @@ async def register_service(db: db_dependency, su: super_admin_dependency, servic
service_model = Service(name=service_request.name, api_key=key)
db.add(service_model)
db.flush()
try:
db.flush()
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise Conflict(message="Service with this name already exists")
db.commit()
response = ServiceWithKeyResponse(**service_model.__dict__)
db.commit()
return {"service": response}