cloud-api/src/auth/service.py

226 lines
6.6 KiB
Python
Raw Normal View History

2026-04-06 12:41:49 +01:00
"""
Module specific business logic for auth module
Exports:
- claims_dependency
"""
import json
from typing import Annotated
from authlib.jose import jwt
from urllib.request import urlopen
from fastapi import Depends, HTTPException, Path
from fastapi.security import OpenIdConnect
from authlib.jose.rfc7517.jwk import JsonWebKey
from authlib.jose.rfc7517.key_set import KeySet
from authlib.oauth2.rfc7523.validator import JWTBearerToken
from sqlalchemy.sql import exists
from src.auth.config import auth_settings
from src.user.service import add_user_to_db
from src.organisation.models import OrgUsers, Organisation as Org
from src.database import db_dependency
oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG)
oidc_dependency = Annotated[str, Depends(oidc)]
async def get_current_user(oidc_auth_string: oidc_dependency) -> JWTBearerToken:
config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"]
key_response = urlopen(jwks_uri)
jwk_keys: KeySet = JsonWebKey.import_key_set(json.loads(key_response.read()))
claims_options = {
"exp": {"essential": True},
"aud": {"essential": True, "value": "account"},
"iss": {"essential": True, "value": auth_settings.OIDC_ISSUER},
}
claims: JWTBearerToken = jwt.decode(
oidc_auth_string.replace("Bearer ", ""),
jwk_keys,
claims_options=claims_options,
claims_cls=JWTBearerToken,
)
claims.validate()
db_id = await add_user_to_db(claims)
claims["db_id"] = db_id
return claims
claims_dependency = Annotated[JWTBearerToken, Depends(get_current_user)]
async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)):
org_exists = db.query(exists().where(Org.id == org_id)).scalar()
if not org_exists:
raise HTTPException(status_code=404, detail="Organisation not found")
db_id = claims.get("db_id", None)
if db_id is None:
raise HTTPException(status_code=404, detail="User not found in db")
exists_query = (db.query(OrgUsers)
.filter(OrgUsers.org_id == org_id,
OrgUsers.user_id == db_id
).exists()
)
org_user_exists = db.query(exists_query).scalar()
if not org_user_exists:
raise HTTPException(status_code=401, detail="Not authorised")
return org_user_exists
org_user_dependency = Annotated[JWTBearerToken, Depends(is_org_user)]
async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)):
org_exists = db.query(exists().where(Org.id == org_id)).scalar()
if not org_exists:
raise HTTPException(status_code=404, detail="Organisation not found")
db_id = claims.get("db_id", None)
if db_id is None:
raise HTTPException(status_code=404, detail="User not found in db")
exists_query = (db.query(OrgUsers)
.filter(OrgUsers.org_id == org_id,
OrgUsers.user_id == db_id,
OrgUsers.is_admin == True
).exists()
)
org_admin_exists = db.query(exists_query).scalar()
if not org_admin_exists:
raise HTTPException(status_code=401, detail="Not authorised")
return org_admin_exists
org_admin_dependency = Annotated[JWTBearerToken, Depends(is_org_admin)]
async def is_super_admin(claims: claims_dependency):
super_admin_ids = []
db_id = claims.get("db_id", None)
if db_id is None:
raise HTTPException(status_code=404, detail="User not found in db")
if db_id not in super_admin_ids:
raise HTTPException(status_code=401, detail="Not authorised")
return True
super_admin_dependency = Annotated[JWTBearerToken, Depends(is_super_admin)]
# Middleware version of user auth
# import json
# import logging
#
# from threading import Timer
# from urllib.request import urlopen
# from starlette.requests import HTTPConnection, Request
#
# from authlib.jose.rfc7517.jwk import JsonWebKey
# from authlib.jose.rfc7517.key_set import KeySet
# from authlib.oauth2 import OAuth2Error, ResourceProtector
# from authlib.oauth2.rfc6749 import MissingAuthorizationError
# from authlib.oauth2.rfc7523 import JWTBearerTokenValidator
# from authlib.oauth2.rfc7523.validator import JWTBearerToken
#
# from starlette.authentication import (
# AuthCredentials,
# AuthenticationBackend,
# AuthenticationError,
# SimpleUser,
# )
#
# logger = logging.getLogger(__name__)
#
#
# class RepeatTimer(Timer):
# def __init__(self, *args, **kwargs) -> None:
# super().__init__(*args, **kwargs)
# self.daemon = True
#
# def run(self):
# while not self.finished.wait(self.interval):
# self.function(*self.args, **self.kwargs)
#
#
# class BearerTokenValidator(JWTBearerTokenValidator):
# def __init__(self, issuer: str, audience: str):
# self._issuer = issuer
# self._jwks_uri: str | None = None
# super().__init__(public_key=self.fetch_key(), issuer=issuer)
# self.claims_options = {
# "exp": {"essential": True},
# "aud": {"essential": True, "value": audience},
# "iss": {"essential": True, "value": issuer},
# }
# self._timer = RepeatTimer(3600, self.refresh)
# self._timer.start()
#
# def refresh(self):
# try:
# self.public_key = self.fetch_key()
# except Exception as exc:
# logger.warning(f"Could not update jwks public key: {exc}")
#
# def fetch_key(self) -> KeySet:
# """Fetch the jwks_uri document and return the KeySet."""
# response = urlopen(self.jwks_uri)
# logger.debug(f"OK GET {self.jwks_uri}")
# return JsonWebKey.import_key_set(json.loads(response.read()))
#
# @property
# def jwks_uri(self) -> str:
# """The jwks_uri field of the openid-configuration document."""
# if self._jwks_uri is None:
# config_url = urlopen(f"{self._issuer}/.well-known/openid-configuration")
# config = json.loads(config_url.read())
# self._jwks_uri = config["jwks_uri"]
# return self._jwks_uri
#
#
# class BearerTokenAuthBackend(AuthenticationBackend):
# def __init__(self, issuer: str, audience: str) -> None:
# rp = ResourceProtector()
# validator = BearerTokenValidator(
# issuer=issuer,
# audience=audience,
# )
# rp.register_token_validator(validator)
# self.resource_protector = rp
#
# async def authenticate(self, conn: HTTPConnection):
# if "Authorization" not in conn.headers:
# return
# request = Request(conn.scope)
# try:
# token: JWTBearerToken = self.resource_protector.validate_request(
# scopes=["openid"],
# request=request,
# )
# except (MissingAuthorizationError, OAuth2Error) as error:
# raise AuthenticationError(error.description) from error
# scope: str = token.get_scope()
# scopes = scope.split()
# scopes.append("authenticated")
# return AuthCredentials(scopes=scopes), SimpleUser(username=token["email"])