1
0
Fork 0
forked from sr2/cloud-api
cloud-api/src/auth/service.py
luxferre 092e12a892 feat: org status check moved
Accessing endpoints as super admin no longer requires the org to be approved.
2026-06-12 14:50:32 +01:00

89 lines
2.5 KiB
Python

"""
Module specific business logic for the auth module
Exports:
- claims_dependency: Dict[str, Any] containing OIDC claims and database ID
"""
import json
import requests
from typing import Annotated, Any
from joserfc import jwt
from joserfc.errors import ExpiredTokenError
from joserfc.jwk import KeySet
from urllib.request import urlopen
from fastapi import Depends, Request
from fastapi.security import OpenIdConnect
from src.organisation.constants import Status as OrgStatus
from src.organisation.exceptions import AwaitingApprovalException
from src.organisation.models import Organisation as Org
from src.exceptions import UnauthorizedException, ForbiddenException
from src.auth.config import auth_settings
from src.user.service import add_user_to_db
from src.database import db_dependency
oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG)
oidc_dependency = Annotated[str, Depends(oidc)]
async def get_dev_user():
return {"db_id": 1, "email": "chris@sr2.uk"}
async def get_current_user(
oidc_auth_string: oidc_dependency, db: db_dependency
) -> dict[str, Any]:
config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"]
key_response = requests.get(jwks_uri)
jwk_keys = KeySet.import_key_set(key_response.json())
claims_options = {
"exp": {"essential": True},
"iss": {"essential": True, "value": auth_settings.OIDC_ISSUER},
}
token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys)
claims_requests = jwt.JWTClaimsRegistry(**claims_options)
try:
claims_requests.validate(token.claims)
except ExpiredTokenError:
raise UnauthorizedException(message="Token is expired")
db_id = await add_user_to_db(db, token.claims)
token.claims["db_id"] = db_id
return token.claims
claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)]
async def org_status_check(org_model: Org, request: Request):
org_status = OrgStatus(org_model.status)
if org_status.is_blocked:
raise ForbiddenException("This organisation cannot perform this action.")
root = "/api/v1"
pre_approval_endpoints = [
f"PATCH{root}/org/status",
f"PATCH{root}/org/questionnaire",
f"GET{root}/org",
f"GET{root}/org/contact",
f"PATCH{root}/org/contact",
f"DELETE{root}/org/self",
]
current_request = f"{request.method}{request.url.path}"
if (
current_request not in pre_approval_endpoints
and org_model.status != OrgStatus.APPROVED
):
raise AwaitingApprovalException(org_model.id)