feat: custom exceptions instead of direct fastapi.httpexceptions

Resolves #2
This commit is contained in:
Chris Milne 2026-05-27 14:58:10 +01:00
parent d3d3b2ca63
commit 868e56ce40
9 changed files with 73 additions and 43 deletions

View file

@ -10,18 +10,19 @@ Functions:
- Functions: Description
"""
from typing import Annotated, Any
from fastapi import Depends, HTTPException
from fastapi import Depends
from src.user.dependencies import user_model_claims_dependency
from src.organisation.dependencies import org_model_query_dependency
from src.auth.exceptions import UnauthorizedException
async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency):
if user_model in org_model.user_rel:
return True
raise HTTPException(status_code=401, detail="Not authorised")
raise UnauthorizedException()
org_query_user_claims_dependency = Annotated[dict[str, Any], Depends(org_query_user_claims)]
@ -31,7 +32,7 @@ async def org_query_root_claims(user_model: user_model_claims_dependency, org_mo
if org_model.root_user_id == user_model.id:
return True
raise HTTPException(status_code=401, detail="Not authorised")
raise UnauthorizedException()
org_query_root_claims_dependency = Annotated[dict[str, Any], Depends(org_query_root_claims)]
@ -40,8 +41,7 @@ org_query_root_claims_dependency = Annotated[dict[str, Any], Depends(org_query_r
async def is_super_admin(user_model: user_model_claims_dependency):
super_admin_emails = []
if user_model.email not in super_admin_emails:
raise HTTPException(status_code=401, detail="Not authorised")
raise UnauthorizedException()
return True

View file

@ -4,4 +4,16 @@ Module specific exceptions for auth module
Exceptions:
- List: Description
- Exceptions: Description
"""
"""
from typing import Optional
from fastapi import HTTPException, status
class UnauthorizedException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Not authorized" if not message else message
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)

View file

@ -13,9 +13,10 @@ from joserfc.errors import ExpiredTokenError
from joserfc.jwk import KeySet
from urllib.request import urlopen
from fastapi import Depends, HTTPException
from fastapi import Depends
from fastapi.security import OpenIdConnect
from src.auth.exceptions import UnauthorizedException
from src.auth.config import auth_settings
from src.user.service import add_user_to_db
@ -50,8 +51,7 @@ async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]:
try:
claims_requests.validate(token.claims)
except ExpiredTokenError:
raise HTTPException(status_code=401, detail="Token expired")
raise UnauthorizedException(message="Token is expired")
db_id = await add_user_to_db(token.claims)
token.claims["db_id"] = db_id

View file

@ -4,4 +4,16 @@ Module specific exceptions for contact module
Exceptions:
- List: Description
- Exceptions: Description
"""
"""
from typing import Optional
from fastapi import HTTPException, status
class ContactNotFoundException(HTTPException):
def __init__(self, contact_id: Optional[int] = None) -> None:
detail = "Contact not found" if contact_id is None else f"Contact with ID '{contact_id}' was not found."
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)

View file

@ -1,3 +1,15 @@
"""
Global exceptions
"""
"""
from typing import Optional
from fastapi import HTTPException, status
class UnprocessableContent(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Not authorized" if not message else message
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=detail,
)

View file

