Compare commits

...
Sign in to create a new pull request.

10 commits

Author SHA1 Message Date
83a24a91f4 docs: user module
In-line and Swagger docs improvements on the User module and endpoints
2026-05-20 15:23:40 +01:00
6871fcd75d feat: handling for expired token
Returns a 401 with "Token expired" as the detail
2026-05-20 10:50:49 +01:00
34413b3fc5 feat: oidc claims response model 2026-05-20 10:42:07 +01:00
26db93b769 feat: user details updated in db on login 2026-05-20 10:06:36 +01:00
7d84f33bfa fix: intake questionnaire typing
The docs were not generating correctly when using the Json type. A class with placeholder properties has been created instead.
2026-05-19 12:55:46 +01:00
f54876eac6 minor: cleanup
Minor tweaks to reduce warnings in IDE e.g. unused imports.
2026-05-19 12:10:06 +01:00
d89c926a38 feat: org exists checks on orguser routes
Routes modifying the org-user table did not check if the org existed first.
2026-05-19 11:49:54 +01:00
6f4556a44b fix: corrected use of path param
Previously used `param: int = Path()` this worked but was incorrect.
Correct usage is `param: Annotated[int, Path()]`
2026-05-19 11:11:03 +01:00
2b8296d622 feat: combined admin dependency
org_or_super_admin_dependency can be used in place of org_admin_dependency to also allow super admins.
2026-05-19 11:08:22 +01:00
34cb4414c9 feat: auth library upgrade
The parts of Authlib used are now deprecated in favour of JoseRFC.
2026-05-19 09:49:27 +01:00
13 changed files with 225 additions and 86 deletions

View file

@ -18,3 +18,4 @@ httptools
psycopg psycopg
email-validator email-validator
alembic alembic
joserfc

View file

@ -5,20 +5,17 @@ Endpoints:
- List: Description - List: Description
- Endpoints: Description - Endpoints: Description
""" """
from typing import Annotated
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from fastapi.params import Path from fastapi.params import Path
from sqlalchemy.sql import exists
from src.organisation.constants import ContactType from src.organisation.constants import ContactType
from src.organisation.schemas import OrgContactGetResponse from src.organisation.schemas import OrgContactGetResponse
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
from src.contact.models import Contact from src.contact.models import Contact
from src.user.models import User
from src.user.schemas import UserResponse, OIDCUser, OrgResponse
from src.organisation.models import OrgUsers, Organisation from src.auth.service import claims_dependency, org_or_super_admin_dependency
from src.auth.service import claims_dependency, org_user_dependency, org_admin_dependency
from src.database import db_dependency from src.database import db_dependency
@ -29,7 +26,7 @@ router = APIRouter(
@router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse) @router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse)
async def get_contact(db: db_dependency, user: claims_dependency, is_org_admin: org_user_dependency, contact_type: ContactType, org_id: int = Path(gt=0)): async def get_contact(db: db_dependency, user: claims_dependency, is_admin: org_or_super_admin_dependency, contact_type: ContactType, org_id: Annotated[int, Path(gt=0)]):
org_model = db.query(Org).filter(Org.id == org_id).first() org_model = db.query(Org).filter(Org.id == org_id).first()
if org_model is None: if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")

View file

