Compare commits

..

8 commits

Author SHA1 Message Date
fc835dc982 feat: missing dependency injections on org endpoints 2026-05-27 15:59:12 +01:00
689443c05e feat: auth requirements to service endpoints 2026-05-27 15:45:31 +01:00
66c2a71c8a feat: auth requirements to org endpoints 2026-05-27 15:42:53 +01:00
789d7d9f7a feat: auth requirements to user endpoints 2026-05-27 15:36:21 +01:00
7e8ec08283 feat: auth requirements to iam endpoints 2026-05-27 15:35:06 +01:00
51bb48372c feat: auth dependency for root user with org in body 2026-05-27 15:34:18 +01:00
36736e5142 fix: auth dependency return values and types
Return values were all labelled as dicts instead of bools. Root user dependency now returns the org for which they are root user.
2026-05-27 15:22:32 +01:00
868e56ce40 feat: custom exceptions instead of direct fastapi.httpexceptions
Resolves #2
2026-05-27 14:58:10 +01:00
11 changed files with 154 additions and 125 deletions

View file

@ -9,40 +9,51 @@ Functions:
- List: Description - List: Description
- Functions: Description - Functions: Description
""" """
from typing import Annotated, Any from typing import Annotated
from fastapi import Depends, HTTPException from fastapi import Depends
from src.user.dependencies import user_model_claims_dependency from src.user.dependencies import user_model_claims_dependency
from src.organisation.dependencies import org_model_query_dependency, org_model_body_dependency
from src.organisation.models import Organisation as Org
from src.organisation.dependencies import org_model_query_dependency from src.auth.exceptions import UnauthorizedException
async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency): async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency):
if user_model in org_model.user_rel: if user_model in org_model.user_rel:
return True return True
raise HTTPException(status_code=401, detail="Not authorised") raise UnauthorizedException()
org_query_user_claims_dependency = Annotated[dict[str, Any], Depends(org_query_user_claims)] org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)]
async def org_query_root_claims(user_model: user_model_claims_dependency, org_model: org_model_query_dependency): async def org_query_root_claims(user_model: user_model_claims_dependency, org_model: org_model_query_dependency):
if org_model.root_user_id == user_model.id: if org_model.root_user_id == user_model.id:
return True return org_model
raise HTTPException(status_code=401, detail="Not authorised") raise UnauthorizedException()
org_query_root_claims_dependency = Annotated[dict[str, Any], Depends(org_query_root_claims)] org_model_root_claim_query_dependency = Annotated[type[Org], Depends(org_query_root_claims)]
async def org_body_root_claims(user_model: user_model_claims_dependency, org_model: org_model_body_dependency):
if org_model.root_user_id == user_model.id:
return org_model
raise UnauthorizedException()
org_model_root_claim_body_dependency = Annotated[type[Org], Depends(org_body_root_claims)]
async def is_super_admin(user_model: user_model_claims_dependency): async def is_super_admin(user_model: user_model_claims_dependency):
super_admin_emails = [] super_admin_emails = []
if user_model.email not in super_admin_emails: if user_model.email not in super_admin_emails:
raise HTTPException(status_code=401, detail="Not authorised") raise UnauthorizedException()
return True return True
super_admin_dependency = Annotated[dict[str, Any], Depends(is_super_admin)] super_admin_dependency = Annotated[bool, Depends(is_super_admin)]

View file

@ -5,3 +5,15 @@ Exceptions:
- List: Description - List: Description
- Exceptions: Description - Exceptions: Description
""" """
from typing import Optional
from fastapi import HTTPException, status
class UnauthorizedException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Not authorized" if not message else message
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)

View file

@ -13,9 +13,10 @@ from joserfc.errors import ExpiredTokenError
from joserfc.jwk import KeySet from joserfc.jwk import KeySet
from urllib.request import urlopen from urllib.request import urlopen
from fastapi import Depends, HTTPException from fastapi import Depends
from fastapi.security import OpenIdConnect from fastapi.security import OpenIdConnect
from src.auth.exceptions import UnauthorizedException
from src.auth.config import auth_settings from src.auth.config import auth_settings
from src.user.service import add_user_to_db from src.user.service import add_user_to_db
@ -50,8 +51,7 @@ async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]:
try: try:
claims_requests.validate(token.claims) claims_requests.validate(token.claims)
except ExpiredTokenError: except ExpiredTokenError:
raise HTTPException(status_code=401, detail="Token expired") raise UnauthorizedException(message="Token is expired")
db_id = await add_user_to_db(token.claims) db_id = await add_user_to_db(token.claims)
token.claims["db_id"] = db_id token.claims["db_id"] = db_id

