diff --git a/src/iam/service.py b/src/iam/service.py index e3c8740..4f23969 100644 --- a/src/iam/service.py +++ b/src/iam/service.py @@ -8,11 +8,15 @@ Exports: from typing import Annotated from datetime import datetime, timedelta, timezone from fastapi import Request, Depends +from sqlalchemy.orm import Session from src.database import db_dependency from src.exceptions import UnauthorizedException from src.utils import send_email, generate_jwt from src.iam.models import Group +from src.organisation.models import Organisation as Org +from src.user.models import User +from src.iam.models import Permission as Perm from src.service.models import Service from src.service.schemas import HasServiceName @@ -66,47 +70,43 @@ async def send_user_group_invitation( ) -async def create_default_user_group(db: db_dependency, org_model): - new_group = Group(name="Default Users", org_id=org_model.id) +async def create_group_and_assign_perms( + db: Session, org_model: Org, group_name: str, perm_list: list[int] +): + new_group = Group(name=group_name, org_id=org_model.id) db.add(new_group) db.flush() - # Grant default permissions here - db.flush() + + for permission in perm_list: + perm_model = db.get(Perm, permission) + + if perm_model is None: + continue + + new_group.permission_rel.append(perm_model) + db.flush() + return new_group -async def assign_default_user_group(db: db_dependency, org_model, user_model): - group_model = None - for group in org_model.group_rel: - if group.name == "Default Users": - group_model = group - break +async def assign_default_group( + db: db_dependency, + org_model: Org, + user_model: User, + group_name: str, + perm_list: list[int], +): + group_model = ( + db.query(Group) + .filter(Group.org_id == org_model.id) + .filter(Group.name == group_name) + .first() + ) if group_model is None: - group_model = await create_default_user_group(db=db, org_model=org_model) - - user_model.group_rel.append(group_model) - db.flush() - - -async def create_default_root_group(db: db_dependency, org_model): - new_group = Group(name="Root User", org_id=org_model.id) - db.add(new_group) - db.flush() - # Grant default permissions here - db.flush() - return new_group - - -async def assign_default_root_group(db: db_dependency, org_model, user_model): - group_model = None - for group in org_model.group_rel: - if group.name == "Root User": - group_model = group - break - - if group_model is None: - group_model = await create_default_root_group(db=db, org_model=org_model) + group_model = await create_group_and_assign_perms( + db=db, group_name=group_name, org_model=org_model, perm_list=perm_list + ) user_model.group_rel.append(group_model) db.flush() diff --git a/src/organisation/router.py b/src/organisation/router.py index 641265c..129e48c 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -34,9 +34,8 @@ from src.contact.models import Contact from src.contact.schemas import ContactAddress from src.contact.exceptions import ContactNotFoundException from src.database import db_dependency -from src.iam.service import assign_default_user_group, assign_default_root_group from src.organisation.schemas_questionnaires import QuestionnaireQuestionsVersion0 -from src.organisation.service import add_default_org_permissions +from src.organisation.service import assign_defaults from src.user.dependencies import ( user_model_body_dependency, user_model_claims_dependency, @@ -47,6 +46,7 @@ from src.auth.dependencies import ( org_model_root_claim_query_dependency, org_model_root_claim_body_dependency, ) +from src.iam.models import Group from src.organisation.dependencies import ( org_model_body_dependency, @@ -189,9 +189,10 @@ async def create_org( org_model.user_rel.append(user_model) org_model.root_user_rel = user_model - # Creates default user and default root IAM groups and assigns them - await assign_default_user_group(db, org_model, user_model) - await assign_default_root_group(db, org_model, user_model) + background_tasks.add_task( + assign_defaults, db, org_id=org_model.id, user_id=user_model.id + ) + for contact_type in [ "billing_contact_id", "security_contact_id", @@ -202,7 +203,6 @@ async def create_org( db.flush() org_model.__setattr__(contact_type, contact_model.id) response = OrgPostOrgResponse(**org_model.__dict__) - background_tasks.add_task(add_default_org_permissions, db, org_model.id) db.commit() return response @@ -357,7 +357,14 @@ async def add_user_to_org( raise ConflictException(message="User already a part of this organisation") org_model.user_rel.append(user_model) db.flush() - await assign_default_user_group(db=db, org_model=org_model, user_model=user_model) + group_model = ( + db.query(Group) + .filter(Group.org_id == org_model.id) + .filter(Group.name == "Default Users") + .first() + ) + if group_model is not None: + user_model.group_rel.append(group_model) response = { "organisation": org_model, "users": [{"id": user.id, "email": user.email} for user in org_model.user_rel], diff --git a/src/organisation/service.py b/src/organisation/service.py index 01b36e9..4fc2f03 100644 --- a/src/organisation/service.py +++ b/src/organisation/service.py @@ -3,27 +3,20 @@ Reusable business logic functions for the organisation module """ from sqlalchemy.orm import Session +from typing import cast + +from src.iam.service import assign_default_group from src.organisation.models import Organisation as Org from src.iam.models import Permission as Perm +from src.user.models import User async def add_default_org_permissions( db: Session, - org_id: int, + org_model: Org, + perm_list: list[int], ): - default_org_permissions = [ - 1, # test_service res_one read - 2, # test_service res_one create - 10, # tor-bridge-service collector read - 13, # tor-bridge-service samples read - ] - - org_model = db.get(Org, org_id) - if org_model is None: - print("Org not found while adding defaults") - return - - for permission in default_org_permissions: + for permission in perm_list: perm_model = db.get(Perm, permission) if perm_model is None: @@ -36,3 +29,43 @@ async def add_default_org_permissions( db.flush() db.commit() + + +async def assign_defaults( + db: Session, + org_id: int, + user_id: int, +): + default_org_permissions = [] + + default_user_permissions = [] + + org_model = db.get(Org, org_id) + if org_model is None: + print("Org not found while adding defaults") + return + + user_model = db.get(User, user_id) + if user_model is None: + print("User not found while adding defaults") + return + + org_model = cast(Org, org_model) + user_model = cast(User, user_model) + + await add_default_org_permissions(db, org_model, default_org_permissions) + await assign_default_group( + db=db, + org_model=org_model, + user_model=user_model, + group_name="Default Users", + perm_list=default_user_permissions, + ) + await assign_default_group( + db=db, + org_model=org_model, + user_model=user_model, + group_name="Root User", + perm_list=default_org_permissions, + ) + db.commit() diff --git a/src/user/router.py b/src/user/router.py index c4b5379..7aecc12 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -10,7 +10,7 @@ Endpoints: from fastapi import APIRouter, status, BackgroundTasks -from src.iam.service import assign_default_user_group +from src.iam.models import Group from src.organisation.exceptions import OrgNotFoundException from src.user.schemas import ( UserResponse, @@ -200,7 +200,14 @@ async def accept_invitation( org_model.user_rel.append(user_model) db.flush() - await assign_default_user_group(db=db, org_model=org_model, user_model=user_model) + group_model = ( + db.query(Group) + .filter(Group.org_id == org_model.id) + .filter(Group.name == "Default Users") + .first() + ) + if group_model is not None: + user_model.group_rel.append(group_model) response = { "organisation": org_model,