feat: custom exceptions instead of direct fastapi.httpexceptions

Resolves #2
This commit is contained in:
Chris Milne 2026-05-27 14:58:10 +01:00
parent d3d3b2ca63
commit 868e56ce40
9 changed files with 73 additions and 43 deletions

View file

@ -10,18 +10,19 @@ Functions:
- Functions: Description - Functions: Description
""" """
from typing import Annotated, Any from typing import Annotated, Any
from fastapi import Depends, HTTPException from fastapi import Depends
from src.user.dependencies import user_model_claims_dependency from src.user.dependencies import user_model_claims_dependency
from src.organisation.dependencies import org_model_query_dependency from src.organisation.dependencies import org_model_query_dependency
from src.auth.exceptions import UnauthorizedException
async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency): async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency):
if user_model in org_model.user_rel: if user_model in org_model.user_rel:
return True return True
raise HTTPException(status_code=401, detail="Not authorised") raise UnauthorizedException()
org_query_user_claims_dependency = Annotated[dict[str, Any], Depends(org_query_user_claims)] org_query_user_claims_dependency = Annotated[dict[str, Any], Depends(org_query_user_claims)]
@ -31,7 +32,7 @@ async def org_query_root_claims(user_model: user_model_claims_dependency, org_mo
if org_model.root_user_id == user_model.id: if org_model.root_user_id == user_model.id:
return True return True
raise HTTPException(status_code=401, detail="Not authorised") raise UnauthorizedException()
org_query_root_claims_dependency = Annotated[dict[str, Any], Depends(org_query_root_claims)] org_query_root_claims_dependency = Annotated[dict[str, Any], Depends(org_query_root_claims)]
@ -40,8 +41,7 @@ org_query_root_claims_dependency = Annotated[dict[str, Any], Depends(org_query_r
async def is_super_admin(user_model: user_model_claims_dependency): async def is_super_admin(user_model: user_model_claims_dependency):
super_admin_emails = [] super_admin_emails = []
if user_model.email not in super_admin_emails: if user_model.email not in super_admin_emails:
raise HTTPException(status_code=401, detail="Not authorised") raise UnauthorizedException()
return True return True

View file

@ -5,3 +5,15 @@ Exceptions:
- List: Description - List: Description
- Exceptions: Description - Exceptions: Description
""" """
from typing import Optional
from fastapi import HTTPException, status
class UnauthorizedException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Not authorized" if not message else message
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)

View file

@ -13,9 +13,10 @@ from joserfc.errors import ExpiredTokenError
from joserfc.jwk import KeySet from joserfc.jwk import KeySet
from urllib.request import urlopen from urllib.request import urlopen
from fastapi import Depends, HTTPException from fastapi import Depends
from fastapi.security import OpenIdConnect from fastapi.security import OpenIdConnect
from src.auth.exceptions import UnauthorizedException
from src.auth.config import auth_settings from src.auth.config import auth_settings
from src.user.service import add_user_to_db from src.user.service import add_user_to_db
@ -50,8 +51,7 @@ async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]:
try: try:
claims_requests.validate(token.claims) claims_requests.validate(token.claims)
except ExpiredTokenError: except ExpiredTokenError:
raise HTTPException(status_code=401, detail="Token expired") raise UnauthorizedException(message="Token is expired")
db_id = await add_user_to_db(token.claims) db_id = await add_user_to_db(token.claims)
token.claims["db_id"] = db_id token.claims["db_id"] = db_id

View file

@ -5,3 +5,15 @@ Exceptions:
- List: Description - List: Description
- Exceptions: Description - Exceptions: Description
""" """
from typing import Optional
from fastapi import HTTPException, status
class ContactNotFoundException(HTTPException):
def __init__(self, contact_id: Optional[int] = None) -> None:
detail = "Contact not found" if contact_id is None else f"Contact with ID '{contact_id}' was not found."
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)

View file

@ -1,3 +1,15 @@
""" """
Global exceptions Global exceptions
""" """
from typing import Optional
from fastapi import HTTPException, status
class UnprocessableContent(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Not authorized" if not message else message
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=detail,
)

View file