@ -5,16 +5,16 @@ Exports:
- claims_dependency - claims_dependency
""" """
import json import json
import requests
from typing import Annotated from typing import Annotated, Any
from authlib.jose import jwt from joserfc import jwt
from joserfc.errors import ExpiredTokenError
from joserfc.jwk import KeySet
from urllib.request import urlopen from urllib.request import urlopen
from fastapi import Depends, HTTPException, Path from fastapi import Depends, HTTPException, Path
from fastapi.security import OpenIdConnect from fastapi.security import OpenIdConnect
from authlib.jose.rfc7517.jwk import JsonWebKey
from authlib.jose.rfc7517.key_set import KeySet
from authlib.oauth2.rfc7523.validator import JWTBearerToken
from sqlalchemy.sql import exists from sqlalchemy.sql import exists
from src.auth.config import auth_settings from src.auth.config import auth_settings
@ -27,12 +27,12 @@ oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG)
oidc_dependency = Annotated[str, Depends(oidc)] oidc_dependency = Annotated[str, Depends(oidc)]
async def get_current_user(oidc_auth_string: oidc_dependency) -> JWTBearerToken: async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]:
config_url = urlopen(auth_settings.OIDC_CONFIG) config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read()) config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"] jwks_uri = config["jwks_uri"]
key_response = urlopen(jwks_uri) key_response = requests.get(jwks_uri)
jwk_keys: KeySet = JsonWebKey.import_key_set(json.loads(key_response.read())) jwk_keys = KeySet.import_key_set(key_response.json())
claims_options = { claims_options = {
"exp": {"essential": True}, "exp": {"essential": True},
@ -40,22 +40,26 @@ async def get_current_user(oidc_auth_string: oidc_dependency) -> JWTBearerToken:
"iss": {"essential": True, "value": auth_settings.OIDC_ISSUER}, "iss": {"essential": True, "value": auth_settings.OIDC_ISSUER},
} }
claims: JWTBearerToken = jwt.decode( token = jwt.decode(
oidc_auth_string.replace("Bearer ", ""), oidc_auth_string.replace("Bearer ", ""),
jwk_keys, jwk_keys
claims_options=claims_options,
claims_cls=JWTBearerToken,
) )
claims.validate() claims_requests = jwt.JWTClaimsRegistry(**claims_options)
db_id = await add_user_to_db(claims)
claims["db_id"] = db_id try:
claims_requests.validate(token.claims)
except ExpiredTokenError as e:
raise HTTPException(status_code=401, detail="Token expired")
return claims db_id = await add_user_to_db(token.claims)
token.claims["db_id"] = db_id
return token.claims
claims_dependency = Annotated[JWTBearerToken, Depends(get_current_user)] claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)]
async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)):
@ -81,7 +85,7 @@ async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int
return org_user_exists return org_user_exists
org_user_dependency = Annotated[JWTBearerToken, Depends(is_org_user)] org_user_dependency = Annotated[dict[str, Any], Depends(is_org_user)]
async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)):
@ -108,7 +112,7 @@ async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int
return org_admin_exists return org_admin_exists
org_admin_dependency = Annotated[JWTBearerToken, Depends(is_org_admin)] org_admin_dependency = Annotated[dict[str, Any], Depends(is_org_admin)]
async def is_super_admin(claims: claims_dependency): async def is_super_admin(claims: claims_dependency):
@ -123,10 +127,22 @@ async def is_super_admin(claims: claims_dependency):
return True return True
super_admin_dependency = Annotated[JWTBearerToken, Depends(is_super_admin)] super_admin_dependency = Annotated[dict[str, Any], Depends(is_super_admin)]
async def is_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)):
try:
await is_super_admin(claims)
return True
except HTTPException as e:
pass
try:
await is_org_admin(claims, db, org_id)
return True
except HTTPException as e:
raise HTTPException(status_code=401, detail="Not authorised")
org_or_super_admin_dependency = Annotated[dict[str, Any], Depends(is_admin)]
# Middleware version of user auth # Middleware version of user auth
# import json # import json

View file