@ -5,8 +5,9 @@ Endpoints:
- List: Description
- Endpoints: Description
"""
from fastapi import APIRouter, HTTPException, status
from fastapi import APIRouter, status
from auth.exceptions import UnauthorizedException
from src.database import db_dependency
from src.iam.schemas import IAMGetGroupPermissionsResponse, IAMGetGroupUsersResponse, IAMPostGroupRequest, \
GroupResponse, IAMPostGroupResponse, IAMPutGroupPermissionRequest, IAMPutGroupPermissionResponse, \
@ -15,8 +16,8 @@ from src.iam.schemas import IAMGetGroupPermissionsResponse, IAMGetGroupUsersResp
IAMPostPermissionResponse, PermissionResponse, IAMDeletePermissionRequest, IAMGetPermissionsSearchRequest, IAMGetPermissionsSearchResponse
from src.schemas import ResourceName
from src.auth.service import claims_dependency
from src.user.exceptions import UserNotFoundException
from src.user.models import User
from src.user.dependencies import user_model_body_dependency
from src.organisation.models import Organisation as Org
from src.service.models import Service
from src.organisation.dependencies import org_model_body_dependency
@ -58,9 +59,8 @@ async def can_act_on_resource(valid_key: service_key_dependency, db: db_dependen
return True
else:
return False
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail="Internal server error")
except Exception:
raise UnauthorizedException()
@router.get("/group/permissions", response_model=IAMGetGroupPermissionsResponse)
@ -100,13 +100,8 @@ async def add_group_permission(db: db_dependency, group_model: group_model_body_
@router.put("/group/user")
async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, request_model: IAMPutGroupUserRequest):
async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, request_model: IAMPutGroupUserRequest):
# TODO: root_user_dependency
# TODO: user_model_dependency
user_model = db.get(User, request_model.user_id)
if user_model is None:
raise UserNotFoundException(user_id=request_model.user_id)
group_model.user_rel.append(user_model)
db.flush()
response = IAMPutGroupUserResponse(group=GroupResponse(**group_model.__dict__), users=group_model.user_rel)
@ -126,13 +121,8 @@ async def remove_group_permissions(db: db_dependency, group_model: group_model_b
@router.delete("/group/user")
async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, request_model: IAMDeleteGroupUserRequest):
async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, request_model: IAMDeleteGroupUserRequest):
# TODO: root_user_dependency
# TODO: User model dependency
user_model = db.get(User, request_model.user_id)
if user_model is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
user_model.group_rel.remove(group_model)
db.flush()
response = IAMDeleteGroupUserResponse(group=GroupResponse(**group_model.__dict__), users=group_model.user_rel)

View file

@ -11,6 +11,8 @@ from pydantic import EmailStr, ConfigDict
from src.organisation.schemas import OrgIDMixin
from src.schemas import CustomBaseModel
from user.schemas import UserIDMixin
class UserResponse(CustomBaseModel):
id: int
@ -54,8 +56,8 @@ class IAMPutGroupPermissionResponse(CustomBaseModel):
group: GroupResponse
permissions: list[PermissionResponse]
class IAMPutGroupUserRequest(GroupIDMixin):
user_id: int
class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin):
pass
class IAMPutGroupUserResponse(CustomBaseModel):
group: GroupResponse
@ -68,8 +70,8 @@ class IAMDeleteGroupPermissionResponse(CustomBaseModel):
group: GroupResponse
permissions: list[PermissionResponse]
class IAMDeleteGroupUserRequest(GroupIDMixin):
user_id: int
class IAMDeleteGroupUserRequest(GroupIDMixin, UserIDMixin):
pass
class IAMDeleteGroupUserResponse(CustomBaseModel):
group: GroupResponse

View file

@ -14,10 +14,12 @@ Endpoints:
"""
from typing import Annotated, Optional
from fastapi import APIRouter, HTTPException, status
from fastapi import APIRouter, status
from fastapi.params import Query
from src.exceptions import UnprocessableContent
from src.contact.schemas import ContactAddress
from src.contact.exceptions import ContactNotFoundException
from src.database import db_dependency
from src.contact.models import Contact
from src.user.models import User
@ -150,7 +152,7 @@ async def remove_user_from_org(db: db_dependency, org_model: org_model_body_depe
raise UserNotFoundException(user_id=user_id)
if user not in org_model.user_rel:
raise HTTPException(status_code=status.HTTP_204_NOT_FOUND)
return
org_model.user_rel.remove(user)
db.commit()
@ -166,10 +168,10 @@ async def get_contact(org_model: org_model_query_dependency, contact_type: Annot
case "owner":
contact_model = org_model.owner_contact_rel
case _:
raise HTTPException(status_code=422, detail="Invalid contact type")
raise UnprocessableContent("Invalid contact type")
if contact_model is None:
raise HTTPException(status_code=404, detail="Contact not found")
raise ContactNotFoundException()
return OrgContactGetResponse.model_construct(
**contact_model.__dict__,
@ -187,17 +189,17 @@ async def update_contact(db: db_dependency, org_model: org_model_body_dependency
case "owner":
contact_model = org_model.owner_contact_rel
case _:
raise HTTPException(status_code=422, detail="Invalid contact type")
raise UnprocessableContent("Invalid contact type")
if contact_model is None:
raise HTTPException(status_code=404, detail="Contact not found")
raise ContactNotFoundException()
update_data = request_model.model_dump(exclude_none=True)
for key, value in update_data.items():
if hasattr(contact_model, key):
setattr(contact_model, key, value)
else:
raise HTTPException(status_code=422, detail="Invalid keys in update request")
raise UnprocessableContent("Invalid keys in update request")
db.flush()
response = OrgContactGetResponse.model_construct(

View file

@ -9,11 +9,11 @@ Exports:
"""
from typing import Any
from fastapi import HTTPException
from src.database import get_db
from src.exceptions import UnprocessableContent
from src.user.schemas import OIDCUser
from src.user.models import User
from src.database import get_db
async def add_user_to_db(user_claims: dict[str, Any]) -> int:
@ -21,7 +21,7 @@ async def add_user_to_db(user_claims: dict[str, Any]) -> int:
valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"])
except Exception as e:
print(e)
raise HTTPException(status_code=422, detail="Invalid or missing OIDC data")
raise UnprocessableContent("Invalid or missing OIDC data")
db = next(get_db())
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()