1
0
Fork 0
forked from sr2/cloud-api

fix(user): simplify add_user

This commit is contained in:
Iain Learmonth 2026-06-22 12:21:02 +01:00
parent 11eeddb347
commit a9e539ef74
2 changed files with 10 additions and 20 deletions

View file

@ -22,7 +22,7 @@ from src.organisation.exceptions import AwaitingApprovalException
from src.organisation.models import Organisation as Org
from src.exceptions import UnauthorizedException, ForbiddenException
from src.auth.config import auth_settings
from src.user.service import add_user_to_db
from src.user.service import add_user
from src.database import db_dependency
@ -53,7 +53,7 @@ async def get_current_user(
claims_requests.validate(token.claims)
except ExpiredTokenError:
raise UnauthorizedException(message="Token is expired")
db_id = await add_user_to_db(db, token.claims)
db_id = await add_user(db, token.claims)
token.claims["db_id"] = db_id

View file

@ -1,8 +1,5 @@
"""
Module specific business logic for user module
Exports:
- add_user_to_db: Creates a User record from OIDC claims, or updates user details
"""
from typing import Any
@ -17,7 +14,7 @@ from src.user.schemas import OIDCUser
from src.user.models import User
async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
async def add_user(db: Session, user_claims: dict[str, Any]) -> int:
try:
valid_user = OIDCUser(
first_name=user_claims["given_name"],
@ -26,7 +23,7 @@ async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
oidc_id=user_claims["sub"],
)
except Exception as e:
print(e)
logging.exception(e)
raise UnprocessableContentException("Invalid or missing OIDC data")
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
@ -37,17 +34,10 @@ async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
user_id = user_model.id
db.commit()
return user_id
else:
user_id = db_user.id
change = False
if db_user.first_name != valid_user.first_name:
db_user.first_name = valid_user.first_name
change = True
if db_user.last_name != valid_user.last_name:
db_user.last_name = valid_user.last_name
change = True
if change:
db.add(db_user)
db.commit()
return user_id