diff --git a/renovate.json b/renovate.json new file mode 100644 index 0000000..c433e3a --- /dev/null +++ b/renovate.json @@ -0,0 +1,8 @@ +{ + "$schema": "https://docs.renovatebot.com/renovate-schema.json", + "extends": [ + "config:recommended" + ], + "minimumReleaseAge": "14 days", + "gitAuthor": "Renovate" +} diff --git a/requirements.txt b/requirements.txt index b48c39d..007020b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,3 @@ httptools psycopg email-validator alembic -joserfc diff --git a/src/admin/router.py b/src/admin/router.py index e6df7de..595335d 100644 --- a/src/admin/router.py +++ b/src/admin/router.py @@ -5,17 +5,20 @@ Endpoints: - List: Description - Endpoints: Description """ -from typing import Annotated - from fastapi import APIRouter, HTTPException from fastapi.params import Path +from sqlalchemy.sql import exists from src.organisation.constants import ContactType from src.organisation.schemas import OrgContactGetResponse from src.organisation.models import Organisation as Org from src.contact.models import Contact +from src.user.models import User +from src.user.schemas import UserResponse, OIDCUser, OrgResponse -from src.auth.service import claims_dependency, org_or_super_admin_dependency +from src.organisation.models import OrgUsers, Organisation + +from src.auth.service import claims_dependency, org_user_dependency, org_admin_dependency from src.database import db_dependency @@ -26,7 +29,7 @@ router = APIRouter( @router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse) -async def get_contact(db: db_dependency, user: claims_dependency, is_admin: org_or_super_admin_dependency, contact_type: ContactType, org_id: Annotated[int, Path(gt=0)]): +async def get_contact(db: db_dependency, user: claims_dependency, is_org_admin: org_user_dependency, contact_type: ContactType, org_id: int = Path(gt=0)): org_model = db.query(Org).filter(Org.id == org_id).first() if org_model is None: raise HTTPException(status_code=404, detail="Organisation not found") diff --git a/src/auth/service.py b/src/auth/service.py index fc3cd4c..80d49c5 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -5,16 +5,16 @@ Exports: - claims_dependency """ import json -import requests -from typing import Annotated, Any -from joserfc import jwt -from joserfc.errors import ExpiredTokenError -from joserfc.jwk import KeySet +from typing import Annotated +from authlib.jose import jwt from urllib.request import urlopen from fastapi import Depends, HTTPException, Path from fastapi.security import OpenIdConnect +from authlib.jose.rfc7517.jwk import JsonWebKey +from authlib.jose.rfc7517.key_set import KeySet +from authlib.oauth2.rfc7523.validator import JWTBearerToken from sqlalchemy.sql import exists from src.auth.config import auth_settings @@ -27,12 +27,12 @@ oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc_dependency = Annotated[str, Depends(oidc)] -async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]: +async def get_current_user(oidc_auth_string: oidc_dependency) -> JWTBearerToken: config_url = urlopen(auth_settings.OIDC_CONFIG) config = json.loads(config_url.read()) jwks_uri = config["jwks_uri"] - key_response = requests.get(jwks_uri) - jwk_keys = KeySet.import_key_set(key_response.json()) + key_response = urlopen(jwks_uri) + jwk_keys: KeySet = JsonWebKey.import_key_set(json.loads(key_response.read())) claims_options = { "exp": {"essential": True}, @@ -40,26 +40,22 @@ async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]: "iss": {"essential": True, "value": auth_settings.OIDC_ISSUER}, } - token = jwt.decode( + claims: JWTBearerToken = jwt.decode( oidc_auth_string.replace("Bearer ", ""), - jwk_keys + jwk_keys, + claims_options=claims_options, + claims_cls=JWTBearerToken, ) - claims_requests = jwt.JWTClaimsRegistry(**claims_options) + claims.validate() + db_id = await add_user_to_db(claims) - try: - claims_requests.validate(token.claims) - except ExpiredTokenError as e: - raise HTTPException(status_code=401, detail="Token expired") + claims["db_id"] = db_id - db_id = await add_user_to_db(token.claims) - - token.claims["db_id"] = db_id - - return token.claims + return claims -claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] +claims_dependency = Annotated[JWTBearerToken, Depends(get_current_user)] async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): @@ -85,7 +81,7 @@ async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int return org_user_exists -org_user_dependency = Annotated[dict[str, Any], Depends(is_org_user)] +org_user_dependency = Annotated[JWTBearerToken, Depends(is_org_user)] async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): @@ -112,7 +108,7 @@ async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int return org_admin_exists -org_admin_dependency = Annotated[dict[str, Any], Depends(is_org_admin)] +org_admin_dependency = Annotated[JWTBearerToken, Depends(is_org_admin)] async def is_super_admin(claims: claims_dependency): @@ -127,22 +123,10 @@ async def is_super_admin(claims: claims_dependency): return True -super_admin_dependency = Annotated[dict[str, Any], Depends(is_super_admin)] +super_admin_dependency = Annotated[JWTBearerToken, Depends(is_super_admin)] -async def is_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): - try: - await is_super_admin(claims) - return True - except HTTPException as e: - pass - try: - await is_org_admin(claims, db, org_id) - return True - except HTTPException as e: - raise HTTPException(status_code=401, detail="Not authorised") -org_or_super_admin_dependency = Annotated[dict[str, Any], Depends(is_admin)] # Middleware version of user auth # import json diff --git a/src/contact/router.py b/src/contact/router.py index 3ed9a0f..2ef34ab 100644 --- a/src/contact/router.py +++ b/src/contact/router.py @@ -9,8 +9,6 @@ Endpoints: - [patch]/{contact_id} - Updates the details of an existing contact - [delete]/{contact_id} - Deletes a contact by ID """ -from typing import Annotated - from fastapi import APIRouter, HTTPException from fastapi.params import Path @@ -57,7 +55,7 @@ async def create_contact(db: db_dependency, contact_request: ContactContactPostR @router.patch("/{contact_id}") -async def update_contact(db: db_dependency, contact_request: ContactUpdateRequest, contact_id: Annotated[int, Path(gt=0)]): +async def update_contact(db: db_dependency, contact_request: ContactUpdateRequest, contact_id: int = Path(gt=0)): contact_model = (db.query(Contact).filter(Contact.id == contact_id).first()) if contact_model is None: raise HTTPException(status_code=404, detail="Contact not found") @@ -74,7 +72,7 @@ async def update_contact(db: db_dependency, contact_request: ContactUpdateReques @router.delete("/{contact_id}") -async def delete_contact(db: db_dependency, contact_id: Annotated[int, Path(gt=0)]): +async def delete_contact(db: db_dependency, contact_id: int = Path(gt=0)): contact_model = (db.query(Contact).filter(Contact.id == contact_id).first()) if contact_model is None: raise HTTPException(status_code=404, detail="Contact not found") @@ -84,7 +82,7 @@ async def delete_contact(db: db_dependency, contact_id: Annotated[int, Path(gt=0 @router.get("/{contact_id}/orgs", response_model=list[ContactOrgGetResponse]) -async def get_contact_orgs(db: db_dependency, contact_id: Annotated[int, Path(gt=0)]): +async def get_contact_orgs(db: db_dependency, contact_id: int = Path(gt=0)): contact_model = (db.query(Contact).filter(Contact.id == contact_id).first()) if contact_model is None: raise HTTPException(status_code=404, detail="Contact not found") diff --git a/src/contact/schemas.py b/src/contact/schemas.py index d64e265..9e27c81 100644 --- a/src/contact/schemas.py +++ b/src/contact/schemas.py @@ -7,7 +7,7 @@ Models: """ from typing import Optional -from pydantic import EmailStr +from pydantic import Field, EmailStr from src.organisation.constants import ContactType from src.schemas import CustomBaseModel diff --git a/src/main.py b/src/main.py index 4b4d1a9..4f122c6 100644 --- a/src/main.py +++ b/src/main.py @@ -26,28 +26,15 @@ if settings.ENVIRONMENT.is_deployed: pass -tags_metadata = [ - { - "name": "User", - "description": "User related operations, includes getting information about the current user", - } -] - - app = FastAPI( swagger_ui_init_oauth={ "clientId": auth_settings.CLIENT_ID, "usePkceWithAuthorizationCodeGrant": True, "scopes": "openid profile email", - }, - openapi_tags=tags_metadata, + } ) -# Type inspection disabled for middleware injection. -# Known bug in FastAPI type checking: https://github.com/astral-sh/ty/issues/1635 -# noinspection PyTypeChecker app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value()) -# noinspection PyTypeChecker app.add_middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS, diff --git a/src/organisation/models.py b/src/organisation/models.py index de2b17a..db56359 100644 --- a/src/organisation/models.py +++ b/src/organisation/models.py @@ -9,6 +9,8 @@ Models: from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, JSON, false from src.database import Base +from src.contact.models import Contact +from src.user.models import User class Organisation(Base): diff --git a/src/organisation/router.py b/src/organisation/router.py index 3313158..40f4ac4 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -14,13 +14,12 @@ Endpoints: - [delete]/{org_id} - Deletes an organisation by ID - [get]/{org_id}/contact/{contact_type} - Retrieves the contact of a specific type (owner, billing, security) for an organisation """ -from typing import Annotated - from fastapi import APIRouter, HTTPException from fastapi.params import Path from sqlalchemy.sql import exists +from src.auth.service import super_admin_dependency from src.database import db_dependency from src.contact.models import Contact @@ -37,7 +36,7 @@ router = APIRouter( @router.get("/id/{org_id}", response_model=OrgOrgGetResponse) -async def get_org_by_id(db: db_dependency, org_id: Annotated[int, Path(gt=0)]): +async def get_org_by_id(db: db_dependency, org_id: int = Path(gt=0)): org_model = (db.query(Org).filter(Org.id == org_id).first()) if org_model is None: raise HTTPException(status_code=404, detail="Organisation not found") @@ -64,7 +63,7 @@ async def create_org(db: db_dependency, org_request: OrgOrgPostRequest): @router.patch("/{org_id}/questionnaire") -async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePatchRequest, org_id: Annotated[int, Path(gt=0)]): +async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePatchRequest, org_id: int = Path(gt=0)): """ Route for updating questionnaire. The partial bool allows for submission of partially completed questionnaire and/or @@ -85,7 +84,7 @@ async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePat @router.patch("/{org_id}/status") -async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest, org_id: Annotated[int, Path(gt=0)]): +async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest, org_id: int = Path(gt=0)): org_model = db.query(Org).filter(Org.id == org_id).first() if org_model is None: raise HTTPException(status_code=404, detail="Organisation not found") @@ -97,7 +96,7 @@ async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest @router.patch("/{org_id}/contact") -async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequest, org_id: Annotated[int, Path(gt=0)]): +async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequest, org_id: int = Path(gt=0)): org_model = db.query(Org).filter(Org.id == org_id).first() if org_model is None: raise HTTPException(status_code=404, detail="Organisation not found") @@ -117,7 +116,7 @@ async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequ @router.get("/{org_id}/users", response_model=list[OrgUserGetResponse]) -async def get_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)]): +async def get_users(db: db_dependency, org_id: int = Path(gt=0)): org_exists = db.query(exists().where(Org.id == org_id)).scalar() if not org_exists: raise HTTPException(status_code=404, detail="Organisation not found") @@ -128,7 +127,7 @@ async def get_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)]): @router.get("/{org_id}/users/admins", response_model=list[OrgUserGetResponse]) -async def get_admin_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)]): +async def get_admin_users(db: db_dependency, org_id: int = Path(gt=0)): org_exists = db.query(exists().where(Org.id == org_id)).scalar() if not org_exists: raise HTTPException(status_code=404, detail="Organisation not found") @@ -139,11 +138,7 @@ async def get_admin_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)]) @router.post("/{org_id}/users") -async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, org_id: Annotated[int, Path(gt=0)]): - org_model = (db.query(Org).filter(Org.id == org_id).first()) - if org_model is None: - raise HTTPException(status_code=404, detail="Organisation not found") - +async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, org_id: int = Path(gt=0)): org_user_model = OrgUsers(**user_request.model_dump(), org_id=org_id) db.add(org_user_model) @@ -151,14 +146,11 @@ async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, o @router.patch("/{org_id}/users") -async def update_user_details(db: db_dependency, user_request: OrgUserPostRequest, org_id: Annotated[int, Path(gt=0)]): +async def update_user_details(db: db_dependency, user_request: OrgUserPostRequest, org_id: int = Path(gt=0)): """ Currently used only to update user admin status for organisation. """ - org_model = (db.query(Org).filter(Org.id == org_id).first()) - if org_model is None: - raise HTTPException(status_code=404, detail="Organisation not found") - + # TODO: Check if org exists org_user_model = db.query(OrgUsers).filter(OrgUsers.org_id == org_id).filter(OrgUsers.user_id == user_request.user_id).first() if org_user_model is None: @@ -172,7 +164,7 @@ async def update_user_details(db: db_dependency, user_request: OrgUserPostReques @router.delete("/{org_id}") -async def delete_organisation_by_id(db: db_dependency, org_id: Annotated[int, Path(gt=0)]): +async def delete_organisation_by_id(db: db_dependency, org_id: int = Path(gt=0)): org_model = (db.query(Org).filter(Org.id == org_id).first()) if org_model is None: raise HTTPException(status_code=404, detail="Organisation not found") @@ -181,11 +173,10 @@ async def delete_organisation_by_id(db: db_dependency, org_id: Annotated[int, Pa @router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse) -async def get_contact(db: db_dependency, contact_type: ContactType, org_id: Annotated[int, Path(gt=0)]): +async def get_contact(db: db_dependency, contact_type: ContactType, org_id: int = Path(gt=0)): org_model = db.query(Org).filter(Org.id == org_id).first() if org_model is None: raise HTTPException(status_code=404, detail="Organisation not found") - match contact_type: case "billing": contact_id = org_model.billing_contact_id diff --git a/src/organisation/schemas.py b/src/organisation/schemas.py index b4f5019..cb11fb1 100644 --- a/src/organisation/schemas.py +++ b/src/organisation/schemas.py @@ -6,27 +6,22 @@ Models: - Models: Description """ from typing import Optional +from pydantic import Json from src.schemas import CustomBaseModel from src.organisation.constants import Status, ContactType -class OrgQuestionnaire(CustomBaseModel): - question_one: str - question_two: str - question_three: str - - class OrgOrgPostRequest(CustomBaseModel): name: str - intake_questionnaire: Optional[OrgQuestionnaire] = None + intake_questionnaire: Optional[Json] = None billing_contact_id: Optional[int] = None security_contact_id: Optional[int] = None owner_contact_id: Optional[int] = None class OrgQuestionnairePatchRequest(CustomBaseModel): - intake_questionnaire: OrgQuestionnaire + intake_questionnaire: Json partial: bool class OrgStatusPatchRequest(CustomBaseModel): diff --git a/src/user/exceptions.py b/src/user/exceptions.py index 6f2a669..5c1087e 100644 --- a/src/user/exceptions.py +++ b/src/user/exceptions.py @@ -4,16 +4,4 @@ Module specific exceptions for user module Exceptions: - List: Description - Exceptions: Description -""" -from typing import Optional - -from fastapi import HTTPException, status - - -class UserNotFoundException(HTTPException): - def __init__(self, user_id: Optional[int] = None) -> None: - detail = "User not found" if user_id is None else f"User with ID '{user_id}' was not found." - super().__init__( - status_code=status.HTTP_404_NOT_FOUND, - detail=detail, - ) +""" \ No newline at end of file diff --git a/src/user/router.py b/src/user/router.py index fb2b632..899f145 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -2,25 +2,21 @@ Router endpoints for user module Endpoints: - - [get]/self/claims - Retrieves user's OIDC claims - - [get]/self/db - Retrieves the user data from the db that corresponds to the current OIDC user - - [get]/self/orgs - Retrieves all organisations associated with the current user - - [get]/self/orgs/admin - Retrieves only admin organisations for the current user + - [get]/me/claims - Retrieves user's OIDC claims + - [get]/me/db - Retrieves the user data from the db that corresponds to the current OIDC user + - [get]/me/orgs - Retrieves all organisations associated with the current user + - [get]/me/orgs/admin - Retrieves only admin organisations for the current user - [get]/{user_id} - Retrieves a specific user by their ID - [get]/{user_id}/orgs - Retrieves all organisations associated with a specific user - [get]/{user_id}/orgs/admin - Retrieves only admin organisations for a specific user - [delete]/{user_id} - Deletes a user from the db by their db ID """ -from typing import Annotated - -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from fastapi.params import Path from sqlalchemy.sql import exists -from starlette import status from src.user.models import User -from src.user.schemas import UserResponse, OrgResponse, OIDCClaims -from src.user.exceptions import UserNotFoundException +from src.user.schemas import UserResponse, OIDCUser, OrgResponse from src.organisation.models import OrgUsers, Organisation @@ -29,54 +25,36 @@ from src.database import db_dependency router = APIRouter( prefix="/user", - tags=["User"], + tags=["user"], ) -@router.get("/self/claims", response_model=OIDCClaims, status_code=status.HTTP_200_OK, responses={ - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, -}) +@router.get("/me/claims") async def current_user_claims(user: claims_dependency): - """ - Returns the full OIDC claims associated with the currently logged-in user. - """ - user["allowed_origins"] = user.get("allowed-origins", []) return user -@router.get("/self/db", response_model=UserResponse, status_code=status.HTTP_200_OK, responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, -}) +@router.get("/me/db", response_model=OIDCUser) async def current_user(user: claims_dependency, db: db_dependency): - """ - Returns the database details associated with the currently logged-in user. - """ - user_id = user.get("db_id", None) - if user_id is None: - raise UserNotFoundException() + db_id = user.get("db_id", None) + if db_id is None: + raise HTTPException(status_code=404, detail="User not found in db") - user_model = (db.query(User).filter(User.id == user_id).first()) + user_model = (db.query(User).filter(User.id == db_id).first()) if user_model is None: - raise UserNotFoundException(user_id=user_id) + raise HTTPException(status_code=404, detail="User not found") return user_model -@router.get("/self/orgs", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, -}) +@router.get("/me/orgs", response_model=list[OrgResponse]) async def get_current_organisations(db: db_dependency, user: claims_dependency): - """ - Returns all organisations associated with the currently logged-in user. - """ user_id = user.get("db_id", None) if user_id is None: - raise UserNotFoundException() + raise HTTPException(status_code=404, detail="User not found") user_exists = db.query(exists().where(User.id == user_id)).scalar() if not user_exists: - raise UserNotFoundException(user_id=user_id) + raise HTTPException(status_code=404, detail="User not found") org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) .join(OrgUsers, Organisation.id == OrgUsers.org_id) @@ -87,20 +65,14 @@ async def get_current_organisations(db: db_dependency, user: claims_dependency): return org_user_models -@router.get("/self/orgs/admin", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, -}) +@router.get("/me/orgs/admin", response_model=list[OrgResponse]) async def get_current_admin_organisations(db: db_dependency, user: claims_dependency): - """ - Returns the organisations for which the currently logged-in user is an admin. - """ user_id = user.get("db_id", None) if user_id is None: - raise UserNotFoundException() + raise HTTPException(status_code=404, detail="User not found") user_exists = db.query(exists().where(User.id == user_id)).scalar() if not user_exists: - raise UserNotFoundException(user_id=user_id) + raise HTTPException(status_code=404, detail="User not found") org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) .join(OrgUsers, Organisation.id == OrgUsers.org_id) @@ -112,32 +84,20 @@ async def get_current_admin_organisations(db: db_dependency, user: claims_depend return org_user_models -@router.get("/{user_id}", response_model=UserResponse, status_code=status.HTTP_200_OK, responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, -}) -async def get_user_by_id(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]): - """ - Returns the database details associated with the provided user ID. - """ +@router.get("/{user_id}", response_model=UserResponse) +async def get_user_by_id(user_id: int, db: db_dependency): user_model = (db.query(User).filter(User.id == user_id).first()) if user_model is None: - raise UserNotFoundException(user_id=user_id) + raise HTTPException(status_code=404, detail="User not found") return user_model -@router.get("/{user_id}/orgs", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, -}) -async def get_organisations(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]): - """ - Returns all organisations associated with the provided user ID. - """ +@router.get("/{user_id}/orgs", response_model=list[OrgResponse]) +async def get_organisations(db: db_dependency, user_id: int = Path(gt=0)): user_exists = db.query(exists().where(User.id == user_id)).scalar() if not user_exists: - raise UserNotFoundException(user_id=user_id) + raise HTTPException(status_code=404, detail="User not found") org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) .join(OrgUsers, Organisation.id == OrgUsers.org_id) @@ -148,17 +108,11 @@ async def get_organisations(db: db_dependency, user_id: Annotated[int, Path(gt=0 return org_user_models -@router.get("/{user_id}/orgs/admin", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, -}) -async def get_admin_organisations(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]): - """ - Returns the organisations for which the user with the provided user ID is an admin. - """ +@router.get("/{user_id}/orgs/admin", response_model=list[OrgResponse]) +async def get_admin_organisations(db: db_dependency, user_id: int = Path(gt=0)): user_exists = db.query(exists().where(User.id == user_id)).scalar() if not user_exists: - raise UserNotFoundException(user_id=user_id) + raise HTTPException(status_code=404, detail="User not found") org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) .join(OrgUsers, Organisation.id == OrgUsers.org_id) @@ -170,16 +124,10 @@ async def get_admin_organisations(db: db_dependency, user_id: Annotated[int, Pat return org_user_models -@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT, responses={ - status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - }) -async def delete_user_by_id(user_id: Annotated[int, Path(gt=0)], db: db_dependency): - """ - Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login. - """ +@router.delete("/{user_id}") +async def delete_user_by_id(user_id: int, db: db_dependency): user_model = (db.query(User).filter(User.id == user_id).first()) if user_model is None: - raise UserNotFoundException(user_id=user_id) + raise HTTPException(status_code=404, detail="User not found") db.delete(user_model) db.commit() diff --git a/src/user/schemas.py b/src/user/schemas.py index 2dd29ab..c4753d9 100644 --- a/src/user/schemas.py +++ b/src/user/schemas.py @@ -6,31 +6,7 @@ Models: - Models: Description """ from src.schemas import CustomBaseModel - - -class OIDCClaims(CustomBaseModel): - exp: int - iat: int - auth_time: int - jti: str - iss: str - aud: str - sub: str - typ: str - azp: str - sid: str - acr: str - allowed_origins: list[str] - realm_access: dict[str, list[str]] - resource_access: dict[str, dict[str, list[str]]] - scope: str - email_verified: bool - name: str - preferred_username: str - given_name: str - family_name: str - email: str - db_id: int +from pydantic import Field class OIDCUser(CustomBaseModel): diff --git a/src/user/service.py b/src/user/service.py index ed706b2..8549c4c 100644 --- a/src/user/service.py +++ b/src/user/service.py @@ -7,8 +7,7 @@ Functions: Exports: - add_user_to_db """ -from typing import Any - +from authlib.jose import JWTClaims from fastapi import HTTPException from src.user.schemas import OIDCUser @@ -16,7 +15,7 @@ from src.user.models import User from src.database import get_db -async def add_user_to_db(user_claims: dict[str, Any]) -> int: +async def add_user_to_db(user_claims: JWTClaims) -> int: try: valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"]) except Exception as e: @@ -32,14 +31,5 @@ async def add_user_to_db(user_claims: dict[str, Any]) -> int: db.commit() return user_model.id else: - change = False - if db_user.first_name != valid_user.first_name: - db_user.first_name = valid_user.first_name - change = True - if db_user.last_name != valid_user.last_name: - db_user.last_name = valid_user.last_name - change = True - if change: - db.add(db_user) - db.commit() + # Verify details still match and update accordingly. return db_user.id