""" Module specific business logic for the auth module Exports: - claims_dependency: Dict[str, Any] containing OIDC claims and database ID """ import json import requests from typing import Annotated, Any from joserfc import jwt from joserfc.errors import ExpiredTokenError from joserfc.jwk import KeySet from urllib.request import urlopen from fastapi import Depends, Request from fastapi.security import OpenIdConnect from src.organisation.constants import Status as OrgStatus from src.organisation.exceptions import AwaitingApprovalException from src.organisation.models import Organisation as Org from src.exceptions import UnauthorizedException, ForbiddenException from src.auth.config import auth_settings from src.user.service import add_user_to_db from src.database import db_dependency oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc_dependency = Annotated[str, Depends(oidc)] async def get_dev_user(): return {"db_id": 1, "email": "chris@sr2.uk"} async def get_current_user( oidc_auth_string: oidc_dependency, db: db_dependency ) -> dict[str, Any]: config_url = urlopen(auth_settings.OIDC_CONFIG) config = json.loads(config_url.read()) jwks_uri = config["jwks_uri"] key_response = requests.get(jwks_uri) jwk_keys = KeySet.import_key_set(key_response.json()) claims_options = { "exp": {"essential": True}, "iss": {"essential": True, "value": auth_settings.OIDC_ISSUER}, } token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys) claims_requests = jwt.JWTClaimsRegistry(**claims_options) try: claims_requests.validate(token.claims) except ExpiredTokenError: raise UnauthorizedException(message="Token is expired") db_id = await add_user_to_db(db, token.claims) token.claims["db_id"] = db_id return token.claims claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] async def org_status_check(org_model: Org, request: Request): org_status = OrgStatus(org_model.status) if org_status.is_blocked: raise ForbiddenException("This organisation cannot perform this action.") root = "/api/v1" pre_approval_endpoints = [ f"PATCH{root}/org/status", f"PATCH{root}/org/questionnaire", f"GET{root}/org", f"GET{root}/org/contact", f"PATCH{root}/org/contact", f"DELETE{root}/org/self", ] current_request = f"{request.method}{request.url.path}" if ( current_request not in pre_approval_endpoints and org_model.status != OrgStatus.APPROVED ): raise AwaitingApprovalException(org_model.id)