From a9e539ef747c1775681f31dcabc4b8a1598473af Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 12:21:02 +0100 Subject: [PATCH] fix(user): simplify add_user --- src/auth/service.py | 4 ++-- src/user/service.py | 26 ++++++++------------------ 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/src/auth/service.py b/src/auth/service.py index 1b90b8c..2675dc2 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -22,7 +22,7 @@ from src.organisation.exceptions import AwaitingApprovalException from src.organisation.models import Organisation as Org from src.exceptions import UnauthorizedException, ForbiddenException from src.auth.config import auth_settings -from src.user.service import add_user_to_db +from src.user.service import add_user from src.database import db_dependency @@ -53,7 +53,7 @@ async def get_current_user( claims_requests.validate(token.claims) except ExpiredTokenError: raise UnauthorizedException(message="Token is expired") - db_id = await add_user_to_db(db, token.claims) + db_id = await add_user(db, token.claims) token.claims["db_id"] = db_id diff --git a/src/user/service.py b/src/user/service.py index ff1da8b..b70056f 100644 --- a/src/user/service.py +++ b/src/user/service.py @@ -1,8 +1,5 @@ """ Module specific business logic for user module - -Exports: - - add_user_to_db: Creates a User record from OIDC claims, or updates user details """ from typing import Any @@ -17,7 +14,7 @@ from src.user.schemas import OIDCUser from src.user.models import User -async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: +async def add_user(db: Session, user_claims: dict[str, Any]) -> int: try: valid_user = OIDCUser( first_name=user_claims["given_name"], @@ -26,7 +23,7 @@ async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: oidc_id=user_claims["sub"], ) except Exception as e: - print(e) + logging.exception(e) raise UnprocessableContentException("Invalid or missing OIDC data") db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first() @@ -37,19 +34,12 @@ async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: user_id = user_model.id db.commit() return user_id - else: - user_id = db_user.id - change = False - if db_user.first_name != valid_user.first_name: - db_user.first_name = valid_user.first_name - change = True - if db_user.last_name != valid_user.last_name: - db_user.last_name = valid_user.last_name - change = True - if change: - db.add(db_user) - db.commit() - return user_id + + user_id = db_user.id + db_user.first_name = valid_user.first_name + db_user.last_name = valid_user.last_name + db.commit() + return user_id async def send_invitation(user_email: str, org_name: str, org_id: int):