@ -9,6 +9,8 @@ Endpoints:
- [patch]/{contact_id} - Updates the details of an existing contact - [patch]/{contact_id} - Updates the details of an existing contact
- [delete]/{contact_id} - Deletes a contact by ID - [delete]/{contact_id} - Deletes a contact by ID
""" """
from typing import Annotated
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from fastapi.params import Path from fastapi.params import Path
@ -55,7 +57,7 @@ async def create_contact(db: db_dependency, contact_request: ContactContactPostR
@router.patch("/{contact_id}") @router.patch("/{contact_id}")
async def update_contact(db: db_dependency, contact_request: ContactUpdateRequest, contact_id: int = Path(gt=0)): async def update_contact(db: db_dependency, contact_request: ContactUpdateRequest, contact_id: Annotated[int, Path(gt=0)]):
contact_model = (db.query(Contact).filter(Contact.id == contact_id).first()) contact_model = (db.query(Contact).filter(Contact.id == contact_id).first())
if contact_model is None: if contact_model is None:
raise HTTPException(status_code=404, detail="Contact not found") raise HTTPException(status_code=404, detail="Contact not found")
@ -72,7 +74,7 @@ async def update_contact(db: db_dependency, contact_request: ContactUpdateReques
@router.delete("/{contact_id}") @router.delete("/{contact_id}")
async def delete_contact(db: db_dependency, contact_id: int = Path(gt=0)): async def delete_contact(db: db_dependency, contact_id: Annotated[int, Path(gt=0)]):
contact_model = (db.query(Contact).filter(Contact.id == contact_id).first()) contact_model = (db.query(Contact).filter(Contact.id == contact_id).first())
if contact_model is None: if contact_model is None:
raise HTTPException(status_code=404, detail="Contact not found") raise HTTPException(status_code=404, detail="Contact not found")
@ -82,7 +84,7 @@ async def delete_contact(db: db_dependency, contact_id: int = Path(gt=0)):
@router.get("/{contact_id}/orgs", response_model=list[ContactOrgGetResponse]) @router.get("/{contact_id}/orgs", response_model=list[ContactOrgGetResponse])
async def get_contact_orgs(db: db_dependency, contact_id: int = Path(gt=0)): async def get_contact_orgs(db: db_dependency, contact_id: Annotated[int, Path(gt=0)]):
contact_model = (db.query(Contact).filter(Contact.id == contact_id).first()) contact_model = (db.query(Contact).filter(Contact.id == contact_id).first())
if contact_model is None: if contact_model is None:
raise HTTPException(status_code=404, detail="Contact not found") raise HTTPException(status_code=404, detail="Contact not found")

View file

@ -7,7 +7,7 @@ Models:
""" """
from typing import Optional from typing import Optional
from pydantic import Field, EmailStr from pydantic import EmailStr
from src.organisation.constants import ContactType from src.organisation.constants import ContactType
from src.schemas import CustomBaseModel from src.schemas import CustomBaseModel

View file

@ -26,15 +26,28 @@ if settings.ENVIRONMENT.is_deployed:
pass pass
tags_metadata = [
{
"name": "User",
"description": "User related operations, includes getting information about the current user",
}
]
app = FastAPI( app = FastAPI(
swagger_ui_init_oauth={ swagger_ui_init_oauth={
"clientId": auth_settings.CLIENT_ID, "clientId": auth_settings.CLIENT_ID,
"usePkceWithAuthorizationCodeGrant": True, "usePkceWithAuthorizationCodeGrant": True,
"scopes": "openid profile email", "scopes": "openid profile email",
} },
openapi_tags=tags_metadata,
) )
# Type inspection disabled for middleware injection.
# Known bug in FastAPI type checking: https://github.com/astral-sh/ty/issues/1635
# noinspection PyTypeChecker
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value()) app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value())
# noinspection PyTypeChecker
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=settings.CORS_ORIGINS, allow_origins=settings.CORS_ORIGINS,

View file

@ -9,8 +9,6 @@ Models:
from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, JSON, false from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, JSON, false
from src.database import Base from src.database import Base
from src.contact.models import Contact
from src.user.models import User
class Organisation(Base): class Organisation(Base):

View file

