Compare commits

..

1 commit

Author SHA1 Message Date
285b7044f0 Add renovate.json 2026-05-20 09:25:00 +00:00
14 changed files with 94 additions and 225 deletions

8
renovate.json Normal file
View file

@ -0,0 +1,8 @@
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": [
"config:recommended"
],
"minimumReleaseAge": "14 days",
"gitAuthor": "Renovate<noreply@sr2.uk>"
}

View file

@ -18,4 +18,3 @@ httptools
psycopg psycopg
email-validator email-validator
alembic alembic
joserfc

View file

@ -5,17 +5,20 @@ Endpoints:
- List: Description - List: Description
- Endpoints: Description - Endpoints: Description
""" """
from typing import Annotated
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from fastapi.params import Path from fastapi.params import Path
from sqlalchemy.sql import exists
from src.organisation.constants import ContactType from src.organisation.constants import ContactType
from src.organisation.schemas import OrgContactGetResponse from src.organisation.schemas import OrgContactGetResponse
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
from src.contact.models import Contact from src.contact.models import Contact
from src.user.models import User
from src.user.schemas import UserResponse, OIDCUser, OrgResponse
from src.auth.service import claims_dependency, org_or_super_admin_dependency from src.organisation.models import OrgUsers, Organisation
from src.auth.service import claims_dependency, org_user_dependency, org_admin_dependency
from src.database import db_dependency from src.database import db_dependency
@ -26,7 +29,7 @@ router = APIRouter(
@router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse) @router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse)
async def get_contact(db: db_dependency, user: claims_dependency, is_admin: org_or_super_admin_dependency, contact_type: ContactType, org_id: Annotated[int, Path(gt=0)]): async def get_contact(db: db_dependency, user: claims_dependency, is_org_admin: org_user_dependency, contact_type: ContactType, org_id: int = Path(gt=0)):
org_model = db.query(Org).filter(Org.id == org_id).first() org_model = db.query(Org).filter(Org.id == org_id).first()
if org_model is None: if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")

View file

@ -5,16 +5,16 @@ Exports:
- claims_dependency - claims_dependency
""" """
import json import json
import requests
from typing import Annotated, Any from typing import Annotated
from joserfc import jwt from authlib.jose import jwt
from joserfc.errors import ExpiredTokenError
from joserfc.jwk import KeySet
from urllib.request import urlopen from urllib.request import urlopen
from fastapi import Depends, HTTPException, Path from fastapi import Depends, HTTPException, Path
from fastapi.security import OpenIdConnect from fastapi.security import OpenIdConnect
from authlib.jose.rfc7517.jwk import JsonWebKey
from authlib.jose.rfc7517.key_set import KeySet
from authlib.oauth2.rfc7523.validator import JWTBearerToken
from sqlalchemy.sql import exists from sqlalchemy.sql import exists
from src.auth.config import auth_settings from src.auth.config import auth_settings
@ -27,12 +27,12 @@ oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG)
oidc_dependency = Annotated[str, Depends(oidc)] oidc_dependency = Annotated[str, Depends(oidc)]
async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]: async def get_current_user(oidc_auth_string: oidc_dependency) -> JWTBearerToken:
config_url = urlopen(auth_settings.OIDC_CONFIG) config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read()) config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"] jwks_uri = config["jwks_uri"]
key_response = requests.get(jwks_uri) key_response = urlopen(jwks_uri)
jwk_keys = KeySet.import_key_set(key_response.json()) jwk_keys: KeySet = JsonWebKey.import_key_set(json.loads(key_response.read()))
claims_options = { claims_options = {
"exp": {"essential": True}, "exp": {"essential": True},
@ -40,26 +40,22 @@ async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]:
"iss": {"essential": True, "value": auth_settings.OIDC_ISSUER}, "iss": {"essential": True, "value": auth_settings.OIDC_ISSUER},
} }
token = jwt.decode( claims: JWTBearerToken = jwt.decode(
oidc_auth_string.replace("Bearer ", ""), oidc_auth_string.replace("Bearer ", ""),
jwk_keys jwk_keys,
claims_options=claims_options,
claims_cls=JWTBearerToken,
) )
claims_requests = jwt.JWTClaimsRegistry(**claims_options) claims.validate()
db_id = await add_user_to_db(claims)
try: claims["db_id"] = db_id
claims_requests.validate(token.claims)
except ExpiredTokenError as e:
raise HTTPException(status_code=401, detail="Token expired")
db_id = await add_user_to_db(token.claims) return claims
token.claims["db_id"] = db_id
return token.claims
claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] claims_dependency = Annotated[JWTBearerToken, Depends(get_current_user)]
async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)):
@ -85,7 +81,7 @@ async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int
return org_user_exists return org_user_exists
org_user_dependency = Annotated[dict[str, Any], Depends(is_org_user)] org_user_dependency = Annotated[JWTBearerToken, Depends(is_org_user)]
async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)):
@ -112,7 +108,7 @@ async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int
return org_admin_exists return org_admin_exists
org_admin_dependency = Annotated[dict[str, Any], Depends(is_org_admin)] org_admin_dependency = Annotated[JWTBearerToken, Depends(is_org_admin)]
async def is_super_admin(claims: claims_dependency): async def is_super_admin(claims: claims_dependency):
@ -127,22 +123,10 @@ async def is_super_admin(claims: claims_dependency):
return True return True
super_admin_dependency = Annotated[dict[str, Any], Depends(is_super_admin)] super_admin_dependency = Annotated[JWTBearerToken, Depends(is_super_admin)]
async def is_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)):
try:
await is_super_admin(claims)
return True
except HTTPException as e:
pass
try:
await is_org_admin(claims, db, org_id)
return True
except HTTPException as e:
raise HTTPException(status_code=401, detail="Not authorised")
org_or_super_admin_dependency = Annotated[dict[str, Any], Depends(is_admin)]
# Middleware version of user auth # Middleware version of user auth
# import json # import json

