From 868e56ce40f15c2c89fc806c2715e5d75857211d Mon Sep 17 00:00:00 2001 From: luxferre Date: Wed, 27 May 2026 14:58:10 +0100 Subject: [PATCH] feat: custom exceptions instead of direct fastapi.httpexceptions Resolves #2 --- src/auth/dependencies.py | 12 ++++++------ src/auth/exceptions.py | 14 +++++++++++++- src/auth/service.py | 6 +++--- src/contact/exceptions.py | 14 +++++++++++++- src/exceptions.py | 14 +++++++++++++- src/iam/router.py | 24 +++++++----------------- src/iam/schemas.py | 10 ++++++---- src/organisation/router.py | 16 +++++++++------- src/user/service.py | 6 +++--- 9 files changed, 73 insertions(+), 43 deletions(-) diff --git a/src/auth/dependencies.py b/src/auth/dependencies.py index 7fbb96d..439b23c 100644 --- a/src/auth/dependencies.py +++ b/src/auth/dependencies.py @@ -10,18 +10,19 @@ Functions: - Functions: Description """ from typing import Annotated, Any -from fastapi import Depends, HTTPException +from fastapi import Depends from src.user.dependencies import user_model_claims_dependency - from src.organisation.dependencies import org_model_query_dependency +from src.auth.exceptions import UnauthorizedException + async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency): if user_model in org_model.user_rel: return True - raise HTTPException(status_code=401, detail="Not authorised") + raise UnauthorizedException() org_query_user_claims_dependency = Annotated[dict[str, Any], Depends(org_query_user_claims)] @@ -31,7 +32,7 @@ async def org_query_root_claims(user_model: user_model_claims_dependency, org_mo if org_model.root_user_id == user_model.id: return True - raise HTTPException(status_code=401, detail="Not authorised") + raise UnauthorizedException() org_query_root_claims_dependency = Annotated[dict[str, Any], Depends(org_query_root_claims)] @@ -40,8 +41,7 @@ org_query_root_claims_dependency = Annotated[dict[str, Any], Depends(org_query_r async def is_super_admin(user_model: user_model_claims_dependency): super_admin_emails = [] if user_model.email not in super_admin_emails: - raise HTTPException(status_code=401, detail="Not authorised") - + raise UnauthorizedException() return True diff --git a/src/auth/exceptions.py b/src/auth/exceptions.py index 3861aad..71aede1 100644 --- a/src/auth/exceptions.py +++ b/src/auth/exceptions.py @@ -4,4 +4,16 @@ Module specific exceptions for auth module Exceptions: - List: Description - Exceptions: Description -""" \ No newline at end of file +""" +from typing import Optional + +from fastapi import HTTPException, status + + +class UnauthorizedException(HTTPException): + def __init__(self, message: Optional[str] = None) -> None: + detail = "Not authorized" if not message else message + super().__init__( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=detail, + ) diff --git a/src/auth/service.py b/src/auth/service.py index 60bd9c3..e0a764e 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -13,9 +13,10 @@ from joserfc.errors import ExpiredTokenError from joserfc.jwk import KeySet from urllib.request import urlopen -from fastapi import Depends, HTTPException +from fastapi import Depends from fastapi.security import OpenIdConnect +from src.auth.exceptions import UnauthorizedException from src.auth.config import auth_settings from src.user.service import add_user_to_db @@ -50,8 +51,7 @@ async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]: try: claims_requests.validate(token.claims) except ExpiredTokenError: - raise HTTPException(status_code=401, detail="Token expired") - + raise UnauthorizedException(message="Token is expired") db_id = await add_user_to_db(token.claims) token.claims["db_id"] = db_id diff --git a/src/contact/exceptions.py b/src/contact/exceptions.py index 58e6e30..b3f8e11 100644 --- a/src/contact/exceptions.py +++ b/src/contact/exceptions.py @@ -4,4 +4,16 @@ Module specific exceptions for contact module Exceptions: - List: Description - Exceptions: Description -""" \ No newline at end of file +""" +from typing import Optional + +from fastapi import HTTPException, status + + +class ContactNotFoundException(HTTPException): + def __init__(self, contact_id: Optional[int] = None) -> None: + detail = "Contact not found" if contact_id is None else f"Contact with ID '{contact_id}' was not found." + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) diff --git a/src/exceptions.py b/src/exceptions.py index b18e221..5d90f95 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -1,3 +1,15 @@ """ Global exceptions -""" \ No newline at end of file +""" +from typing import Optional + +from fastapi import HTTPException, status + + +class UnprocessableContent(HTTPException): + def __init__(self, message: Optional[str] = None) -> None: + detail = "Not authorized" if not message else message + super().__init__( + status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, + detail=detail, + ) diff --git a/src/iam/router.py b/src/iam/router.py index 2e2ad7c..0c9c59a 100644 --- a/src/iam/router.py +++ b/src/iam/router.py @@ -5,8 +5,9 @@ Endpoints: - List: Description - Endpoints: Description """ -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, status +from auth.exceptions import UnauthorizedException from src.database import db_dependency from src.iam.schemas import IAMGetGroupPermissionsResponse, IAMGetGroupUsersResponse, IAMPostGroupRequest, \ GroupResponse, IAMPostGroupResponse, IAMPutGroupPermissionRequest, IAMPutGroupPermissionResponse, \ @@ -15,8 +16,8 @@ from src.iam.schemas import IAMGetGroupPermissionsResponse, IAMGetGroupUsersResp IAMPostPermissionResponse, PermissionResponse, IAMDeletePermissionRequest, IAMGetPermissionsSearchRequest, IAMGetPermissionsSearchResponse from src.schemas import ResourceName from src.auth.service import claims_dependency -from src.user.exceptions import UserNotFoundException from src.user.models import User +from src.user.dependencies import user_model_body_dependency from src.organisation.models import Organisation as Org from src.service.models import Service from src.organisation.dependencies import org_model_body_dependency @@ -58,9 +59,8 @@ async def can_act_on_resource(valid_key: service_key_dependency, db: db_dependen return True else: return False - except Exception as e: - print(e) - raise HTTPException(status_code=500, detail="Internal server error") + except Exception: + raise UnauthorizedException() @router.get("/group/permissions", response_model=IAMGetGroupPermissionsResponse) @@ -100,13 +100,8 @@ async def add_group_permission(db: db_dependency, group_model: group_model_body_ @router.put("/group/user") -async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, request_model: IAMPutGroupUserRequest): +async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, request_model: IAMPutGroupUserRequest): # TODO: root_user_dependency - # TODO: user_model_dependency - user_model = db.get(User, request_model.user_id) - if user_model is None: - raise UserNotFoundException(user_id=request_model.user_id) - group_model.user_rel.append(user_model) db.flush() response = IAMPutGroupUserResponse(group=GroupResponse(**group_model.__dict__), users=group_model.user_rel) @@ -126,13 +121,8 @@ async def remove_group_permissions(db: db_dependency, group_model: group_model_b @router.delete("/group/user") -async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, request_model: IAMDeleteGroupUserRequest): +async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, request_model: IAMDeleteGroupUserRequest): # TODO: root_user_dependency - # TODO: User model dependency - user_model = db.get(User, request_model.user_id) - if user_model is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") - user_model.group_rel.remove(group_model) db.flush() response = IAMDeleteGroupUserResponse(group=GroupResponse(**group_model.__dict__), users=group_model.user_rel) diff --git a/src/iam/schemas.py b/src/iam/schemas.py index 70af0f8..f6cd7bb 100644 --- a/src/iam/schemas.py +++ b/src/iam/schemas.py @@ -11,6 +11,8 @@ from pydantic import EmailStr, ConfigDict from src.organisation.schemas import OrgIDMixin from src.schemas import CustomBaseModel +from user.schemas import UserIDMixin + class UserResponse(CustomBaseModel): id: int @@ -54,8 +56,8 @@ class IAMPutGroupPermissionResponse(CustomBaseModel): group: GroupResponse permissions: list[PermissionResponse] -class IAMPutGroupUserRequest(GroupIDMixin): - user_id: int +class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin): + pass class IAMPutGroupUserResponse(CustomBaseModel): group: GroupResponse @@ -68,8 +70,8 @@ class IAMDeleteGroupPermissionResponse(CustomBaseModel): group: GroupResponse permissions: list[PermissionResponse] -class IAMDeleteGroupUserRequest(GroupIDMixin): - user_id: int +class IAMDeleteGroupUserRequest(GroupIDMixin, UserIDMixin): + pass class IAMDeleteGroupUserResponse(CustomBaseModel): group: GroupResponse diff --git a/src/organisation/router.py b/src/organisation/router.py index 413e233..f198bdc 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -14,10 +14,12 @@ Endpoints: """ from typing import Annotated, Optional -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, status from fastapi.params import Query +from src.exceptions import UnprocessableContent from src.contact.schemas import ContactAddress +from src.contact.exceptions import ContactNotFoundException from src.database import db_dependency from src.contact.models import Contact from src.user.models import User @@ -150,7 +152,7 @@ async def remove_user_from_org(db: db_dependency, org_model: org_model_body_depe raise UserNotFoundException(user_id=user_id) if user not in org_model.user_rel: - raise HTTPException(status_code=status.HTTP_204_NOT_FOUND) + return org_model.user_rel.remove(user) db.commit() @@ -166,10 +168,10 @@ async def get_contact(org_model: org_model_query_dependency, contact_type: Annot case "owner": contact_model = org_model.owner_contact_rel case _: - raise HTTPException(status_code=422, detail="Invalid contact type") + raise UnprocessableContent("Invalid contact type") if contact_model is None: - raise HTTPException(status_code=404, detail="Contact not found") + raise ContactNotFoundException() return OrgContactGetResponse.model_construct( **contact_model.__dict__, @@ -187,17 +189,17 @@ async def update_contact(db: db_dependency, org_model: org_model_body_dependency case "owner": contact_model = org_model.owner_contact_rel case _: - raise HTTPException(status_code=422, detail="Invalid contact type") + raise UnprocessableContent("Invalid contact type") if contact_model is None: - raise HTTPException(status_code=404, detail="Contact not found") + raise ContactNotFoundException() update_data = request_model.model_dump(exclude_none=True) for key, value in update_data.items(): if hasattr(contact_model, key): setattr(contact_model, key, value) else: - raise HTTPException(status_code=422, detail="Invalid keys in update request") + raise UnprocessableContent("Invalid keys in update request") db.flush() response = OrgContactGetResponse.model_construct( diff --git a/src/user/service.py b/src/user/service.py index ed706b2..378d837 100644 --- a/src/user/service.py +++ b/src/user/service.py @@ -9,11 +9,11 @@ Exports: """ from typing import Any -from fastapi import HTTPException +from src.database import get_db +from src.exceptions import UnprocessableContent from src.user.schemas import OIDCUser from src.user.models import User -from src.database import get_db async def add_user_to_db(user_claims: dict[str, Any]) -> int: @@ -21,7 +21,7 @@ async def add_user_to_db(user_claims: dict[str, Any]) -> int: valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"]) except Exception as e: print(e) - raise HTTPException(status_code=422, detail="Invalid or missing OIDC data") + raise UnprocessableContent("Invalid or missing OIDC data") db = next(get_db()) db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()