fix(user): simplify add_user

This commit is contained in:
Iain Learmonth 2026-06-22 12:21:02 +01:00
parent 11eeddb347
commit 78508ae59f
2 changed files with 10 additions and 17 deletions

View file

@ -22,7 +22,7 @@ from src.organisation.exceptions import AwaitingApprovalException
from src.organisation.models import Organisation as Org
from src.exceptions import UnauthorizedException, ForbiddenException
from src.auth.config import auth_settings
from src.user.service import add_user_to_db
from src.user.service import add_user
from src.database import db_dependency
@ -53,7 +53,7 @@ async def get_current_user(
claims_requests.validate(token.claims)
except ExpiredTokenError:
raise UnauthorizedException(message="Token is expired")
db_id = await add_user_to_db(db, token.claims)
db_id = await add_user(db, token.claims)
token.claims["db_id"] = db_id

View file

@ -17,7 +17,7 @@ from src.user.schemas import OIDCUser
from src.user.models import User
async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
async def add_user(db: Session, user_claims: dict[str, Any]) -> int:
try:
valid_user = OIDCUser(
first_name=user_claims["given_name"],
@ -26,7 +26,7 @@ async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
oidc_id=user_claims["sub"],
)
except Exception as e:
print(e)
logging.exception(e)
raise UnprocessableContentException("Invalid or missing OIDC data")
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
@ -37,19 +37,12 @@ async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
user_id = user_model.id
db.commit()
return user_id
else:
user_id = db_user.id
change = False
if db_user.first_name != valid_user.first_name:
db_user.first_name = valid_user.first_name
change = True
if db_user.last_name != valid_user.last_name:
db_user.last_name = valid_user.last_name
change = True
if change:
db.add(db_user)
db.commit()
return user_id
user_id = db_user.id
db_user.first_name = valid_user.first_name
db_user.last_name = valid_user.last_name
db.commit()
return user_id
async def send_invitation(user_email: str, org_name: str, org_id: int):