View file

@ -9,8 +9,6 @@ Endpoints:
- [patch]/{contact_id} - Updates the details of an existing contact - [patch]/{contact_id} - Updates the details of an existing contact
- [delete]/{contact_id} - Deletes a contact by ID - [delete]/{contact_id} - Deletes a contact by ID
""" """
from typing import Annotated
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from fastapi.params import Path from fastapi.params import Path
@ -57,7 +55,7 @@ async def create_contact(db: db_dependency, contact_request: ContactContactPostR
@router.patch("/{contact_id}") @router.patch("/{contact_id}")
async def update_contact(db: db_dependency, contact_request: ContactUpdateRequest, contact_id: Annotated[int, Path(gt=0)]): async def update_contact(db: db_dependency, contact_request: ContactUpdateRequest, contact_id: int = Path(gt=0)):
contact_model = (db.query(Contact).filter(Contact.id == contact_id).first()) contact_model = (db.query(Contact).filter(Contact.id == contact_id).first())
if contact_model is None: if contact_model is None:
raise HTTPException(status_code=404, detail="Contact not found") raise HTTPException(status_code=404, detail="Contact not found")
@ -74,7 +72,7 @@ async def update_contact(db: db_dependency, contact_request: ContactUpdateReques
@router.delete("/{contact_id}") @router.delete("/{contact_id}")
async def delete_contact(db: db_dependency, contact_id: Annotated[int, Path(gt=0)]): async def delete_contact(db: db_dependency, contact_id: int = Path(gt=0)):
contact_model = (db.query(Contact).filter(Contact.id == contact_id).first()) contact_model = (db.query(Contact).filter(Contact.id == contact_id).first())
if contact_model is None: if contact_model is None:
raise HTTPException(status_code=404, detail="Contact not found") raise HTTPException(status_code=404, detail="Contact not found")
@ -84,7 +82,7 @@ async def delete_contact(db: db_dependency, contact_id: Annotated[int, Path(gt=0
@router.get("/{contact_id}/orgs", response_model=list[ContactOrgGetResponse]) @router.get("/{contact_id}/orgs", response_model=list[ContactOrgGetResponse])
async def get_contact_orgs(db: db_dependency, contact_id: Annotated[int, Path(gt=0)]): async def get_contact_orgs(db: db_dependency, contact_id: int = Path(gt=0)):
contact_model = (db.query(Contact).filter(Contact.id == contact_id).first()) contact_model = (db.query(Contact).filter(Contact.id == contact_id).first())
if contact_model is None: if contact_model is None:
raise HTTPException(status_code=404, detail="Contact not found") raise HTTPException(status_code=404, detail="Contact not found")

View file

@ -7,7 +7,7 @@ Models:
""" """
from typing import Optional from typing import Optional
from pydantic import EmailStr from pydantic import Field, EmailStr
from src.organisation.constants import ContactType from src.organisation.constants import ContactType
from src.schemas import CustomBaseModel from src.schemas import CustomBaseModel

View file

@ -26,28 +26,15 @@ if settings.ENVIRONMENT.is_deployed:
pass pass
tags_metadata = [
{
"name": "User",
"description": "User related operations, includes getting information about the current user",
}
]
app = FastAPI( app = FastAPI(
swagger_ui_init_oauth={ swagger_ui_init_oauth={
"clientId": auth_settings.CLIENT_ID, "clientId": auth_settings.CLIENT_ID,
"usePkceWithAuthorizationCodeGrant": True, "usePkceWithAuthorizationCodeGrant": True,
"scopes": "openid profile email", "scopes": "openid profile email",
}, }
openapi_tags=tags_metadata,
) )
# Type inspection disabled for middleware injection.
# Known bug in FastAPI type checking: https://github.com/astral-sh/ty/issues/1635
# noinspection PyTypeChecker
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value()) app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value())
# noinspection PyTypeChecker
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=settings.CORS_ORIGINS, allow_origins=settings.CORS_ORIGINS,

View file

@ -9,6 +9,8 @@ Models:
from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, JSON, false from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, JSON, false
from src.database import Base from src.database import Base
from src.contact.models import Contact
from src.user.models import User
class Organisation(Base): class Organisation(Base):

View file

@ -14,13 +14,12 @@ Endpoints:
- [delete]/{org_id} - Deletes an organisation by ID - [delete]/{org_id} - Deletes an organisation by ID
- [get]/{org_id}/contact/{contact_type} - Retrieves the contact of a specific type (owner, billing, security) for an organisation - [get]/{org_id}/contact/{contact_type} - Retrieves the contact of a specific type (owner, billing, security) for an organisation
""" """
from typing import Annotated
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from fastapi.params import Path from fastapi.params import Path
from sqlalchemy.sql import exists from sqlalchemy.sql import exists
from src.auth.service import super_admin_dependency
from src.database import db_dependency from src.database import db_dependency
from src.contact.models import Contact from src.contact.models import Contact
@ -37,7 +36,7 @@ router = APIRouter(
@router.get("/id/{org_id}", response_model=OrgOrgGetResponse) @router.get("/id/{org_id}", response_model=OrgOrgGetResponse)
async def get_org_by_id(db: db_dependency, org_id: Annotated[int, Path(gt=0)]): async def get_org_by_id(db: db_dependency, org_id: int = Path(gt=0)):
org_model = (db.query(Org).filter(Org.id == org_id).first()) org_model = (db.query(Org).filter(Org.id == org_id).first())
if org_model is None: if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
@ -64,7 +63,7 @@ async def create_org(db: db_dependency, org_request: OrgOrgPostRequest):
@router.patch("/{org_id}/questionnaire") @router.patch("/{org_id}/questionnaire")
async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePatchRequest, org_id: Annotated[int, Path(gt=0)]): async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePatchRequest, org_id: int = Path(gt=0)):
""" """
Route for updating questionnaire. Route for updating questionnaire.
The partial bool allows for submission of partially completed questionnaire and/or The partial bool allows for submission of partially completed questionnaire and/or
@ -85,7 +84,7 @@ async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePat
@router.patch("/{org_id}/status") @router.patch("/{org_id}/status")
async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest, org_id: Annotated[int, Path(gt=0)]): async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest, org_id: int = Path(gt=0)):
org_model = db.query(Org).filter(Org.id == org_id).first() org_model = db.query(Org).filter(Org.id == org_id).first()
if org_model is None: if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
@ -97,7 +96,7 @@ async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest
@router.patch("/{org_id}/contact") @router.patch("/{org_id}/contact")
async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequest, org_id: Annotated[int, Path(gt=0)]): async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequest, org_id: int = Path(gt=0)):
org_model = db.query(Org).filter(Org.id == org_id).first() org_model = db.query(Org).filter(Org.id == org_id).first()
if org_model is None: if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
@ -117,7 +116,7 @@ async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequ
@router.get("/{org_id}/users", response_model=list[OrgUserGetResponse]) @router.get("/{org_id}/users", response_model=list[OrgUserGetResponse])
async def get_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)]): async def get_users(db: db_dependency, org_id: int = Path(gt=0)):
org_exists = db.query(exists().where(Org.id == org_id)).scalar() org_exists = db.query(exists().where(Org.id == org_id)).scalar()
if not org_exists: if not org_exists:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
@ -128,7 +127,7 @@ async def get_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)]):
@router.get("/{org_id}/users/admins", response_model=list[OrgUserGetResponse]) @router.get("/{org_id}/users/admins", response_model=list[OrgUserGetResponse])
async def get_admin_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)]): async def get_admin_users(db: db_dependency, org_id: int = Path(gt=0)):
org_exists = db.query(exists().where(Org.id == org_id)).scalar() org_exists = db.query(exists().where(Org.id == org_id)).scalar()
if not org_exists: if not org_exists:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
@ -139,11 +138,7 @@ async def get_admin_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)])
@router.post("/{org_id}/users") @router.post("/{org_id}/users")
async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, org_id: Annotated[int, Path(gt=0)]): async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, org_id: int = Path(gt=0)):
org_model = (db.query(Org).filter(Org.id == org_id).first())
if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found")
org_user_model = OrgUsers(**user_request.model_dump(), org_id=org_id) org_user_model = OrgUsers(**user_request.model_dump(), org_id=org_id)
db.add(org_user_model) db.add(org_user_model)
@ -151,14 +146,11 @@ async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, o
@router.patch("/{org_id}/users") @router.patch("/{org_id}/users")
async def update_user_details(db: db_dependency, user_request: OrgUserPostRequest, org_id: Annotated[int, Path(gt=0)]): async def update_user_details(db: db_dependency, user_request: OrgUserPostRequest, org_id: int = Path(gt=0)):
""" """
Currently used only to update user admin status for organisation. Currently used only to update user admin status for organisation.
""" """
org_model = (db.query(Org).filter(Org.id == org_id).first()) # TODO: Check if org exists
if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found")
org_user_model = db.query(OrgUsers).filter(OrgUsers.org_id == org_id).filter(OrgUsers.user_id == user_request.user_id).first() org_user_model = db.query(OrgUsers).filter(OrgUsers.org_id == org_id).filter(OrgUsers.user_id == user_request.user_id).first()
if org_user_model is None: if org_user_model is None:
@ -172,7 +164,7 @@ async def update_user_details(db: db_dependency, user_request: OrgUserPostReques
@router.delete("/{org_id}") @router.delete("/{org_id}")
async def delete_organisation_by_id(db: db_dependency, org_id: Annotated[int, Path(gt=0)]): async def delete_organisation_by_id(db: db_dependency, org_id: int = Path(gt=0)):
org_model = (db.query(Org).filter(Org.id == org_id).first()) org_model = (db.query(Org).filter(Org.id == org_id).first())
if org_model is None: if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
@ -181,11 +173,10 @@ async def delete_organisation_by_id(db: db_dependency, org_id: Annotated[int, Pa
@router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse) @router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse)
async def get_contact(db: db_dependency, contact_type: ContactType, org_id: Annotated[int, Path(gt=0)]): async def get_contact(db: db_dependency, contact_type: ContactType, org_id: int = Path(gt=0)):
org_model = db.query(Org).filter(Org.id == org_id).first() org_model = db.query(Org).filter(Org.id == org_id).first()
if org_model is None: if org_model is None:
raise HTTPException(status_code=404, detail="Organisation not found") raise HTTPException(status_code=404, detail="Organisation not found")
match contact_type: match contact_type:
case "billing": case "billing":
contact_id = org_model.billing_contact_id contact_id = org_model.billing_contact_id

View file

@ -6,27 +6,22 @@ Models:
- Models: Description - Models: Description
""" """
from typing import Optional from typing import Optional
from pydantic import Json
from src.schemas import CustomBaseModel from src.schemas import CustomBaseModel
from src.organisation.constants import Status, ContactType from src.organisation.constants import Status, ContactType
class OrgQuestionnaire(CustomBaseModel):
question_one: str
question_two: str
question_three: str
class OrgOrgPostRequest(CustomBaseModel): class OrgOrgPostRequest(CustomBaseModel):
name: str name: str
intake_questionnaire: Optional[OrgQuestionnaire] = None intake_questionnaire: Optional[Json] = None
billing_contact_id: Optional[int] = None billing_contact_id: Optional[int] = None
security_contact_id: Optional[int] = None security_contact_id: Optional[int] = None
owner_contact_id: Optional[int] = None owner_contact_id: Optional[int] = None
class OrgQuestionnairePatchRequest(CustomBaseModel): class OrgQuestionnairePatchRequest(CustomBaseModel):
intake_questionnaire: OrgQuestionnaire intake_questionnaire: Json
partial: bool partial: bool
class OrgStatusPatchRequest(CustomBaseModel): class OrgStatusPatchRequest(CustomBaseModel):

View file

@ -4,16 +4,4 @@ Module specific exceptions for user module
Exceptions: Exceptions:
- List: Description - List: Description
- Exceptions: Description - Exceptions: Description
""" """
from typing import Optional
from fastapi import HTTPException, status
class UserNotFoundException(HTTPException):
def __init__(self, user_id: Optional[int] = None) -> None:
detail = "User not found" if user_id is None else f"User with ID '{user_id}' was not found."
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)

