fix: use dependency instead of db.next

This commit is contained in:
Chris Milne 2026-05-29 14:15:50 +01:00
parent 2d60b4fcc5
commit 1a81be210a
2 changed files with 10 additions and 7 deletions

View file

@ -19,6 +19,7 @@ from fastapi.security import OpenIdConnect
from src.auth.exceptions import UnauthorizedException from src.auth.exceptions import UnauthorizedException
from src.auth.config import auth_settings from src.auth.config import auth_settings
from src.user.service import add_user_to_db from src.user.service import add_user_to_db
from src.database import db_dependency
oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG)
@ -28,7 +29,7 @@ def get_dev_user():
return {"db_id": 1} return {"db_id": 1}
async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]: async def get_current_user(oidc_auth_string: oidc_dependency, db: db_dependency) -> dict[str, Any]:
config_url = urlopen(auth_settings.OIDC_CONFIG) config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read()) config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"] jwks_uri = config["jwks_uri"]
@ -51,7 +52,7 @@ async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]:
claims_requests.validate(token.claims) claims_requests.validate(token.claims)
except ExpiredTokenError: except ExpiredTokenError:
raise UnauthorizedException(message="Token is expired") raise UnauthorizedException(message="Token is expired")
db_id = await add_user_to_db(token.claims) db_id = await add_user_to_db(db, token.claims)
token.claims["db_id"] = db_id token.claims["db_id"] = db_id

View file

@ -6,29 +6,31 @@ Exports:
""" """
from typing import Any from typing import Any
from src.database import get_db from sqlalchemy.orm import Session
from src.exceptions import UnprocessableContentException from src.exceptions import UnprocessableContentException
from src.user.schemas import OIDCUser from src.user.schemas import OIDCUser
from src.user.models import User from src.user.models import User
async def add_user_to_db(user_claims: dict[str, Any]) -> int: async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
try: try:
valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"]) valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"])
except Exception as e: except Exception as e:
print(e) print(e)
raise UnprocessableContentException("Invalid or missing OIDC data") raise UnprocessableContentException("Invalid or missing OIDC data")
db = next(get_db())
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first() db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
if not db_user: if not db_user:
user_model = User(**valid_user.model_dump()) user_model = User(**valid_user.model_dump())
db.add(user_model) db.add(user_model)
user_id = user_model.id
db.commit() db.commit()
return user_model.id return user_id
else: else:
user_id = db_user.id
change = False change = False
if db_user.first_name != valid_user.first_name: if db_user.first_name != valid_user.first_name:
db_user.first_name = valid_user.first_name db_user.first_name = valid_user.first_name
@ -39,4 +41,4 @@ async def add_user_to_db(user_claims: dict[str, Any]) -> int:
if change: if change:
db.add(db_user) db.add(db_user)
db.commit() db.commit()
return db_user.id return user_id