@ -5,8 +5,9 @@ Endpoints:
- List: Description - List: Description
- Endpoints: Description - Endpoints: Description
""" """
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, status
from auth.exceptions import UnauthorizedException
from src.database import db_dependency from src.database import db_dependency
from src.iam.schemas import IAMGetGroupPermissionsResponse, IAMGetGroupUsersResponse, IAMPostGroupRequest, \ from src.iam.schemas import IAMGetGroupPermissionsResponse, IAMGetGroupUsersResponse, IAMPostGroupRequest, \
GroupResponse, IAMPostGroupResponse, IAMPutGroupPermissionRequest, IAMPutGroupPermissionResponse, \ GroupResponse, IAMPostGroupResponse, IAMPutGroupPermissionRequest, IAMPutGroupPermissionResponse, \
@ -15,8 +16,8 @@ from src.iam.schemas import IAMGetGroupPermissionsResponse, IAMGetGroupUsersResp
IAMPostPermissionResponse, PermissionResponse, IAMDeletePermissionRequest, IAMGetPermissionsSearchRequest, IAMGetPermissionsSearchResponse IAMPostPermissionResponse, PermissionResponse, IAMDeletePermissionRequest, IAMGetPermissionsSearchRequest, IAMGetPermissionsSearchResponse
from src.schemas import ResourceName from src.schemas import ResourceName
from src.auth.service import claims_dependency from src.auth.service import claims_dependency
from src.user.exceptions import UserNotFoundException
from src.user.models import User from src.user.models import User
from src.user.dependencies import user_model_body_dependency
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
from src.service.models import Service from src.service.models import Service
from src.organisation.dependencies import org_model_body_dependency from src.organisation.dependencies import org_model_body_dependency
@ -58,9 +59,8 @@ async def can_act_on_resource(valid_key: service_key_dependency, db: db_dependen
return True return True
else: else:
return False return False
except Exception as e: except Exception:
print(e) raise UnauthorizedException()
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/group/permissions", response_model=IAMGetGroupPermissionsResponse) @router.get("/group/permissions", response_model=IAMGetGroupPermissionsResponse)
@ -100,13 +100,8 @@ async def add_group_permission(db: db_dependency, group_model: group_model_body_
@router.put("/group/user") @router.put("/group/user")
async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, request_model: IAMPutGroupUserRequest): async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, request_model: IAMPutGroupUserRequest):
# TODO: root_user_dependency # TODO: root_user_dependency
# TODO: user_model_dependency
user_model = db.get(User, request_model.user_id)
if user_model is None:
raise UserNotFoundException(user_id=request_model.user_id)
group_model.user_rel.append(user_model) group_model.user_rel.append(user_model)
db.flush() db.flush()
response = IAMPutGroupUserResponse(group=GroupResponse(**group_model.__dict__), users=group_model.user_rel) response = IAMPutGroupUserResponse(group=GroupResponse(**group_model.__dict__), users=group_model.user_rel)
@ -126,13 +121,8 @@ async def remove_group_permissions(db: db_dependency, group_model: group_model_b
@router.delete("/group/user") @router.delete("/group/user")
async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, request_model: IAMDeleteGroupUserRequest): async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, request_model: IAMDeleteGroupUserRequest):
# TODO: root_user_dependency # TODO: root_user_dependency
# TODO: User model dependency
user_model = db.get(User, request_model.user_id)
if user_model is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
user_model.group_rel.remove(group_model) user_model.group_rel.remove(group_model)
db.flush() db.flush()
response = IAMDeleteGroupUserResponse(group=GroupResponse(**group_model.__dict__), users=group_model.user_rel) response = IAMDeleteGroupUserResponse(group=GroupResponse(**group_model.__dict__), users=group_model.user_rel)

View file