View file

@ -2,25 +2,21 @@
Router endpoints for user module Router endpoints for user module
Endpoints: Endpoints:
- [get]/self/claims - Retrieves user's OIDC claims - [get]/me/claims - Retrieves user's OIDC claims
- [get]/self/db - Retrieves the user data from the db that corresponds to the current OIDC user - [get]/me/db - Retrieves the user data from the db that corresponds to the current OIDC user
- [get]/self/orgs - Retrieves all organisations associated with the current user - [get]/me/orgs - Retrieves all organisations associated with the current user
- [get]/self/orgs/admin - Retrieves only admin organisations for the current user - [get]/me/orgs/admin - Retrieves only admin organisations for the current user
- [get]/{user_id} - Retrieves a specific user by their ID - [get]/{user_id} - Retrieves a specific user by their ID
- [get]/{user_id}/orgs - Retrieves all organisations associated with a specific user - [get]/{user_id}/orgs - Retrieves all organisations associated with a specific user
- [get]/{user_id}/orgs/admin - Retrieves only admin organisations for a specific user - [get]/{user_id}/orgs/admin - Retrieves only admin organisations for a specific user
- [delete]/{user_id} - Deletes a user from the db by their db ID - [delete]/{user_id} - Deletes a user from the db by their db ID
""" """
from typing import Annotated from fastapi import APIRouter, HTTPException
from fastapi import APIRouter
from fastapi.params import Path from fastapi.params import Path
from sqlalchemy.sql import exists from sqlalchemy.sql import exists
from starlette import status
from src.user.models import User from src.user.models import User
from src.user.schemas import UserResponse, OrgResponse, OIDCClaims from src.user.schemas import UserResponse, OIDCUser, OrgResponse
from src.user.exceptions import UserNotFoundException
from src.organisation.models import OrgUsers, Organisation from src.organisation.models import OrgUsers, Organisation
@ -29,54 +25,36 @@ from src.database import db_dependency
router = APIRouter( router = APIRouter(
prefix="/user", prefix="/user",
tags=["User"], tags=["user"],
) )
@router.get("/self/claims", response_model=OIDCClaims, status_code=status.HTTP_200_OK, responses={ @router.get("/me/claims")
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def current_user_claims(user: claims_dependency): async def current_user_claims(user: claims_dependency):
"""
Returns the full OIDC claims associated with the currently logged-in user.
"""
user["allowed_origins"] = user.get("allowed-origins", [])
return user return user
@router.get("/self/db", response_model=UserResponse, status_code=status.HTTP_200_OK, responses={ @router.get("/me/db", response_model=OIDCUser)
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def current_user(user: claims_dependency, db: db_dependency): async def current_user(user: claims_dependency, db: db_dependency):
""" db_id = user.get("db_id", None)
Returns the database details associated with the currently logged-in user. if db_id is None:
""" raise HTTPException(status_code=404, detail="User not found in db")
user_id = user.get("db_id", None)
if user_id is None:
raise UserNotFoundException()
user_model = (db.query(User).filter(User.id == user_id).first()) user_model = (db.query(User).filter(User.id == db_id).first())
if user_model is None: if user_model is None:
raise UserNotFoundException(user_id=user_id) raise HTTPException(status_code=404, detail="User not found")
return user_model return user_model
@router.get("/self/orgs", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={ @router.get("/me/orgs", response_model=list[OrgResponse])
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def get_current_organisations(db: db_dependency, user: claims_dependency): async def get_current_organisations(db: db_dependency, user: claims_dependency):
"""
Returns all organisations associated with the currently logged-in user.
"""
user_id = user.get("db_id", None) user_id = user.get("db_id", None)
if user_id is None: if user_id is None:
raise UserNotFoundException() raise HTTPException(status_code=404, detail="User not found")
user_exists = db.query(exists().where(User.id == user_id)).scalar() user_exists = db.query(exists().where(User.id == user_id)).scalar()
if not user_exists: if not user_exists:
raise UserNotFoundException(user_id=user_id) raise HTTPException(status_code=404, detail="User not found")
org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name)
.join(OrgUsers, Organisation.id == OrgUsers.org_id) .join(OrgUsers, Organisation.id == OrgUsers.org_id)
@ -87,20 +65,14 @@ async def get_current_organisations(db: db_dependency, user: claims_dependency):
return org_user_models return org_user_models
@router.get("/self/orgs/admin", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={ @router.get("/me/orgs/admin", response_model=list[OrgResponse])
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def get_current_admin_organisations(db: db_dependency, user: claims_dependency): async def get_current_admin_organisations(db: db_dependency, user: claims_dependency):
"""
Returns the organisations for which the currently logged-in user is an admin.
"""
user_id = user.get("db_id", None) user_id = user.get("db_id", None)
if user_id is None: if user_id is None:
raise UserNotFoundException() raise HTTPException(status_code=404, detail="User not found")
user_exists = db.query(exists().where(User.id == user_id)).scalar() user_exists = db.query(exists().where(User.id == user_id)).scalar()
if not user_exists: if not user_exists:
raise UserNotFoundException(user_id=user_id) raise HTTPException(status_code=404, detail="User not found")
org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name)
.join(OrgUsers, Organisation.id == OrgUsers.org_id) .join(OrgUsers, Organisation.id == OrgUsers.org_id)
@ -112,32 +84,20 @@ async def get_current_admin_organisations(db: db_dependency, user: claims_depend
return org_user_models return org_user_models
@router.get("/{user_id}", response_model=UserResponse, status_code=status.HTTP_200_OK, responses={ @router.get("/{user_id}", response_model=UserResponse)
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, async def get_user_by_id(user_id: int, db: db_dependency):
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def get_user_by_id(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]):
"""
Returns the database details associated with the provided user ID.
"""
user_model = (db.query(User).filter(User.id == user_id).first()) user_model = (db.query(User).filter(User.id == user_id).first())
if user_model is None: if user_model is None:
raise UserNotFoundException(user_id=user_id) raise HTTPException(status_code=404, detail="User not found")
return user_model return user_model
@router.get("/{user_id}/orgs", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={ @router.get("/{user_id}/orgs", response_model=list[OrgResponse])
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, async def get_organisations(db: db_dependency, user_id: int = Path(gt=0)):
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def get_organisations(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]):
"""
Returns all organisations associated with the provided user ID.
"""
user_exists = db.query(exists().where(User.id == user_id)).scalar() user_exists = db.query(exists().where(User.id == user_id)).scalar()
if not user_exists: if not user_exists:
raise UserNotFoundException(user_id=user_id) raise HTTPException(status_code=404, detail="User not found")
org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name)
.join(OrgUsers, Organisation.id == OrgUsers.org_id) .join(OrgUsers, Organisation.id == OrgUsers.org_id)
@ -148,17 +108,11 @@ async def get_organisations(db: db_dependency, user_id: Annotated[int, Path(gt=0
return org_user_models return org_user_models
@router.get("/{user_id}/orgs/admin", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={ @router.get("/{user_id}/orgs/admin", response_model=list[OrgResponse])
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, async def get_admin_organisations(db: db_dependency, user_id: int = Path(gt=0)):
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def get_admin_organisations(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]):
"""
Returns the organisations for which the user with the provided user ID is an admin.
"""
user_exists = db.query(exists().where(User.id == user_id)).scalar() user_exists = db.query(exists().where(User.id == user_id)).scalar()
if not user_exists: if not user_exists:
raise UserNotFoundException(user_id=user_id) raise HTTPException(status_code=404, detail="User not found")
org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name) org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name)
.join(OrgUsers, Organisation.id == OrgUsers.org_id) .join(OrgUsers, Organisation.id == OrgUsers.org_id)
@ -170,16 +124,10 @@ async def get_admin_organisations(db: db_dependency, user_id: Annotated[int, Pat
return org_user_models return org_user_models
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT, responses={ @router.delete("/{user_id}")
status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, async def delete_user_by_id(user_id: int, db: db_dependency):
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
})
async def delete_user_by_id(user_id: Annotated[int, Path(gt=0)], db: db_dependency):
"""
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
"""
user_model = (db.query(User).filter(User.id == user_id).first()) user_model = (db.query(User).filter(User.id == user_id).first())
if user_model is None: if user_model is None:
raise UserNotFoundException(user_id=user_id) raise HTTPException(status_code=404, detail="User not found")
db.delete(user_model) db.delete(user_model)
db.commit() db.commit()

View file

@ -6,31 +6,7 @@ Models:
- Models: Description - Models: Description
""" """
from src.schemas import CustomBaseModel from src.schemas import CustomBaseModel
from pydantic import Field
class OIDCClaims(CustomBaseModel):
exp: int
iat: int
auth_time: int
jti: str
iss: str
aud: str
sub: str
typ: str
azp: str
sid: str
acr: str
allowed_origins: list[str]
realm_access: dict[str, list[str]]
resource_access: dict[str, dict[str, list[str]]]
scope: str
email_verified: bool
name: str
preferred_username: str
given_name: str
family_name: str
email: str
db_id: int
class OIDCUser(CustomBaseModel): class OIDCUser(CustomBaseModel):

View file

@ -7,8 +7,7 @@ Functions:
Exports: Exports:
- add_user_to_db - add_user_to_db
""" """
from typing import Any from authlib.jose import JWTClaims
from fastapi import HTTPException from fastapi import HTTPException
from src.user.schemas import OIDCUser from src.user.schemas import OIDCUser
@ -16,7 +15,7 @@ from src.user.models import User
from src.database import get_db from src.database import get_db
async def add_user_to_db(user_claims: dict[str, Any]) -> int: async def add_user_to_db(user_claims: JWTClaims) -> int:
try: try:
valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"]) valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"])
except Exception as e: except Exception as e:
@ -32,14 +31,5 @@ async def add_user_to_db(user_claims: dict[str, Any]) -> int:
db.commit() db.commit()
return user_model.id return user_model.id
else: else:
change = False # Verify details still match and update accordingly.
if db_user.first_name != valid_user.first_name:
db_user.first_name = valid_user.first_name
change = True
if db_user.last_name != valid_user.last_name:
db_user.last_name = valid_user.last_name
change = True
if change:
db.add(db_user)
db.commit()
return db_user.id return db_user.id