diff --git a/requirements.txt b/requirements.txt index 007020b..b48c39d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ httptools psycopg email-validator alembic +joserfc diff --git a/src/admin/router.py b/src/admin/router.py index 595335d..e6df7de 100644 --- a/src/admin/router.py +++ b/src/admin/router.py @@ -5,20 +5,17 @@ Endpoints: - List: Description - Endpoints: Description """ +from typing import Annotated + from fastapi import APIRouter, HTTPException from fastapi.params import Path -from sqlalchemy.sql import exists from src.organisation.constants import ContactType from src.organisation.schemas import OrgContactGetResponse from src.organisation.models import Organisation as Org from src.contact.models import Contact -from src.user.models import User -from src.user.schemas import UserResponse, OIDCUser, OrgResponse -from src.organisation.models import OrgUsers, Organisation - -from src.auth.service import claims_dependency, org_user_dependency, org_admin_dependency +from src.auth.service import claims_dependency, org_or_super_admin_dependency from src.database import db_dependency @@ -29,7 +26,7 @@ router = APIRouter( @router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse) -async def get_contact(db: db_dependency, user: claims_dependency, is_org_admin: org_user_dependency, contact_type: ContactType, org_id: int = Path(gt=0)): +async def get_contact(db: db_dependency, user: claims_dependency, is_admin: org_or_super_admin_dependency, contact_type: ContactType, org_id: Annotated[int, Path(gt=0)]): org_model = db.query(Org).filter(Org.id == org_id).first() if org_model is None: raise HTTPException(status_code=404, detail="Organisation not found") diff --git a/src/auth/service.py b/src/auth/service.py index 80d49c5..fc3cd4c 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -5,16 +5,16 @@ Exports: - claims_dependency """ import json +import requests -from typing import Annotated -from authlib.jose import jwt +from typing import Annotated, Any +from joserfc import jwt +from joserfc.errors import ExpiredTokenError +from joserfc.jwk import KeySet from urllib.request import urlopen from fastapi import Depends, HTTPException, Path from fastapi.security import OpenIdConnect -from authlib.jose.rfc7517.jwk import JsonWebKey -from authlib.jose.rfc7517.key_set import KeySet -from authlib.oauth2.rfc7523.validator import JWTBearerToken from sqlalchemy.sql import exists from src.auth.config import auth_settings @@ -27,12 +27,12 @@ oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc_dependency = Annotated[str, Depends(oidc)] -async def get_current_user(oidc_auth_string: oidc_dependency) -> JWTBearerToken: +async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]: config_url = urlopen(auth_settings.OIDC_CONFIG) config = json.loads(config_url.read()) jwks_uri = config["jwks_uri"] - key_response = urlopen(jwks_uri) - jwk_keys: KeySet = JsonWebKey.import_key_set(json.loads(key_response.read())) + key_response = requests.get(jwks_uri) + jwk_keys = KeySet.import_key_set(key_response.json()) claims_options = { "exp": {"essential": True}, @@ -40,22 +40,26 @@ async def get_current_user(oidc_auth_string: oidc_dependency) -> JWTBearerToken: "iss": {"essential": True, "value": auth_settings.OIDC_ISSUER}, } - claims: JWTBearerToken = jwt.decode( + token = jwt.decode( oidc_auth_string.replace("Bearer ", ""), - jwk_keys, - claims_options=claims_options, - claims_cls=JWTBearerToken, + jwk_keys ) - claims.validate() - db_id = await add_user_to_db(claims) + claims_requests = jwt.JWTClaimsRegistry(**claims_options) - claims["db_id"] = db_id + try: + claims_requests.validate(token.claims) + except ExpiredTokenError as e: + raise HTTPException(status_code=401, detail="Token expired") - return claims + db_id = await add_user_to_db(token.claims) + + token.claims["db_id"] = db_id + + return token.claims -claims_dependency = Annotated[JWTBearerToken, Depends(get_current_user)] +claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): @@ -81,7 +85,7 @@ async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int return org_user_exists -org_user_dependency = Annotated[JWTBearerToken, Depends(is_org_user)] +org_user_dependency = Annotated[dict[str, Any], Depends(is_org_user)] async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): @@ -108,7 +112,7 @@ async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int return org_admin_exists -org_admin_dependency = Annotated[JWTBearerToken, Depends(is_org_admin)] +org_admin_dependency = Annotated[dict[str, Any], Depends(is_org_admin)] async def is_super_admin(claims: claims_dependency): @@ -123,10 +127,22 @@ async def is_super_admin(claims: claims_dependency): return True -super_admin_dependency = Annotated[JWTBearerToken, Depends(is_super_admin)] +super_admin_dependency = Annotated[dict[str, Any], Depends(is_super_admin)] +async def is_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): + try: + await is_super_admin(claims) + return True + except HTTPException as e: + pass + try: + await is_org_admin(claims, db, org_id) + return True + except HTTPException as e: + raise HTTPException(status_code=401, detail="Not authorised") +org_or_super_admin_dependency = Annotated[dict[str, Any], Depends(is_admin)] # Middleware version of user auth # import json diff --git a/src/contact/router.py b/src/contact/router.py index 2ef34ab..3ed9a0f 100644 --- a/src/contact/router.py +++ b/src/contact/router.py @@ -9,6 +9,8 @@ Endpoints: - [patch]/{contact_id} - Updates the details of an existing contact - [delete]/{contact_id} - Deletes a contact by ID """ +from typing import Annotated + from fastapi import APIRouter, HTTPException from fastapi.params import Path @@ -55,7 +57,7 @@ async def create_contact(db: db_dependency, contact_request: ContactContactPostR @router.patch("/{contact_id}") -async def update_contact(db: db_dependency, contact_request: ContactUpdateRequest, contact_id: int = Path(gt=0)): +async def update_contact(db: db_dependency, contact_request: ContactUpdateRequest, contact_id: Annotated[int, Path(gt=0)]): contact_model = (db.query(Contact).filter(Contact.id == contact_id).first()) if contact_model is None: raise HTTPException(status_code=404, detail="Contact not found") @@ -72,7 +74,7 @@ async def update_contact(db: db_dependency, contact_request: ContactUpdateReques @router.delete("/{contact_id}") -async def delete_contact(db: db_dependency, contact_id: int = Path(gt=0)): +async def delete_contact(db: db_dependency, contact_id: Annotated[int, Path(gt=0)]): contact_model = (db.query(Contact).filter(Contact.id == contact_id).first()) if contact_model is None: raise HTTPException(status_code=404, detail="Contact not found") @@ -82,7 +84,7 @@ async def delete_contact(db: db_dependency, contact_id: int = Path(gt=0)): @router.get("/{contact_id}/orgs", response_model=list[ContactOrgGetResponse]) -async def get_contact_orgs(db: db_dependency, contact_id: int = Path(gt=0)): +async def get_contact_orgs(db: db_dependency, contact_id: Annotated[int, Path(gt=0)]): contact_model = (db.query(Contact).filter(Contact.id == contact_id).first()) if contact_model is None: raise HTTPException(status_code=404, detail="Contact not found") diff --git a/src/contact/schemas.py b/src/contact/schemas.py index 9e27c81..d64e265 100644 --- a/src/contact/schemas.py +++ b/src/contact/schemas.py @@ -7,7 +7,7 @@ Models: """ from typing import Optional -from pydantic import Field, EmailStr +from pydantic import EmailStr from src.organisation.constants import ContactType from src.schemas import CustomBaseModel diff --git a/src/main.py b/src/main.py index 4f122c6..4b4d1a9 100644 --- a/src/main.py +++ b/src/main.py @@ -26,15 +26,28 @@ if settings.ENVIRONMENT.is_deployed: pass +tags_metadata = [ + { + "name": "User", + "description": "User related operations, includes getting information about the current user", + } +] + + app = FastAPI( swagger_ui_init_oauth={ "clientId": auth_settings.CLIENT_ID, "usePkceWithAuthorizationCodeGrant": True, "scopes": "openid profile email", - } + }, + openapi_tags=tags_metadata, ) +# Type inspection disabled for middleware injection. +# Known bug in FastAPI type checking: https://github.com/astral-sh/ty/issues/1635 +# noinspection PyTypeChecker app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value()) +# noinspection PyTypeChecker app.add_middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS, diff --git a/src/organisation/models.py b/src/organisation/models.py index db56359..de2b17a 100644 --- a/src/organisation/models.py +++ b/src/organisation/models.py @@ -9,8 +9,6 @@ Models: from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, JSON, false from src.database import Base -from src.contact.models import Contact -from src.user.models import User class Organisation(Base): diff --git a/src/organisation/router.py b/src/organisation/router.py index 40f4ac4..3313158 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -14,12 +14,13 @@ Endpoints: - [delete]/{org_id} - Deletes an organisation by ID - [get]/{org_id}/contact/{contact_type} - Retrieves the contact of a specific type (owner, billing, security) for an organisation """ +from typing import Annotated + from fastapi import APIRouter, HTTPException from fastapi.params import Path from sqlalchemy.sql import exists -from src.auth.service import super_admin_dependency from src.database import db_dependency from src.contact.models import Contact @@ -36,7 +37,7 @@ router = APIRouter( @router.get("/id/{org_id}", response_model=OrgOrgGetResponse) -async def get_org_by_id(db: db_dependency, org_id: int = Path(gt=0)): +async def get_org_by_id(db: db_dependency, org_id: Annotated[int, Path(gt=0)]): org_model = (db.query(Org).filter(Org.id == org_id).first()) if org_model is None: raise HTTPException(status_code=404, detail="Organisation not found") @@ -63,7 +64,7 @@ async def create_org(db: db_dependency, org_request: OrgOrgPostRequest): @router.patch("/{org_id}/questionnaire") -async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePatchRequest, org_id: int = Path(gt=0)): +async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePatchRequest, org_id: Annotated[int, Path(gt=0)]): """ Route for updating questionnaire. The partial bool allows for submission of partially completed questionnaire and/or @@ -84,7 +85,7 @@ async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePat @router.patch("/{org_id}/status") -async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest, org_id: int = Path(gt=0)): +async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest, org_id: Annotated[int, Path(gt=0)]): org_model = db.query(Org).filter(Org.id == org_id).first() if org_model is None: raise HTTPException(status_code=404, detail="Organisation not found") @@ -96,7 +97,7 @@ async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest @router.patch("/{org_id}/contact") -async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequest, org_id: int = Path(gt=0)): +async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequest, org_id: Annotated[int, Path(gt=0)]): org_model = db.query(Org).filter(Org.id == org_id).first() if org_model is None: raise HTTPException(status_code=404, detail="Organisation not found") @@ -116,7 +117,7 @@ async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequ @router.get("/{org_id}/users", response_model=list[OrgUserGetResponse]) -async def get_users(db: db_dependency, org_id: int = Path(gt=0)): +async def get_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)]): org_exists = db.query(exists().where(Org.id == org_id)).scalar() if not org_exists: raise HTTPException(status_code=404, detail="Organisation not found") @@ -127,7 +128,7 @@ async def get_users(db: db_dependency, org_id: int = Path(gt=0)): @router.get("/{org_id}/users/admins", response_model=list[OrgUserGetResponse]) -async def get_admin_users(db: db_dependency, org_id: int = Path(gt=0)): +async def get_admin_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)]): org_exists = db.query(exists().where(Org.id == org_id)).scalar() if not org_exists: raise HTTPException(status_code=404, detail="Organisation not found") @@ -138,7 +139,11 @@ async def get_admin_users(db: db_dependency, org_id: int = Path(gt=0)): @router.post("/{org_id}/users") -async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, org_id: int = Path(gt=0)): +async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, org_id: Annotated[int, Path(gt=0)]): + org_model = (db.query(Org).filter(Org.id == org_id).first()) + if org_model is None: + raise HTTPException(status_code=404, detail="Organisation not found") + org_user_model = OrgUsers(**user_request.model_dump(), org_id=org_id) db.add(org_user_model) @@ -146,11 +151,14 @@ async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, o @router.patch("/{org_id}/users") -async def update_user_details(db: db_dependency, user_request: OrgUserPostRequest, org_id: int = Path(gt=0)): +async def update_user_details(db: db_dependency, user_request: OrgUserPostRequest, org_id: Annotated[int, Path(gt=0)]): """ Currently used only to update user admin status for organisation. """ - # TODO: Check if org exists + org_model = (db.query(Org).filter(Org.id == org_id).first()) + if org_model is None: + raise HTTPException(status_code=404, detail="Organisation not found") + org_user_model = db.query(OrgUsers).filter(OrgUsers.org_id == org_id).filter(OrgUsers.user_id == user_request.user_id).first() if org_user_model is None: @@ -164,7 +172,7 @@ async def update_user_details(db: db_dependency, user_request: OrgUserPostReques @router.delete("/{org_id}") -async def delete_organisation_by_id(db: db_dependency, org_id: int = Path(gt=0)): +async def delete_organisation_by_id(db: db_dependency, org_id: Annotated[int, Path(gt=0)]): org_model = (db.query(Org).filter(Org.id == org_id).first()) if org_model is None: raise HTTPException(status_code=404, detail="Organisation not found") @@ -173,10 +181,11 @@ async def delete_organisation_by_id(db: db_dependency, org_id: int = Path(gt=0)) @router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse) -async def get_contact(db: db_dependency, contact_type: ContactType, org_id: int = Path(gt=0)): +async def get_contact(db: db_dependency, contact_type: ContactType, org_id: Annotated[int, Path(gt=0)]): org_model = db.query(Org).filter(Org.id == org_id).first() if org_model is None: raise HTTPException(status_code=404, detail="Organisation not found") + match contact_type: case "billing": contact_id = org_model.billing_contact_id diff --git a/src/organisation/schemas.py b/src/organisation/schemas.py index cb11fb1..b4f5019 100644 --- a/src/organisation/schemas.py +++ b/src/organisation/schemas.py @@ -6,22 +6,27 @@ Models: - Models: Description """ from typing import Optional -from pydantic import Json from src.schemas import CustomBaseModel from src.organisation.constants import Status, ContactType +class OrgQuestionnaire(CustomBaseModel): + question_one: str + question_two: str + question_three: str + + class OrgOrgPostRequest(CustomBaseModel): name: str - intake_questionnaire: Optional[Json] = None + intake_questionnaire: Optional[OrgQuestionnaire] = None billing_contact_id: Optional[int] = None security_contact_id: Optional[int] = None owner_contact_id: Optional[int] = None class OrgQuestionnairePatchRequest(CustomBaseModel): - intake_questionnaire: Json + intake_questionnaire: OrgQuestionnaire partial: bool class OrgStatusPatchRequest(CustomBaseModel): diff --git a/src/user/exceptions.py b/src/user/exceptions.py index 5c1087e..6f2a669 100644 --- a/src/user/exceptions.py +++ b/src/user/exceptions.py @@ -4,4 +4,16 @@ Module specific exceptions for user module Exceptions: - List: Description - Exceptions: Description -""" \ No newline at end of file +""" +from typing import Optional + +from fastapi import HTTPException, status + + +class UserNotFoundException(HTTPException): + def __init__(self, user_id: Optional[int] = None) -> None: + detail = "User not found" if user_id is None else f"User with ID '{user_id}' was not found." + super().__init__( + status_code=status.HTTP_404_NOT_FOUND, + detail=detail, + ) diff --git a/src/user/router.py b/src/user/router.py index 899f145..fb2b632 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -2,21 +2,25 @@ Router endpoints for user module Endpoints: - - [get]/me/claims - Retrieves user's OIDC claims - - [get]/me/db - Retrieves the user data from the db that corresponds to the current OIDC user - - [get]/me/orgs - Retrieves all organisations associated with the current user - - [get]/me/orgs/admin - Retrieves only admin organisations for the current user + - [get]/self/claims - Retrieves user's OIDC claims + - [get]/self/db - Retrieves the user data from the db that corresponds to the current OIDC user + - [get]/self/orgs - Retrieves all organisations associated with the current user + - [get]/self/orgs/admin - Retrieves only admin organisations for the current user - [get]/{user_id} - Retrieves a specific user by their ID - [get]/{user_id}/orgs - Retrieves all organisations associated with a specific user - [get]/{user_id}/orgs/admin - Retrieves only admin organisations for a specific user - [delete]/{user_id} - Deletes a user from the db by their db ID """ -from fastapi import APIRouter, HTTPException +from typing import Annotated + +from fastapi import APIRouter from fastapi.params import Path from sqlalchemy.sql import exists +from starlette import status from src.user.models import User -from src.user.schemas import UserResponse, OIDCUser, OrgResponse +from src.user.schemas import UserResponse, OrgResponse, OIDCClaims +from src.user.exceptions import UserNotFoundException from src.organisation.models import OrgUsers, Organisation @@ -25,36 +29,54 @@ from src.database import db_dependency router = APIRouter( prefix="/user", - tags=["user"], + tags=["User"], ) -@router.get("/me/claims") +@router.get("/self/claims", response_model=OIDCClaims, status_code=status.HTTP_200_OK, responses={ + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, +}) async def current_user_claims(user: claims_dependency): + """ + Returns the full OIDC claims associated with the currently logged-in user. + """ + user["allowed_origins"] = user.get("allowed-origins", []) return user -@router.get("/me/db", response_model=OIDCUser) +@router.get("/self/db", response_model=UserResponse, status_code=status.HTTP_200_OK, responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, +}) async def current_user(user: claims_dependency, db: db_dependency): - db_id = user.get("db_id", None) - if db_id is None: - raise HTTPException(status_code=404, detail="User not found in db") + """ + Returns the database details associated with the currently logged-in user. + """ + user_id = user.get("db_id", None) + if user_id is None: + raise UserNotFoundException() - user_model = (db.query(User).filter(User.id == db_id).first()) + user_model = (db.query(User).filter(User.id == user_id).first()) if user_model is None: - raise HTTPException(status_code=404, detail="User not found") + raise UserNotFoundException(user_id=user_id) return user_model -@router.get("/me/orgs", response_model=list[OrgResponse]) +@router.get("/self/orgs", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, +}) async def get_current_organisations(db: db_dependency, user: claims_dependency): + """ + Returns all organisations associated with the currently logged-in user. + """ user_id = user.get("db_id", None) if user_id is None: - raise HTTPException(status_code=404, detail="User not found") + raise UserNotFoundException() user_exists = db.query(exists().where(User.id == user_id)).scalar() if not user_exists: - raise HTTPException(status_code=404, detail="User not found") + raise UserNotFoundException(user_id=user_id) org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) .join(OrgUsers, Organisation.id == OrgUsers.org_id) @@ -65,14 +87,20 @@ async def get_current_organisations(db: db_dependency, user: claims_dependency): return org_user_models -@router.get("/me/orgs/admin", response_model=list[OrgResponse]) +@router.get("/self/orgs/admin", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, +}) async def get_current_admin_organisations(db: db_dependency, user: claims_dependency): + """ + Returns the organisations for which the currently logged-in user is an admin. + """ user_id = user.get("db_id", None) if user_id is None: - raise HTTPException(status_code=404, detail="User not found") + raise UserNotFoundException() user_exists = db.query(exists().where(User.id == user_id)).scalar() if not user_exists: - raise HTTPException(status_code=404, detail="User not found") + raise UserNotFoundException(user_id=user_id) org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) .join(OrgUsers, Organisation.id == OrgUsers.org_id) @@ -84,20 +112,32 @@ async def get_current_admin_organisations(db: db_dependency, user: claims_depend return org_user_models -@router.get("/{user_id}", response_model=UserResponse) -async def get_user_by_id(user_id: int, db: db_dependency): +@router.get("/{user_id}", response_model=UserResponse, status_code=status.HTTP_200_OK, responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, +}) +async def get_user_by_id(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]): + """ + Returns the database details associated with the provided user ID. + """ user_model = (db.query(User).filter(User.id == user_id).first()) if user_model is None: - raise HTTPException(status_code=404, detail="User not found") + raise UserNotFoundException(user_id=user_id) return user_model -@router.get("/{user_id}/orgs", response_model=list[OrgResponse]) -async def get_organisations(db: db_dependency, user_id: int = Path(gt=0)): +@router.get("/{user_id}/orgs", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, +}) +async def get_organisations(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]): + """ + Returns all organisations associated with the provided user ID. + """ user_exists = db.query(exists().where(User.id == user_id)).scalar() if not user_exists: - raise HTTPException(status_code=404, detail="User not found") + raise UserNotFoundException(user_id=user_id) org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) .join(OrgUsers, Organisation.id == OrgUsers.org_id) @@ -108,11 +148,17 @@ async def get_organisations(db: db_dependency, user_id: int = Path(gt=0)): return org_user_models -@router.get("/{user_id}/orgs/admin", response_model=list[OrgResponse]) -async def get_admin_organisations(db: db_dependency, user_id: int = Path(gt=0)): +@router.get("/{user_id}/orgs/admin", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, +}) +async def get_admin_organisations(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]): + """ + Returns the organisations for which the user with the provided user ID is an admin. + """ user_exists = db.query(exists().where(User.id == user_id)).scalar() if not user_exists: - raise HTTPException(status_code=404, detail="User not found") + raise UserNotFoundException(user_id=user_id) org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) .join(OrgUsers, Organisation.id == OrgUsers.org_id) @@ -124,10 +170,16 @@ async def get_admin_organisations(db: db_dependency, user_id: int = Path(gt=0)): return org_user_models -@router.delete("/{user_id}") -async def delete_user_by_id(user_id: int, db: db_dependency): +@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT, responses={ + status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + }) +async def delete_user_by_id(user_id: Annotated[int, Path(gt=0)], db: db_dependency): + """ + Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login. + """ user_model = (db.query(User).filter(User.id == user_id).first()) if user_model is None: - raise HTTPException(status_code=404, detail="User not found") + raise UserNotFoundException(user_id=user_id) db.delete(user_model) db.commit() diff --git a/src/user/schemas.py b/src/user/schemas.py index c4753d9..2dd29ab 100644 --- a/src/user/schemas.py +++ b/src/user/schemas.py @@ -6,7 +6,31 @@ Models: - Models: Description """ from src.schemas import CustomBaseModel -from pydantic import Field + + +class OIDCClaims(CustomBaseModel): + exp: int + iat: int + auth_time: int + jti: str + iss: str + aud: str + sub: str + typ: str + azp: str + sid: str + acr: str + allowed_origins: list[str] + realm_access: dict[str, list[str]] + resource_access: dict[str, dict[str, list[str]]] + scope: str + email_verified: bool + name: str + preferred_username: str + given_name: str + family_name: str + email: str + db_id: int class OIDCUser(CustomBaseModel): diff --git a/src/user/service.py b/src/user/service.py index 8549c4c..ed706b2 100644 --- a/src/user/service.py +++ b/src/user/service.py @@ -7,7 +7,8 @@ Functions: Exports: - add_user_to_db """ -from authlib.jose import JWTClaims +from typing import Any + from fastapi import HTTPException from src.user.schemas import OIDCUser @@ -15,7 +16,7 @@ from src.user.models import User from src.database import get_db -async def add_user_to_db(user_claims: JWTClaims) -> int: +async def add_user_to_db(user_claims: dict[str, Any]) -> int: try: valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"]) except Exception as e: @@ -31,5 +32,14 @@ async def add_user_to_db(user_claims: JWTClaims) -> int: db.commit() return user_model.id else: - # Verify details still match and update accordingly. + change = False + if db_user.first_name != valid_user.first_name: + db_user.first_name = valid_user.first_name + change = True + if db_user.last_name != valid_user.last_name: + db_user.last_name = valid_user.last_name + change = True + if change: + db.add(db_user) + db.commit() return db_user.id