@ -14,12 +14,13 @@ Endpoints:
- [delete]/{org_id} - Deletes an organisation by ID - [delete]/{org_id} - Deletes an organisation by ID
- [get]/{org_id}/contact/{contact_type} - Retrieves the contact of a specific type (owner, billing, security) for an organisation - [get]/{org_id}/contact/{contact_type} - Retrieves the contact of a specific type (owner, billing, security) for an organisation
""" """
from typing import Annotated
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from fastapi.params import Path from fastapi.params import Path
from sqlalchemy.sql import exists from sqlalchemy.sql import exists
from src.auth.service import super_admin_dependency
from src.database import db_dependency from src.database import db_dependency
from src.contact.models import Contact from src.contact.models import Contact
@ -36,7 +37,7 @@ router = APIRouter(
@router.get("/id/{org_id}", response_model=OrgOrgGetResponse) @router.get("/id/{org_id}", response_model=OrgOrgGetResponse)
async def get_org_by_id(db: db_dependency, org_id: int = Path(gt=0)): async def get_org_by_id(db: db_dependency, org_id: Annotated[int, Path(gt=0)]):
org_model = (db.query(Org).filter(Org.id == org_id).first()) org_model = (db.query(Org).filter(Org.id == org_id).first())
if org_model is None: if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
@ -63,7 +64,7 @@ async def create_org(db: db_dependency, org_request: OrgOrgPostRequest):
@router.patch("/{org_id}/questionnaire") @router.patch("/{org_id}/questionnaire")
async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePatchRequest, org_id: int = Path(gt=0)): async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePatchRequest, org_id: Annotated[int, Path(gt=0)]):
""" """
Route for updating questionnaire. Route for updating questionnaire.
The partial bool allows for submission of partially completed questionnaire and/or The partial bool allows for submission of partially completed questionnaire and/or
@ -84,7 +85,7 @@ async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePat
@router.patch("/{org_id}/status") @router.patch("/{org_id}/status")
async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest, org_id: int = Path(gt=0)): async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest, org_id: Annotated[int, Path(gt=0)]):
org_model = db.query(Org).filter(Org.id == org_id).first() org_model = db.query(Org).filter(Org.id == org_id).first()
if org_model is None: if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
@ -96,7 +97,7 @@ async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest
@router.patch("/{org_id}/contact") @router.patch("/{org_id}/contact")
async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequest, org_id: int = Path(gt=0)): async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequest, org_id: Annotated[int, Path(gt=0)]):
org_model = db.query(Org).filter(Org.id == org_id).first() org_model = db.query(Org).filter(Org.id == org_id).first()
if org_model is None: if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
@ -116,7 +117,7 @@ async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequ
@router.get("/{org_id}/users", response_model=list[OrgUserGetResponse]) @router.get("/{org_id}/users", response_model=list[OrgUserGetResponse])
async def get_users(db: db_dependency, org_id: int = Path(gt=0)): async def get_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)]):
org_exists = db.query(exists().where(Org.id == org_id)).scalar() org_exists = db.query(exists().where(Org.id == org_id)).scalar()
if not org_exists: if not org_exists:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
@ -127,7 +128,7 @@ async def get_users(db: db_dependency, org_id: int = Path(gt=0)):
@router.get("/{org_id}/users/admins", response_model=list[OrgUserGetResponse]) @router.get("/{org_id}/users/admins", response_model=list[OrgUserGetResponse])
async def get_admin_users(db: db_dependency, org_id: int = Path(gt=0)): async def get_admin_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)]):
org_exists = db.query(exists().where(Org.id == org_id)).scalar() org_exists = db.query(exists().where(Org.id == org_id)).scalar()
if not org_exists: if not org_exists:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
@ -138,7 +139,11 @@ async def get_admin_users(db: db_dependency, org_id: int = Path(gt=0)):
@router.post("/{org_id}/users") @router.post("/{org_id}/users")
async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, org_id: int = Path(gt=0)): async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, org_id: Annotated[int, Path(gt=0)]):
org_model = (db.query(Org).filter(Org.id == org_id).first())
if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found")
org_user_model = OrgUsers(**user_request.model_dump(), org_id=org_id) org_user_model = OrgUsers(**user_request.model_dump(), org_id=org_id)
db.add(org_user_model) db.add(org_user_model)
@ -146,11 +151,14 @@ async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, o
@router.patch("/{org_id}/users") @router.patch("/{org_id}/users")
async def update_user_details(db: db_dependency, user_request: OrgUserPostRequest, org_id: int = Path(gt=0)): async def update_user_details(db: db_dependency, user_request: OrgUserPostRequest, org_id: Annotated[int, Path(gt=0)]):
""" """
Currently used only to update user admin status for organisation. Currently used only to update user admin status for organisation.
""" """
# TODO: Check if org exists org_model = (db.query(Org).filter(Org.id == org_id).first())
if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found")
org_user_model = db.query(OrgUsers).filter(OrgUsers.org_id == org_id).filter(OrgUsers.user_id == user_request.user_id).first() org_user_model = db.query(OrgUsers).filter(OrgUsers.org_id == org_id).filter(OrgUsers.user_id == user_request.user_id).first()
if org_user_model is None: if org_user_model is None:
@ -164,7 +172,7 @@ async def update_user_details(db: db_dependency, user_request: OrgUserPostReques
@router.delete("/{org_id}") @router.delete("/{org_id}")
async def delete_organisation_by_id(db: db_dependency, org_id: int = Path(gt=0)): async def delete_organisation_by_id(db: db_dependency, org_id: Annotated[int, Path(gt=0)]):
org_model = (db.query(Org).filter(Org.id == org_id).first()) org_model = (db.query(Org).filter(Org.id == org_id).first())
if org_model is None: if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
@ -173,10 +181,11 @@ async def delete_organisation_by_id(db: db_dependency, org_id: int = Path(gt=0))
@router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse) @router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse)
async def get_contact(db: db_dependency, contact_type: ContactType, org_id: int = Path(gt=0)): async def get_contact(db: db_dependency, contact_type: ContactType, org_id: Annotated[int, Path(gt=0)]):
org_model = db.query(Org).filter(Org.id == org_id).first() org_model = db.query(Org).filter(Org.id == org_id).first()
if org_model is None: if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
match contact_type: match contact_type:
case "billing": case "billing":
contact_id = org_model.billing_contact_id contact_id = org_model.billing_contact_id