@ -11,6 +11,8 @@ from pydantic import EmailStr, ConfigDict
from src.organisation.schemas import OrgIDMixin from src.organisation.schemas import OrgIDMixin
from src.schemas import CustomBaseModel from src.schemas import CustomBaseModel
from user.schemas import UserIDMixin
class UserResponse(CustomBaseModel): class UserResponse(CustomBaseModel):
id: int id: int
@ -54,8 +56,8 @@ class IAMPutGroupPermissionResponse(CustomBaseModel):
group: GroupResponse group: GroupResponse
permissions: list[PermissionResponse] permissions: list[PermissionResponse]
class IAMPutGroupUserRequest(GroupIDMixin): class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin):
user_id: int pass
class IAMPutGroupUserResponse(CustomBaseModel): class IAMPutGroupUserResponse(CustomBaseModel):
group: GroupResponse group: GroupResponse
@ -68,8 +70,8 @@ class IAMDeleteGroupPermissionResponse(CustomBaseModel):
group: GroupResponse group: GroupResponse
permissions: list[PermissionResponse] permissions: list[PermissionResponse]
class IAMDeleteGroupUserRequest(GroupIDMixin): class IAMDeleteGroupUserRequest(GroupIDMixin, UserIDMixin):
user_id: int pass
class IAMDeleteGroupUserResponse(CustomBaseModel): class IAMDeleteGroupUserResponse(CustomBaseModel):
group: GroupResponse group: GroupResponse

View file

@ -14,10 +14,12 @@ Endpoints:
""" """
from typing import Annotated, Optional from typing import Annotated, Optional
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, status
from fastapi.params import Query from fastapi.params import Query
from src.exceptions import UnprocessableContent
from src.contact.schemas import ContactAddress from src.contact.schemas import ContactAddress
from src.contact.exceptions import ContactNotFoundException
from src.database import db_dependency from src.database import db_dependency
from src.contact.models import Contact from src.contact.models import Contact
from src.user.models import User from src.user.models import User
@ -150,7 +152,7 @@ async def remove_user_from_org(db: db_dependency, org_model: org_model_body_depe
raise UserNotFoundException(user_id=user_id) raise UserNotFoundException(user_id=user_id)
if user not in org_model.user_rel: if user not in org_model.user_rel:
raise HTTPException(status_code=status.HTTP_204_NOT_FOUND) return
org_model.user_rel.remove(user) org_model.user_rel.remove(user)
db.commit() db.commit()
@ -166,10 +168,10 @@ async def get_contact(org_model: org_model_query_dependency, contact_type: Annot
case "owner": case "owner":
contact_model = org_model.owner_contact_rel contact_model = org_model.owner_contact_rel
case _: case _:
raise HTTPException(status_code=422, detail="Invalid contact type") raise UnprocessableContent("Invalid contact type")
if contact_model is None: if contact_model is None:
raise HTTPException(status_code=404, detail="Contact not found") raise ContactNotFoundException()
return OrgContactGetResponse.model_construct( return OrgContactGetResponse.model_construct(
**contact_model.__dict__, **contact_model.__dict__,
@ -187,17 +189,17 @@ async def update_contact(db: db_dependency, org_model: org_model_body_dependency
case "owner": case "owner":
contact_model = org_model.owner_contact_rel contact_model = org_model.owner_contact_rel
case _: case _:
raise HTTPException(status_code=422, detail="Invalid contact type") raise UnprocessableContent("Invalid contact type")
if contact_model is None: if contact_model is None:
raise HTTPException(status_code=404, detail="Contact not found") raise ContactNotFoundException()
update_data = request_model.model_dump(exclude_none=True) update_data = request_model.model_dump(exclude_none=True)
for key, value in update_data.items(): for key, value in update_data.items():
if hasattr(contact_model, key): if hasattr(contact_model, key):
setattr(contact_model, key, value) setattr(contact_model, key, value)
else: else:
raise HTTPException(status_code=422, detail="Invalid keys in update request") raise UnprocessableContent("Invalid keys in update request")
db.flush() db.flush()
response = OrgContactGetResponse.model_construct( response = OrgContactGetResponse.model_construct(

View file

@ -9,11 +9,11 @@ Exports:
""" """
from typing import Any from typing import Any
from fastapi import HTTPException from src.database import get_db
from src.exceptions import UnprocessableContent
from src.user.schemas import OIDCUser from src.user.schemas import OIDCUser
from src.user.models import User from src.user.models import User
from src.database import get_db
async def add_user_to_db(user_claims: dict[str, Any]) -> int: async def add_user_to_db(user_claims: dict[str, Any]) -> int:
@ -21,7 +21,7 @@ async def add_user_to_db(user_claims: dict[str, Any]) -> int:
valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"]) valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"])
except Exception as e: except Exception as e:
print(e) print(e)
raise HTTPException(status_code=422, detail="Invalid or missing OIDC data") raise UnprocessableContent("Invalid or missing OIDC data")
db = next(get_db()) db = next(get_db())
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first() db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()