""" Module specific business logic for auth module Exports: - claims_dependency """ import json import requests from typing import Annotated, Any from joserfc import jwt from joserfc.errors import ExpiredTokenError from joserfc.jwk import KeySet from urllib.request import urlopen from fastapi import Depends, HTTPException, Path from fastapi.security import OpenIdConnect from sqlalchemy.sql import exists from src.auth.config import auth_settings from src.user.service import add_user_to_db from src.organisation.models import OrgUsers, Organisation as Org from src.database import db_dependency oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc_dependency = Annotated[str, Depends(oidc)] async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]: config_url = urlopen(auth_settings.OIDC_CONFIG) config = json.loads(config_url.read()) jwks_uri = config["jwks_uri"] key_response = requests.get(jwks_uri) jwk_keys = KeySet.import_key_set(key_response.json()) claims_options = { "exp": {"essential": True}, "aud": {"essential": True, "value": "account"}, "iss": {"essential": True, "value": auth_settings.OIDC_ISSUER}, } token = jwt.decode( oidc_auth_string.replace("Bearer ", ""), jwk_keys ) claims_requests = jwt.JWTClaimsRegistry(**claims_options) try: claims_requests.validate(token.claims) except ExpiredTokenError as e: raise HTTPException(status_code=401, detail="Token expired") db_id = await add_user_to_db(token.claims) token.claims["db_id"] = db_id return token.claims claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): org_exists = db.query(exists().where(Org.id == org_id)).scalar() if not org_exists: raise HTTPException(status_code=404, detail="Organisation not found") db_id = claims.get("db_id", None) if db_id is None: raise HTTPException(status_code=404, detail="User not found in db") exists_query = (db.query(OrgUsers) .filter(OrgUsers.org_id == org_id, OrgUsers.user_id == db_id ).exists() ) org_user_exists = db.query(exists_query).scalar() if not org_user_exists: raise HTTPException(status_code=401, detail="Not authorised") return org_user_exists org_user_dependency = Annotated[dict[str, Any], Depends(is_org_user)] async def is_super_admin(claims: claims_dependency): super_admin_ids = [] db_id = claims.get("db_id", None) if db_id is None: raise HTTPException(status_code=404, detail="User not found in db") if db_id not in super_admin_ids: raise HTTPException(status_code=401, detail="Not authorised") return True super_admin_dependency = Annotated[dict[str, Any], Depends(is_super_admin)]