""" Module specific business logic for auth module Exports: - claims_dependency """ import json import requests from typing import Annotated, Any from joserfc import jwt from joserfc.jwk import KeySet from urllib.request import urlopen from fastapi import Depends, HTTPException, Path from fastapi.security import OpenIdConnect from sqlalchemy.sql import exists from src.auth.config import auth_settings from src.user.service import add_user_to_db from src.organisation.models import OrgUsers, Organisation as Org from src.database import db_dependency oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc_dependency = Annotated[str, Depends(oidc)] async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]: config_url = urlopen(auth_settings.OIDC_CONFIG) config = json.loads(config_url.read()) jwks_uri = config["jwks_uri"] key_response = requests.get(jwks_uri) jwk_keys = KeySet.import_key_set(key_response.json()) claims_options = { "exp": {"essential": True}, "aud": {"essential": True, "value": "account"}, "iss": {"essential": True, "value": auth_settings.OIDC_ISSUER}, } token = jwt.decode( oidc_auth_string.replace("Bearer ", ""), jwk_keys ) claims_requests = jwt.JWTClaimsRegistry(**claims_options) claims_requests.validate(token.claims) db_id = await add_user_to_db(token.claims) token.claims["db_id"] = db_id return token.claims claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): org_exists = db.query(exists().where(Org.id == org_id)).scalar() if not org_exists: raise HTTPException(status_code=404, detail="Organisation not found") db_id = claims.get("db_id", None) if db_id is None: raise HTTPException(status_code=404, detail="User not found in db") exists_query = (db.query(OrgUsers) .filter(OrgUsers.org_id == org_id, OrgUsers.user_id == db_id ).exists() ) org_user_exists = db.query(exists_query).scalar() if not org_user_exists: raise HTTPException(status_code=401, detail="Not authorised") return org_user_exists org_user_dependency = Annotated[dict[str, Any], Depends(is_org_user)] async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): org_exists = db.query(exists().where(Org.id == org_id)).scalar() if not org_exists: raise HTTPException(status_code=404, detail="Organisation not found") db_id = claims.get("db_id", None) if db_id is None: raise HTTPException(status_code=404, detail="User not found in db") exists_query = (db.query(OrgUsers) .filter(OrgUsers.org_id == org_id, OrgUsers.user_id == db_id, OrgUsers.is_admin == True ).exists() ) org_admin_exists = db.query(exists_query).scalar() if not org_admin_exists: raise HTTPException(status_code=401, detail="Not authorised") return org_admin_exists org_admin_dependency = Annotated[dict[str, Any], Depends(is_org_admin)] async def is_super_admin(claims: claims_dependency): super_admin_ids = [] db_id = claims.get("db_id", None) if db_id is None: raise HTTPException(status_code=404, detail="User not found in db") if db_id not in super_admin_ids: raise HTTPException(status_code=401, detail="Not authorised") return True super_admin_dependency = Annotated[dict[str, Any], Depends(is_super_admin)] async def is_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): try: await is_super_admin(claims) return True except HTTPException as e: pass try: await is_org_admin(claims, db, org_id) return True except HTTPException as e: raise HTTPException(status_code=401, detail="Not authorised") org_or_super_admin_dependency = Annotated[dict[str, Any], Depends(is_admin)] # Middleware version of user auth # import json # import logging # # from threading import Timer # from urllib.request import urlopen # from starlette.requests import HTTPConnection, Request # # from authlib.jose.rfc7517.jwk import JsonWebKey # from authlib.jose.rfc7517.key_set import KeySet # from authlib.oauth2 import OAuth2Error, ResourceProtector # from authlib.oauth2.rfc6749 import MissingAuthorizationError # from authlib.oauth2.rfc7523 import JWTBearerTokenValidator # from authlib.oauth2.rfc7523.validator import JWTBearerToken # # from starlette.authentication import ( # AuthCredentials, # AuthenticationBackend, # AuthenticationError, # SimpleUser, # ) # # logger = logging.getLogger(__name__) # # # class RepeatTimer(Timer): # def __init__(self, *args, **kwargs) -> None: # super().__init__(*args, **kwargs) # self.daemon = True # # def run(self): # while not self.finished.wait(self.interval): # self.function(*self.args, **self.kwargs) # # # class BearerTokenValidator(JWTBearerTokenValidator): # def __init__(self, issuer: str, audience: str): # self._issuer = issuer # self._jwks_uri: str | None = None # super().__init__(public_key=self.fetch_key(), issuer=issuer) # self.claims_options = { # "exp": {"essential": True}, # "aud": {"essential": True, "value": audience}, # "iss": {"essential": True, "value": issuer}, # } # self._timer = RepeatTimer(3600, self.refresh) # self._timer.start() # # def refresh(self): # try: # self.public_key = self.fetch_key() # except Exception as exc: # logger.warning(f"Could not update jwks public key: {exc}") # # def fetch_key(self) -> KeySet: # """Fetch the jwks_uri document and return the KeySet.""" # response = urlopen(self.jwks_uri) # logger.debug(f"OK GET {self.jwks_uri}") # return JsonWebKey.import_key_set(json.loads(response.read())) # # @property # def jwks_uri(self) -> str: # """The jwks_uri field of the openid-configuration document.""" # if self._jwks_uri is None: # config_url = urlopen(f"{self._issuer}/.well-known/openid-configuration") # config = json.loads(config_url.read()) # self._jwks_uri = config["jwks_uri"] # return self._jwks_uri # # # class BearerTokenAuthBackend(AuthenticationBackend): # def __init__(self, issuer: str, audience: str) -> None: # rp = ResourceProtector() # validator = BearerTokenValidator( # issuer=issuer, # audience=audience, # ) # rp.register_token_validator(validator) # self.resource_protector = rp # # async def authenticate(self, conn: HTTPConnection): # if "Authorization" not in conn.headers: # return # request = Request(conn.scope) # try: # token: JWTBearerToken = self.resource_protector.validate_request( # scopes=["openid"], # request=request, # ) # except (MissingAuthorizationError, OAuth2Error) as error: # raise AuthenticationError(error.description) from error # scope: str = token.get_scope() # scopes = scope.split() # scopes.append("authenticated") # return AuthCredentials(scopes=scopes), SimpleUser(username=token["email"])