""" Module specific business logic for auth module Exports: - claims_dependency """ import json import requests from typing import Annotated, Any from joserfc import jwt from joserfc.errors import ExpiredTokenError from joserfc.jwk import KeySet from urllib.request import urlopen from fastapi import Depends, HTTPException, Path from fastapi.security import OpenIdConnect from sqlalchemy.sql import exists from src.auth.config import auth_settings from src.user.service import add_user_to_db from src.organisation.models import OrgUsers, Organisation as Org from src.user.models import User from src.database import db_dependency from src.organisation.dependencies import org_model_dependency oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc_dependency = Annotated[str, Depends(oidc)] def get_dev_user(): return {"db_id": 1} async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]: config_url = urlopen(auth_settings.OIDC_CONFIG) config = json.loads(config_url.read()) jwks_uri = config["jwks_uri"] key_response = requests.get(jwks_uri) jwk_keys = KeySet.import_key_set(key_response.json()) claims_options = { "exp": {"essential": True}, "aud": {"essential": True, "value": "account"}, "iss": {"essential": True, "value": auth_settings.OIDC_ISSUER}, } token = jwt.decode( oidc_auth_string.replace("Bearer ", ""), jwk_keys ) claims_requests = jwt.JWTClaimsRegistry(**claims_options) try: claims_requests.validate(token.claims) except ExpiredTokenError as e: raise HTTPException(status_code=401, detail="Token expired") db_id = await add_user_to_db(token.claims) token.claims["db_id"] = db_id return token.claims claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): org_exists = db.query(exists().where(Org.id == org_id)).scalar() if not org_exists: raise HTTPException(status_code=404, detail="Organisation not found") db_id = claims.get("db_id", None) if db_id is None: raise HTTPException(status_code=404, detail="User not found in db") exists_query = (db.query(OrgUsers) .filter(OrgUsers.org_id == org_id, OrgUsers.user_id == db_id ).exists() ) org_user_exists = db.query(exists_query).scalar() if not org_user_exists: raise HTTPException(status_code=401, detail="Not authorised") return org_user_exists org_user_dependency = Annotated[dict[str, Any], Depends(is_org_user)] async def is_org_root(claims: claims_dependency, db: db_dependency, org_model: org_model_dependency, org_id: int = Path(gt=0)): db_id = claims.get("db_id", None) if db_id is None: raise HTTPException(status_code=404, detail="User not found in db") if org_model.root_user_id == db_id: return db.query(User).filter(User.id == db_id).first() raise HTTPException(status_code=401, detail="Not authorised") root_user_dependency = Annotated[dict[str, Any], Depends(is_org_root)] async def is_super_admin(claims: claims_dependency): super_admin_ids = [] db_id = claims.get("db_id", None) if db_id is None: raise HTTPException(status_code=404, detail="User not found in db") if db_id not in super_admin_ids: raise HTTPException(status_code=401, detail="Not authorised") return True super_admin_dependency = Annotated[dict[str, Any], Depends(is_super_admin)]