diff --git a/src/auth/dependencies.py b/src/auth/dependencies.py index 0001fb8..7fbb96d 100644 --- a/src/auth/dependencies.py +++ b/src/auth/dependencies.py @@ -8,4 +8,41 @@ Classes: Functions: - List: Description - Functions: Description -""" \ No newline at end of file +""" +from typing import Annotated, Any +from fastapi import Depends, HTTPException + +from src.user.dependencies import user_model_claims_dependency + +from src.organisation.dependencies import org_model_query_dependency + + +async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency): + if user_model in org_model.user_rel: + return True + + raise HTTPException(status_code=401, detail="Not authorised") + + +org_query_user_claims_dependency = Annotated[dict[str, Any], Depends(org_query_user_claims)] + + +async def org_query_root_claims(user_model: user_model_claims_dependency, org_model: org_model_query_dependency): + if org_model.root_user_id == user_model.id: + return True + + raise HTTPException(status_code=401, detail="Not authorised") + + +org_query_root_claims_dependency = Annotated[dict[str, Any], Depends(org_query_root_claims)] + + +async def is_super_admin(user_model: user_model_claims_dependency): + super_admin_emails = [] + if user_model.email not in super_admin_emails: + raise HTTPException(status_code=401, detail="Not authorised") + + return True + + +super_admin_dependency = Annotated[dict[str, Any], Depends(is_super_admin)] diff --git a/src/auth/service.py b/src/auth/service.py index 8f27902..60bd9c3 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -13,16 +13,11 @@ from joserfc.errors import ExpiredTokenError from joserfc.jwk import KeySet from urllib.request import urlopen -from fastapi import Depends, HTTPException, Path +from fastapi import Depends, HTTPException from fastapi.security import OpenIdConnect -from sqlalchemy.sql import exists from src.auth.config import auth_settings from src.user.service import add_user_to_db -from src.organisation.models import OrgUsers, Organisation as Org -from src.user.models import User -from src.database import db_dependency -from src.organisation.dependencies import org_model_query_dependency oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) @@ -65,58 +60,3 @@ async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]: claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] - - -async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)): - org_exists = db.query(exists().where(Org.id == org_id)).scalar() - if not org_exists: - raise HTTPException(status_code=404, detail="Organisation not found") - - db_id = claims.get("db_id", None) - if db_id is None: - raise HTTPException(status_code=404, detail="User not found in db") - - exists_query = (db.query(OrgUsers) - .filter(OrgUsers.org_id == org_id, - OrgUsers.user_id == db_id - ).exists() - ) - - org_user_exists = db.query(exists_query).scalar() - - if not org_user_exists: - raise HTTPException(status_code=401, detail="Not authorised") - - return org_user_exists - - -org_user_dependency = Annotated[dict[str, Any], Depends(is_org_user)] - - -async def is_org_root_query(claims: claims_dependency, db: db_dependency, org_model: org_model_query_dependency): - db_id = claims.get("db_id", None) - if db_id is None: - raise HTTPException(status_code=404, detail="User not found in db") - - if org_model.root_user_id == db_id: - return db.query(User).filter(User.id == db_id).first() - - raise HTTPException(status_code=401, detail="Not authorised") - - -root_user_query_dependency = Annotated[dict[str, Any], Depends(is_org_root_query)] - - -async def is_super_admin(claims: claims_dependency): - super_admin_ids = [] - - db_id = claims.get("db_id", None) - if db_id is None: - raise HTTPException(status_code=404, detail="User not found in db") - if db_id not in super_admin_ids: - raise HTTPException(status_code=401, detail="Not authorised") - - return True - - -super_admin_dependency = Annotated[dict[str, Any], Depends(is_super_admin)] diff --git a/src/organisation/router.py b/src/organisation/router.py index 6cf8d87..413e233 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -22,7 +22,7 @@ from src.database import db_dependency from src.contact.models import Contact from src.user.models import User from src.user.exceptions import UserNotFoundException -from src.auth.service import root_user_query_dependency, claims_dependency +from src.auth.service import claims_dependency from src.organisation.dependencies import org_model_query_dependency, org_model_body_dependency from src.organisation.constants import ContactType