View file

@ -5,3 +5,15 @@ Exceptions:
- List: Description - List: Description
- Exceptions: Description - Exceptions: Description
""" """
from typing import Optional
from fastapi import HTTPException, status
class ContactNotFoundException(HTTPException):
def __init__(self, contact_id: Optional[int] = None) -> None:
detail = "Contact not found" if contact_id is None else f"Contact with ID '{contact_id}' was not found."
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)

View file

@ -1,3 +1,15 @@
""" """
Global exceptions Global exceptions
""" """
from typing import Optional
from fastapi import HTTPException, status
class UnprocessableContent(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Not authorized" if not message else message
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=detail,
)

View file

@ -5,25 +5,28 @@ Endpoints:
- List: Description - List: Description
- Endpoints: Description - Endpoints: Description
""" """
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, status
from src.database import db_dependency from src.database import db_dependency
from src.schemas import ResourceName
from src.auth.exceptions import UnauthorizedException
from src.auth.service import claims_dependency
from src.auth.dependencies import org_model_root_claim_query_dependency, org_model_root_claim_body_dependency, \
super_admin_dependency
from src.user.models import User
from src.user.dependencies import user_model_body_dependency
from src.organisation.models import Organisation as Org
from src.service.models import Service
from src.iam.service import service_key_dependency
from src.iam.models import Permission as Perm, GroupPermissions as GPerms, Group, UserGroups
from src.iam.dependencies import group_model_query_dependency, group_model_body_dependency, perm_model_body_dependency
from src.iam.schemas import IAMGetGroupPermissionsResponse, IAMGetGroupUsersResponse, IAMPostGroupRequest, \ from src.iam.schemas import IAMGetGroupPermissionsResponse, IAMGetGroupUsersResponse, IAMPostGroupRequest, \
GroupResponse, IAMPostGroupResponse, IAMPutGroupPermissionRequest, IAMPutGroupPermissionResponse, \ GroupResponse, IAMPostGroupResponse, IAMPutGroupPermissionRequest, IAMPutGroupPermissionResponse, \
IAMPutGroupUserRequest, IAMPutGroupUserResponse, IAMDeleteGroupPermissionRequest, IAMDeleteGroupPermissionResponse, \ IAMPutGroupUserRequest, IAMPutGroupUserResponse, IAMDeleteGroupPermissionRequest, IAMDeleteGroupPermissionResponse, \
IAMDeleteGroupUserRequest, IAMDeleteGroupUserResponse, IAMGetPermissionsResponse, IAMPostPermissionRequest, \ IAMDeleteGroupUserRequest, IAMDeleteGroupUserResponse, IAMGetPermissionsResponse, IAMPostPermissionRequest, \
IAMPostPermissionResponse, PermissionResponse, IAMDeletePermissionRequest, IAMGetPermissionsSearchRequest, IAMGetPermissionsSearchResponse IAMPostPermissionResponse, PermissionResponse, IAMDeletePermissionRequest, IAMGetPermissionsSearchRequest, IAMGetPermissionsSearchResponse
from src.schemas import ResourceName
from src.auth.service import claims_dependency
from src.user.exceptions import UserNotFoundException
from src.user.models import User
from src.organisation.models import Organisation as Org
from src.service.models import Service
from src.organisation.dependencies import org_model_body_dependency
from src.iam.service import service_key_dependency
from src.iam.models import Permission as Perm, GroupPermissions as GPerms, Group, UserGroups
from src.iam.dependencies import group_model_query_dependency, group_model_body_dependency, perm_model_body_dependency
router = APIRouter( router = APIRouter(
tags=["IAM"], tags=["IAM"],
@ -58,27 +61,26 @@ async def can_act_on_resource(valid_key: service_key_dependency, db: db_dependen
return True return True
else: else:
return False return False
except Exception as e: except Exception:
print(e) raise UnauthorizedException()
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/group/permissions", response_model=IAMGetGroupPermissionsResponse) @router.get("/group/permissions", response_model=IAMGetGroupPermissionsResponse)
async def get_group_permissions(group_model: group_model_query_dependency): async def get_group_permissions(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency):
# TODO: root_user_dependency if group_model.org_id != org_model.id:
raise UnauthorizedException()
return {"permissions": group_model.permission_rel} return {"permissions": group_model.permission_rel}
@router.get("/group/users", response_model=IAMGetGroupUsersResponse) @router.get("/group/users", response_model=IAMGetGroupUsersResponse)
async def get_group_users(group_model: group_model_query_dependency): async def get_group_users(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency):
# TODO: root_user_dependency if group_model.org_id == org_model.id:
raise UnauthorizedException()
return {"users": group_model.user_rel} return {"users": group_model.user_rel}
@router.post("/group", response_model=IAMPostGroupResponse) @router.post("/group", response_model=IAMPostGroupResponse)
async def create_group(db: db_dependency, request_model: IAMPostGroupRequest, org_model: org_model_body_dependency): async def create_group(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPostGroupRequest):
# TODO: root_user_dependency
# TODO: get org ID from dependency instead of query (needs updated dep first)
group_model = Group(name=request_model.name, org_id=org_model.id) group_model = Group(name=request_model.name, org_id=org_model.id)
db.add(group_model) db.add(group_model)
@ -89,8 +91,10 @@ async def create_group(db: db_dependency, request_model: IAMPostGroupRequest, or
@router.put("/group/permission", response_model=IAMPutGroupPermissionResponse) @router.put("/group/permission", response_model=IAMPutGroupPermissionResponse)
async def add_group_permission(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, request_model: IAMPutGroupPermissionRequest): async def add_group_permission(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupPermissionRequest):
# TODO: root_user_dependency if group_model.org_id == org_model.id:
raise UnauthorizedException()
group_model.permission_rel.append(perm_model) group_model.permission_rel.append(perm_model)
db.flush() db.flush()
@ -100,12 +104,9 @@ async def add_group_permission(db: db_dependency, group_model: group_model_body_
@router.put("/group/user") @router.put("/group/user")
async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, request_model: IAMPutGroupUserRequest): async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupUserRequest):
# TODO: root_user_dependency if group_model.org_id == org_model.id:
# TODO: user_model_dependency raise UnauthorizedException()
user_model = db.get(User, request_model.user_id)
if user_model is None:
raise UserNotFoundException(user_id=request_model.user_id)
group_model.user_rel.append(user_model) group_model.user_rel.append(user_model)
db.flush() db.flush()
@ -115,8 +116,10 @@ async def add_group_user(db: db_dependency, group_model: group_model_body_depend
@router.delete("/group/permissions") @router.delete("/group/permissions")
async def remove_group_permissions(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, request_model: IAMDeleteGroupPermissionRequest): async def remove_group_permissions(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupPermissionRequest):
# TODO: root_user_dependency if group_model.org_id == org_model.id:
raise UnauthorizedException()
group_model.permission_rel.remove(perm_model) group_model.permission_rel.remove(perm_model)
db.flush() db.flush()
response = IAMDeleteGroupPermissionResponse(group=GroupResponse(**group_model.__dict__), response = IAMDeleteGroupPermissionResponse(group=GroupResponse(**group_model.__dict__),
@ -126,12 +129,9 @@ async def remove_group_permissions(db: db_dependency, group_model: group_model_b
@router.delete("/group/user") @router.delete("/group/user")
async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, request_model: IAMDeleteGroupUserRequest): async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupUserRequest):
# TODO: root_user_dependency if group_model.org_id == org_model.id:
# TODO: User model dependency raise UnauthorizedException()
user_model = db.get(User, request_model.user_id)
if user_model is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
user_model.group_rel.remove(group_model) user_model.group_rel.remove(group_model)
db.flush() db.flush()
@ -142,16 +142,14 @@ async def remove_group_user(db: db_dependency, group_model: group_model_body_dep
@router.get("/permissions", response_model=IAMGetPermissionsResponse) @router.get("/permissions", response_model=IAMGetPermissionsResponse)
async def get_permissions(db: db_dependency): async def get_permissions(db: db_dependency, org_model: org_model_root_claim_body_dependency):
# TODO: root_user_dependency
permission_models = db.query(Perm).all() permission_models = db.query(Perm).all()
return {"permissions": permission_models} return {"permissions": permission_models}
@router.post("/permission") @router.post("/permission")
async def create_new_permission(db: db_dependency, request_mode: IAMPostPermissionRequest): async def create_new_permission(db: db_dependency, su: super_admin_dependency, request_mode: IAMPostPermissionRequest):
# TODO: super_admin_dependency
perm_model = Perm(**request_mode.__dict__) perm_model = Perm(**request_mode.__dict__)
db.add(perm_model) db.add(perm_model)
@ -162,15 +160,13 @@ async def create_new_permission(db: db_dependency, request_mode: IAMPostPermissi
@router.delete("/permission", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/permission", status_code=status.HTTP_204_NO_CONTENT)
async def delete_permission(db: db_dependency, perm_model: perm_model_body_dependency, request_model: IAMDeletePermissionRequest): async def delete_permission(db: db_dependency, su: super_admin_dependency, perm_model: perm_model_body_dependency, request_model: IAMDeletePermissionRequest):
# TODO: super_admin_dependency
db.delete(perm_model) db.delete(perm_model)
db.commit() db.commit()
@router.get("/permissions/search", response_model=IAMGetPermissionsSearchResponse) @router.get("/permissions/search", response_model=IAMGetPermissionsSearchResponse)
async def get_permissions(db: db_dependency, search: IAMGetPermissionsSearchRequest): async def get_permissions(db: db_dependency, org_model: org_model_root_claim_body_dependency, search: IAMGetPermissionsSearchRequest):
# TODO: root_user_dependency
permission_query = db.query(Perm) permission_query = db.query(Perm)
if search.service_id is not None: if search.service_id is not None:

View file

@ -11,6 +11,8 @@ from pydantic import EmailStr, ConfigDict
from src.organisation.schemas import OrgIDMixin from src.organisation.schemas import OrgIDMixin
from src.schemas import CustomBaseModel from src.schemas import CustomBaseModel
from user.schemas import UserIDMixin
class UserResponse(CustomBaseModel): class UserResponse(CustomBaseModel):
id: int id: int
@ -54,8 +56,8 @@ class IAMPutGroupPermissionResponse(CustomBaseModel):
group: GroupResponse group: GroupResponse
permissions: list[PermissionResponse] permissions: list[PermissionResponse]
class IAMPutGroupUserRequest(GroupIDMixin): class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin):
user_id: int pass
class IAMPutGroupUserResponse(CustomBaseModel): class IAMPutGroupUserResponse(CustomBaseModel):
group: GroupResponse group: GroupResponse
@ -68,8 +70,8 @@ class IAMDeleteGroupPermissionResponse(CustomBaseModel):
group: GroupResponse group: GroupResponse
permissions: list[PermissionResponse] permissions: list[PermissionResponse]
class IAMDeleteGroupUserRequest(GroupIDMixin): class IAMDeleteGroupUserRequest(GroupIDMixin, UserIDMixin):
user_id: int pass
class IAMDeleteGroupUserResponse(CustomBaseModel): class IAMDeleteGroupUserResponse(CustomBaseModel):
group: GroupResponse group: GroupResponse

View file

@ -14,17 +14,19 @@ Endpoints:
""" """
from typing import Annotated, Optional from typing import Annotated, Optional
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, status
from fastapi.params import Query from fastapi.params import Query
from src.contact.schemas import ContactAddress from src.exceptions import UnprocessableContent
from src.database import db_dependency
from src.contact.models import Contact from src.contact.models import Contact
from src.contact.schemas import ContactAddress
from src.contact.exceptions import ContactNotFoundException
from src.database import db_dependency
from src.user.models import User from src.user.models import User
from src.user.exceptions import UserNotFoundException from user.dependencies import user_model_body_dependency, user_model_claims_dependency
from src.auth.service import claims_dependency from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency, org_model_root_claim_body_dependency
from src.organisation.dependencies import org_model_query_dependency, org_model_body_dependency from src.organisation.dependencies import org_model_body_dependency
from src.organisation.constants import ContactType from src.organisation.constants import ContactType
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
from src.organisation.schemas import OrgOrgPostRequest, OrgQuestionnairePatchRequest, OrgStatusPatchRequest, \ from src.organisation.schemas import OrgOrgPostRequest, OrgQuestionnairePatchRequest, OrgStatusPatchRequest, \
@ -32,6 +34,7 @@ from src.organisation.schemas import OrgOrgPostRequest, OrgQuestionnairePatchReq
OrgUserPostRequest, OrgUserGetResponse, OrgContactGetResponse, OrgOrgGetResponse, OrgRootPatchRequest, \ OrgUserPostRequest, OrgUserGetResponse, OrgContactGetResponse, OrgOrgGetResponse, OrgRootPatchRequest, \
OrgGroupGetResponse, OrgUserDeleteRequest, OrgDeleteOrgRequest OrgGroupGetResponse, OrgUserDeleteRequest, OrgDeleteOrgRequest
router = APIRouter( router = APIRouter(
prefix="/org", prefix="/org",
tags=["org"], tags=["org"],
@ -39,7 +42,7 @@ router = APIRouter(
@router.get("/id", response_model=OrgOrgGetResponse) @router.get("/id", response_model=OrgOrgGetResponse)
async def get_org_by_id(org_model: org_model_query_dependency): async def get_org_by_id(org_model: org_model_root_claim_query_dependency):
response = { response = {
"name": org_model.name, "name": org_model.name,
"status": org_model.status, "status": org_model.status,
@ -53,11 +56,7 @@ async def get_org_by_id(org_model: org_model_query_dependency):
@router.post("/") @router.post("/")
async def create_org(db: db_dependency, user: claims_dependency, request_model: OrgOrgPostRequest): async def create_org(db: db_dependency, user_model: user_model_claims_dependency, request_model: OrgOrgPostRequest):
db_id: Optional[int] = user.get("db_id", None)
if db_id is None:
raise UserNotFoundException()
if request_model.intake_questionnaire: if request_model.intake_questionnaire:
intake_questionnaire = request_model.intake_questionnaire.model_dump() intake_questionnaire = request_model.intake_questionnaire.model_dump()
else: else:
@ -81,7 +80,7 @@ async def create_org(db: db_dependency, user: claims_dependency, request_model:
@router.patch("/questionnaire") @router.patch("/questionnaire")
async def update_questionnaire(db: db_dependency, org_model: org_model_body_dependency, request_model: OrgQuestionnairePatchRequest): async def update_questionnaire(db: db_dependency, org_model: org_model_root_claim_query_dependency, request_model: OrgQuestionnairePatchRequest):
""" """
Route for updating questionnaire. Route for updating questionnaire.
The partial bool allows for submission of partially completed questionnaire and/or The partial bool allows for submission of partially completed questionnaire and/or
@ -97,21 +96,19 @@ async def update_questionnaire(db: db_dependency, org_model: org_model_body_depe
@router.patch("/status") @router.patch("/status")
async def update_status(db: db_dependency, org_model: org_model_body_dependency, request_model: OrgStatusPatchRequest): async def update_status(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgStatusPatchRequest):
org_model.status = request_model.status org_model.status = request_model.status
db.commit() db.commit()
@router.get("/users", response_model=OrgUserGetResponse) @router.get("/users", response_model=OrgUserGetResponse)
async def get_users(org_model: org_model_query_dependency): async def get_users(org_model: org_model_root_claim_query_dependency):
return {"users": [user.email for user in org_model.user_rel]} return {"users": [user.email for user in org_model.user_rel]}
@router.post("/users") @router.post("/users")
async def add_user_to_org(db: db_dependency, org_model: org_model_body_dependency, request_model: OrgUserPostRequest): async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgUserPostRequest):
# TODO: user_model_body_dependency
user_model = db.get(User, request_model.user_id)
if user_model in org_model.user_rel: if user_model in org_model.user_rel:
return return
org_model.user_rel.append(user_model) org_model.user_rel.append(user_model)
@ -119,45 +116,33 @@ async def add_user_to_org(db: db_dependency, org_model: org_model_body_dependenc
@router.delete("/", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/", status_code=status.HTTP_204_NO_CONTENT)
async def delete_organisation_by_id(db: db_dependency, org_model: org_model_body_dependency, request_model: OrgDeleteOrgRequest): async def delete_organisation_by_id(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgDeleteOrgRequest):
db.delete(org_model) db.delete(org_model)
db.commit() db.commit()
@router.patch("/root_user", status_code=status.HTTP_204_NO_CONTENT) @router.patch("/root_user", status_code=status.HTTP_204_NO_CONTENT)
async def update_root_user(db: db_dependency, org_model: org_model_body_dependency, request_model: OrgRootPatchRequest): async def update_root_user(db: db_dependency, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, request_model: OrgRootPatchRequest):
# TODO: user_model_body_dependency org_model.root_user_rel = user_model
root_user_model = db.get(User, request_model.user_id)
if root_user_model is None:
raise UserNotFoundException(user_id=request_model.user_id)
org_model.root_user_rel = root_user_model
db.commit() db.commit()
@router.get("/groups", response_model=OrgGroupGetResponse) @router.get("/groups", response_model=OrgGroupGetResponse)
async def get_org_groups(org_model: org_model_query_dependency): async def get_org_groups(org_model: org_model_root_claim_query_dependency):
return {"groups": [group.name for group in org_model.group_rel]} return {"groups": [group.name for group in org_model.group_rel]}
@router.delete("/user", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/user", status_code=status.HTTP_204_NO_CONTENT)
async def remove_user_from_org(db: db_dependency, org_model: org_model_body_dependency, request_model: OrgUserDeleteRequest): async def remove_user_from_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgUserDeleteRequest):
# TODO: user_model_body_dependency if user_model not in org_model.user_rel:
user_id = request_model.user_id return
user = db.get(User, user_id)
if user is None: org_model.user_rel.remove(user_model)
raise UserNotFoundException(user_id=user_id)
if user not in org_model.user_rel:
raise HTTPException(status_code=status.HTTP_204_NOT_FOUND)
org_model.user_rel.remove(user)
db.commit() db.commit()
@router.get("/contact", response_model=OrgContactGetResponse) @router.get("/contact", response_model=OrgContactGetResponse)
async def get_contact(org_model: org_model_query_dependency, contact_type: Annotated[ContactType, Query()]): async def get_contact(org_model: org_model_root_claim_query_dependency, contact_type: Annotated[ContactType, Query()]):
match contact_type: match contact_type:
case "billing": case "billing":
contact_model = org_model.billing_contact_rel contact_model = org_model.billing_contact_rel
@ -166,10 +151,10 @@ async def get_contact(org_model: org_model_query_dependency, contact_type: Annot
case "owner": case "owner":
contact_model = org_model.owner_contact_rel contact_model = org_model.owner_contact_rel
case _: case _:
raise HTTPException(status_code=422, detail="Invalid contact type") raise UnprocessableContent("Invalid contact type")
if contact_model is None: if contact_model is None:
raise HTTPException(status_code=404, detail="Contact not found") raise ContactNotFoundException()
return OrgContactGetResponse.model_construct( return OrgContactGetResponse.model_construct(
**contact_model.__dict__, **contact_model.__dict__,
@ -178,7 +163,7 @@ async def get_contact(org_model: org_model_query_dependency, contact_type: Annot
@router.patch("/contact", response_model=OrgContactGetResponse) @router.patch("/contact", response_model=OrgContactGetResponse)
async def update_contact(db: db_dependency, org_model: org_model_body_dependency, request_model: OrgContactPatchRequest): async def update_contact(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: OrgContactPatchRequest):
match request_model.contact_type: match request_model.contact_type:
case "billing": case "billing":
contact_model = org_model.billing_contact_rel contact_model = org_model.billing_contact_rel
@ -187,17 +172,17 @@ async def update_contact(db: db_dependency, org_model: org_model_body_dependency
case "owner": case "owner":
contact_model = org_model.owner_contact_rel contact_model = org_model.owner_contact_rel
case _: case _:
raise HTTPException(status_code=422, detail="Invalid contact type") raise UnprocessableContent("Invalid contact type")
if contact_model is None: if contact_model is None:
raise HTTPException(status_code=404, detail="Contact not found") raise ContactNotFoundException()
update_data = request_model.model_dump(exclude_none=True) update_data = request_model.model_dump(exclude_none=True)
for key, value in update_data.items(): for key, value in update_data.items():
if hasattr(contact_model, key): if hasattr(contact_model, key):
setattr(contact_model, key, value) setattr(contact_model, key, value)
else: else:
raise HTTPException(status_code=422, detail="Invalid keys in update request") raise UnprocessableContent("Invalid keys in update request")
db.flush() db.flush()
response = OrgContactGetResponse.model_construct( response = OrgContactGetResponse.model_construct(

View file

@ -8,6 +8,8 @@ Endpoints:
from fastapi import APIRouter, status from fastapi import APIRouter, status
from src.database import db_dependency from src.database import db_dependency
from src.auth.service import claims_dependency
from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency, org_model_root_claim_body_dependency
from src.service.models import Service from src.service.models import Service
from src.service.utils import generate_api_key from src.service.utils import generate_api_key
@ -21,15 +23,13 @@ router = APIRouter(
) )
@router.get("/", response_model=ServiceGetServiceResponse) @router.get("/", response_model=ServiceGetServiceResponse)
async def get_all_services(db: db_dependency): async def get_all_services(db: db_dependency, org_model: org_model_root_claim_query_dependency):
# TODO: user_dependency
permission_models = db.query(Service).all() permission_models = db.query(Service).all()
return {"services": permission_models} return {"services": permission_models}
@router.post("/", response_model=ServicePostServiceResponse) @router.post("/", response_model=ServicePostServiceResponse)
async def register_service(db: db_dependency, service_request: ServicePostServiceRequest): async def register_service(db: db_dependency, su: super_admin_dependency, service_request: ServicePostServiceRequest):
# TODO: super_admin_dependency
key = generate_api_key() key = generate_api_key()
service_model = Service(name=service_request.name, api_key=key) service_model = Service(name=service_request.name, api_key=key)
@ -40,8 +40,7 @@ async def register_service(db: db_dependency, service_request: ServicePostServic
return {"service": response} return {"service": response}
@router.patch("/key", response_model=ServicePatchKeyResponse) @router.patch("/key", response_model=ServicePatchKeyResponse)
async def regenerate_api_key(db: db_dependency, service_model: service_model_body_dependency, request_model: ServicePatchKeyRequest): async def regenerate_api_key(db: db_dependency, su: super_admin_dependency, service_model: service_model_body_dependency, request_model: ServicePatchKeyRequest):
# TODO: super_admin_dependency
key = generate_api_key() key = generate_api_key()
service_model.api_key = key service_model.api_key = key
@ -51,7 +50,6 @@ async def regenerate_api_key(db: db_dependency, service_model: service_model_bod
return {"service": response} return {"service": response}
@router.delete("/", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/", status_code=status.HTTP_204_NO_CONTENT)
async def remove_service(db: db_dependency, service_model: service_model_body_dependency, request_model: ServiceDeleteServiceRequest): async def remove_service(db: db_dependency, service_model: service_model_body_dependency, su: super_admin_dependency, request_model: ServiceDeleteServiceRequest):
# TODO: super_admin_dependency
db.delete(service_model) db.delete(service_model)
db.commit() db.commit()

View file

@ -17,6 +17,7 @@ from starlette import status
from src.user.schemas import UserResponse, OIDCClaims, UserDeleteUserRequest from src.user.schemas import UserResponse, OIDCClaims, UserDeleteUserRequest
from src.user.dependencies import user_model_claims_dependency, user_model_query_dependency, user_model_body_dependency from src.user.dependencies import user_model_claims_dependency, user_model_query_dependency, user_model_body_dependency
from src.auth.dependencies import super_admin_dependency
from src.auth.service import claims_dependency from src.auth.service import claims_dependency
from src.database import db_dependency from src.database import db_dependency
@ -52,7 +53,7 @@ async def current_user(user_model: user_model_claims_dependency):
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
}) })
async def get_user_by_id(user_model: user_model_query_dependency): async def get_user_by_id(user_model: user_model_query_dependency, su: super_admin_dependency):
""" """
Returns the database details associated with the provided user ID. Returns the database details associated with the provided user ID.
""" """
@ -63,7 +64,7 @@ async def get_user_by_id(user_model: user_model_query_dependency):
status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, status.HTTP_404_NOT_FOUND: {"description": "User not found"},
}) })
async def delete_user_by_id(db: db_dependency, user_model: user_model_body_dependency, request_model: UserDeleteUserRequest): async def delete_user_by_id(db: db_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, request_model: UserDeleteUserRequest):
""" """
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login. Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
""" """

View file

@ -9,11 +9,11 @@ Exports:
""" """
from typing import Any from typing import Any
from fastapi import HTTPException from src.database import get_db
from src.exceptions import UnprocessableContent
from src.user.schemas import OIDCUser from src.user.schemas import OIDCUser
from src.user.models import User from src.user.models import User
from src.database import get_db
async def add_user_to_db(user_claims: dict[str, Any]) -> int: async def add_user_to_db(user_claims: dict[str, Any]) -> int:
@ -21,7 +21,7 @@ async def add_user_to_db(user_claims: dict[str, Any]) -> int:
valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"]) valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"])
except Exception as e: except Exception as e:
print(e) print(e)
raise HTTPException(status_code=422, detail="Invalid or missing OIDC data") raise UnprocessableContent("Invalid or missing OIDC data")
db = next(get_db()) db = next(get_db())
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first() db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()