View file

@ -6,22 +6,27 @@ Models:
- Models: Description - Models: Description
""" """
from typing import Optional from typing import Optional
from pydantic import Json
from src.schemas import CustomBaseModel from src.schemas import CustomBaseModel
from src.organisation.constants import Status, ContactType from src.organisation.constants import Status, ContactType
class OrgQuestionnaire(CustomBaseModel):
question_one: str
question_two: str
question_three: str
class OrgOrgPostRequest(CustomBaseModel): class OrgOrgPostRequest(CustomBaseModel):
name: str name: str
intake_questionnaire: Optional[Json] = None intake_questionnaire: Optional[OrgQuestionnaire] = None
billing_contact_id: Optional[int] = None billing_contact_id: Optional[int] = None
security_contact_id: Optional[int] = None security_contact_id: Optional[int] = None
owner_contact_id: Optional[int] = None owner_contact_id: Optional[int] = None
class OrgQuestionnairePatchRequest(CustomBaseModel): class OrgQuestionnairePatchRequest(CustomBaseModel):
intake_questionnaire: Json intake_questionnaire: OrgQuestionnaire
partial: bool partial: bool
class OrgStatusPatchRequest(CustomBaseModel): class OrgStatusPatchRequest(CustomBaseModel):

View file

@ -5,3 +5,15 @@ Exceptions:
- List: Description - List: Description
- Exceptions: Description - Exceptions: Description
""" """
from typing import Optional
from fastapi import HTTPException, status
class UserNotFoundException(HTTPException):
def __init__(self, user_id: Optional[int] = None) -> None:
detail = "User not found" if user_id is None else f"User with ID '{user_id}' was not found."
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)

View file

