""" Module specific business logic for the auth module Exports: - claims_dependency: Dict[str, Any] containing OIDC claims and database ID """ import json import requests from typing import Annotated, Any from joserfc import jwt from joserfc.errors import ExpiredTokenError from joserfc.jwk import KeySet from urllib.request import urlopen from fastapi import Depends from fastapi.security import OpenIdConnect from src.auth.exceptions import UnauthorizedException from src.auth.config import auth_settings from src.user.service import add_user_to_db from src.database import db_dependency oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc_dependency = Annotated[str, Depends(oidc)] def get_dev_user(): return {"db_id": 1} async def get_current_user(oidc_auth_string: oidc_dependency, db: db_dependency) -> dict[str, Any]: config_url = urlopen(auth_settings.OIDC_CONFIG) config = json.loads(config_url.read()) jwks_uri = config["jwks_uri"] key_response = requests.get(jwks_uri) jwk_keys = KeySet.import_key_set(key_response.json()) claims_options = { "exp": {"essential": True}, "iss": {"essential": True, "value": auth_settings.OIDC_ISSUER}, } token = jwt.decode( oidc_auth_string.replace("Bearer ", ""), jwk_keys ) claims_requests = jwt.JWTClaimsRegistry(**claims_options) try: claims_requests.validate(token.claims) except ExpiredTokenError: raise UnauthorizedException(message="Token is expired") db_id = await add_user_to_db(db, token.claims) token.claims["db_id"] = db_id return token.claims claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)]