diff --git a/src/auth/service.py b/src/auth/service.py index 40a696b..6f9cc97 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -20,7 +20,9 @@ from sqlalchemy.sql import exists from src.auth.config import auth_settings from src.user.service import add_user_to_db from src.organisation.models import OrgUsers, Organisation as Org +from src.user.models import User from src.database import db_dependency +from src.organisation.dependencies import org_model_dependency oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) @@ -88,6 +90,20 @@ async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int org_user_dependency = Annotated[dict[str, Any], Depends(is_org_user)] +async def is_org_root(claims: claims_dependency, db: db_dependency, org_model: org_model_dependency, org_id: int = Path(gt=0)): + db_id = claims.get("db_id", None) + if db_id is None: + raise HTTPException(status_code=404, detail="User not found in db") + + if org_model.root_user_id == db_id: + return db.query(User).filter(User.id == db_id).first() + + raise HTTPException(status_code=401, detail="Not authorised") + + +root_user_dependency = Annotated[dict[str, Any], Depends(is_org_root)] + + async def is_super_admin(claims: claims_dependency): super_admin_ids = []