@ -2,21 +2,25 @@
Router endpoints for user module Router endpoints for user module
Endpoints: Endpoints:
- [get]/me/claims - Retrieves user's OIDC claims - [get]/self/claims - Retrieves user's OIDC claims
- [get]/me/db - Retrieves the user data from the db that corresponds to the current OIDC user - [get]/self/db - Retrieves the user data from the db that corresponds to the current OIDC user
- [get]/me/orgs - Retrieves all organisations associated with the current user - [get]/self/orgs - Retrieves all organisations associated with the current user
- [get]/me/orgs/admin - Retrieves only admin organisations for the current user - [get]/self/orgs/admin - Retrieves only admin organisations for the current user
- [get]/{user_id} - Retrieves a specific user by their ID - [get]/{user_id} - Retrieves a specific user by their ID
- [get]/{user_id}/orgs - Retrieves all organisations associated with a specific user - [get]/{user_id}/orgs - Retrieves all organisations associated with a specific user
- [get]/{user_id}/orgs/admin - Retrieves only admin organisations for a specific user - [get]/{user_id}/orgs/admin - Retrieves only admin organisations for a specific user
- [delete]/{user_id} - Deletes a user from the db by their db ID - [delete]/{user_id} - Deletes a user from the db by their db ID
""" """
from fastapi import APIRouter, HTTPException from typing import Annotated
from fastapi import APIRouter
from fastapi.params import Path from fastapi.params import Path
from sqlalchemy.sql import exists from sqlalchemy.sql import exists
from starlette import status
from src.user.models import User from src.user.models import User
from src.user.schemas import UserResponse, OIDCUser, OrgResponse from src.user.schemas import UserResponse, OrgResponse, OIDCClaims
from src.user.exceptions import UserNotFoundException
from src.organisation.models import OrgUsers, Organisation from src.organisation.models import OrgUsers, Organisation
@ -25,36 +29,54 @@ from src.database import db_dependency
router = APIRouter( router = APIRouter(
prefix="/user", prefix="/user",
tags=["user"], tags=["User"],
) )
@router.get("/me/claims") @router.get("/self/claims", response_model=OIDCClaims, status_code=status.HTTP_200_OK, responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def current_user_claims(user: claims_dependency): async def current_user_claims(user: claims_dependency):
"""
Returns the full OIDC claims associated with the currently logged-in user.
"""
user["allowed_origins"] = user.get("allowed-origins", [])
return user return user
@router.get("/me/db", response_model=OIDCUser) @router.get("/self/db", response_model=UserResponse, status_code=status.HTTP_200_OK, responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def current_user(user: claims_dependency, db: db_dependency): async def current_user(user: claims_dependency, db: db_dependency):
db_id = user.get("db_id", None) """
if db_id is None: Returns the database details associated with the currently logged-in user.
raise HTTPException(status_code=404, detail="User not found in db") """
user_id = user.get("db_id", None)
if user_id is None:
raise UserNotFoundException()
user_model = (db.query(User).filter(User.id == db_id).first()) user_model = (db.query(User).filter(User.id == user_id).first())
if user_model is None: if user_model is None:
raise HTTPException(status_code=404, detail="User not found") raise UserNotFoundException(user_id=user_id)
return user_model return user_model
@router.get("/me/orgs", response_model=list[OrgResponse]) @router.get("/self/orgs", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def get_current_organisations(db: db_dependency, user: claims_dependency): async def get_current_organisations(db: db_dependency, user: claims_dependency):
"""
Returns all organisations associated with the currently logged-in user.
"""
user_id = user.get("db_id", None) user_id = user.get("db_id", None)
if user_id is None: if user_id is None:
raise HTTPException(status_code=404, detail="User not found") raise UserNotFoundException()
user_exists = db.query(exists().where(User.id == user_id)).scalar() user_exists = db.query(exists().where(User.id == user_id)).scalar()
if not user_exists: if not user_exists:
raise HTTPException(status_code=404, detail="User not found") raise UserNotFoundException(user_id=user_id)
org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name)
.join(OrgUsers, Organisation.id == OrgUsers.org_id) .join(OrgUsers, Organisation.id == OrgUsers.org_id)
@ -65,14 +87,20 @@ async def get_current_organisations(db: db_dependency, user: claims_dependency):
return org_user_models return org_user_models
@router.get("/me/orgs/admin", response_model=list[OrgResponse]) @router.get("/self/orgs/admin", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def get_current_admin_organisations(db: db_dependency, user: claims_dependency): async def get_current_admin_organisations(db: db_dependency, user: claims_dependency):
"""
Returns the organisations for which the currently logged-in user is an admin.
"""
user_id = user.get("db_id", None) user_id = user.get("db_id", None)
if user_id is None: if user_id is None:
raise HTTPException(status_code=404, detail="User not found") raise UserNotFoundException()
user_exists = db.query(exists().where(User.id == user_id)).scalar() user_exists = db.query(exists().where(User.id == user_id)).scalar()
if not user_exists: if not user_exists:
raise HTTPException(status_code=404, detail="User not found") raise UserNotFoundException(user_id=user_id)
org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name)
.join(OrgUsers, Organisation.id == OrgUsers.org_id) .join(OrgUsers, Organisation.id == OrgUsers.org_id)
@ -84,20 +112,32 @@ async def get_current_admin_organisations(db: db_dependency, user: claims_depend
return org_user_models return org_user_models
@router.get("/{user_id}", response_model=UserResponse) @router.get("/{user_id}", response_model=UserResponse, status_code=status.HTTP_200_OK, responses={
async def get_user_by_id(user_id: int, db: db_dependency): status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def get_user_by_id(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]):
"""
Returns the database details associated with the provided user ID.
"""
user_model = (db.query(User).filter(User.id == user_id).first()) user_model = (db.query(User).filter(User.id == user_id).first())
if user_model is None: if user_model is None:
raise HTTPException(status_code=404, detail="User not found") raise UserNotFoundException(user_id=user_id)
return user_model return user_model
@router.get("/{user_id}/orgs", response_model=list[OrgResponse]) @router.get("/{user_id}/orgs", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={
async def get_organisations(db: db_dependency, user_id: int = Path(gt=0)): status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def get_organisations(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]):
"""
Returns all organisations associated with the provided user ID.
"""
user_exists = db.query(exists().where(User.id == user_id)).scalar() user_exists = db.query(exists().where(User.id == user_id)).scalar()
if not user_exists: if not user_exists:
raise HTTPException(status_code=404, detail="User not found") raise UserNotFoundException(user_id=user_id)
org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name)
.join(OrgUsers, Organisation.id == OrgUsers.org_id) .join(OrgUsers, Organisation.id == OrgUsers.org_id)
@ -108,11 +148,17 @@ async def get_organisations(db: db_dependency, user_id: int = Path(gt=0)):
return org_user_models return org_user_models
@router.get("/{user_id}/orgs/admin", response_model=list[OrgResponse]) @router.get("/{user_id}/orgs/admin", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={
async def get_admin_organisations(db: db_dependency, user_id: int = Path(gt=0)): status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def get_admin_organisations(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]):
"""
Returns the organisations for which the user with the provided user ID is an admin.
"""
user_exists = db.query(exists().where(User.id == user_id)).scalar() user_exists = db.query(exists().where(User.id == user_id)).scalar()
if not user_exists: if not user_exists:
raise HTTPException(status_code=404, detail="User not found") raise UserNotFoundException(user_id=user_id)
org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name)
.join(OrgUsers, Organisation.id == OrgUsers.org_id) .join(OrgUsers, Organisation.id == OrgUsers.org_id)
@ -124,10 +170,16 @@ async def get_admin_organisations(db: db_dependency, user_id: int = Path(gt=0)):
return org_user_models return org_user_models
@router.delete("/{user_id}") @router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT, responses={
async def delete_user_by_id(user_id: int, db: db_dependency): status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
})
async def delete_user_by_id(user_id: Annotated[int, Path(gt=0)], db: db_dependency):
"""
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
"""
user_model = (db.query(User).filter(User.id == user_id).first()) user_model = (db.query(User).filter(User.id == user_id).first())
if user_model is None: if user_model is None:
raise HTTPException(status_code=404, detail="User not found") raise UserNotFoundException(user_id=user_id)
db.delete(user_model) db.delete(user_model)
db.commit() db.commit()

View file

@ -6,7 +6,31 @@ Models:
- Models: Description - Models: Description
""" """
from src.schemas import CustomBaseModel from src.schemas import CustomBaseModel
from pydantic import Field
class OIDCClaims(CustomBaseModel):
exp: int
iat: int
auth_time: int
jti: str
iss: str
aud: str
sub: str
typ: str
azp: str
sid: str
acr: str
allowed_origins: list[str]
realm_access: dict[str, list[str]]
resource_access: dict[str, dict[str, list[str]]]
scope: str
email_verified: bool
name: str
preferred_username: str
given_name: str
family_name: str
email: str
db_id: int
class OIDCUser(CustomBaseModel): class OIDCUser(CustomBaseModel):

View file

@ -7,7 +7,8 @@ Functions:
Exports: Exports:
- add_user_to_db - add_user_to_db
""" """
from authlib.jose import JWTClaims from typing import Any
from fastapi import HTTPException from fastapi import HTTPException
from src.user.schemas import OIDCUser from src.user.schemas import OIDCUser
@ -15,7 +16,7 @@ from src.user.models import User
from src.database import get_db from src.database import get_db
async def add_user_to_db(user_claims: JWTClaims) -> int: async def add_user_to_db(user_claims: dict[str, Any]) -> int:
try: try:
valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"]) valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"])
except Exception as e: except Exception as e:
@ -31,5 +32,14 @@ async def add_user_to_db(user_claims: JWTClaims) -> int:
db.commit() db.commit()
return user_model.id return user_model.id
else: else:
# Verify details still match and update accordingly. change = False
if db_user.first_name != valid_user.first_name:
db_user.first_name = valid_user.first_name
change = True
if db_user.last_name != valid_user.last_name:
db_user.last_name = valid_user.last_name
change = True
if change:
db.add(db_user)
db.commit()
return db_user.id return db_user.id