From 78508ae59fd94b682322618211060e51fd162380 Mon Sep 17 00:00:00 2001 From: irl Date: Mon, 22 Jun 2026 12:21:02 +0100 Subject: [PATCH] fix(user): simplify add_user --- src/auth/service.py | 4 ++-- src/user/service.py | 23 ++++++++--------------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/auth/service.py b/src/auth/service.py index 1b90b8c..2675dc2 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -22,7 +22,7 @@ from src.organisation.exceptions import AwaitingApprovalException from src.organisation.models import Organisation as Org from src.exceptions import UnauthorizedException, ForbiddenException from src.auth.config import auth_settings -from src.user.service import add_user_to_db +from src.user.service import add_user from src.database import db_dependency @@ -53,7 +53,7 @@ async def get_current_user( claims_requests.validate(token.claims) except ExpiredTokenError: raise UnauthorizedException(message="Token is expired") - db_id = await add_user_to_db(db, token.claims) + db_id = await add_user(db, token.claims) token.claims["db_id"] = db_id diff --git a/src/user/service.py b/src/user/service.py index ff1da8b..2174471 100644 --- a/src/user/service.py +++ b/src/user/service.py @@ -17,7 +17,7 @@ from src.user.schemas import OIDCUser from src.user.models import User -async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: +async def add_user(db: Session, user_claims: dict[str, Any]) -> int: try: valid_user = OIDCUser( first_name=user_claims["given_name"], @@ -26,7 +26,7 @@ async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: oidc_id=user_claims["sub"], ) except Exception as e: - print(e) + logging.exception(e) raise UnprocessableContentException("Invalid or missing OIDC data") db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first() @@ -37,19 +37,12 @@ async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: user_id = user_model.id db.commit() return user_id - else: - user_id = db_user.id - change = False - if db_user.first_name != valid_user.first_name: - db_user.first_name = valid_user.first_name - change = True - if db_user.last_name != valid_user.last_name: - db_user.last_name = valid_user.last_name - change = True - if change: - db.add(db_user) - db.commit() - return user_id + + user_id = db_user.id + db_user.first_name = valid_user.first_name + db_user.last_name = valid_user.last_name + db.commit() + return user_id async def send_invitation(user_email: str, org_name: str, org_id: int):