Compare commits

..

3 commits

5 changed files with 26 additions and 12 deletions

View file

@ -19,6 +19,7 @@ from fastapi.security import OpenIdConnect
from src.auth.exceptions import UnauthorizedException from src.auth.exceptions import UnauthorizedException
from src.auth.config import auth_settings from src.auth.config import auth_settings
from src.user.service import add_user_to_db from src.user.service import add_user_to_db
from src.database import db_dependency
oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG)
@ -28,7 +29,7 @@ def get_dev_user():
return {"db_id": 1} return {"db_id": 1}
async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]: async def get_current_user(oidc_auth_string: oidc_dependency, db: db_dependency) -> dict[str, Any]:
config_url = urlopen(auth_settings.OIDC_CONFIG) config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read()) config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"] jwks_uri = config["jwks_uri"]
@ -51,7 +52,7 @@ async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]:
claims_requests.validate(token.claims) claims_requests.validate(token.claims)
except ExpiredTokenError: except ExpiredTokenError:
raise UnauthorizedException(message="Token is expired") raise UnauthorizedException(message="Token is expired")
db_id = await add_user_to_db(token.claims) db_id = await add_user_to_db(db, token.claims)
token.claims["db_id"] = db_id token.claims["db_id"] = db_id

View file

@ -48,6 +48,9 @@ _QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD))
SQLALCHEMY_DATABASE_URI = SecretStr(f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}") SQLALCHEMY_DATABASE_URI = SecretStr(f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}")
if settings.ENVIRONMENT == Environment.TESTING:
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")
app_configs: dict[str, Any] = {"title": "App API"} app_configs: dict[str, Any] = {"title": "App API"}
if settings.ENVIRONMENT.is_deployed: if settings.ENVIRONMENT.is_deployed:
app_configs["root_path"] = f"/v{settings.APP_VERSION}" app_configs["root_path"] = f"/v{settings.APP_VERSION}"

View file

@ -7,13 +7,19 @@ Exports:
""" """
from typing import Annotated from typing import Annotated
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import declarative_base, sessionmaker, Session from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session
from fastapi import Depends from fastapi import Depends
from src.config import SQLALCHEMY_DATABASE_URI from src.constants import Environment
from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings
engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value()) if global_settings.ENVIRONMENT == Environment.TESTING:
connect_args = {"check_same_thread": False}
else:
connect_args = {}
engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value(), connect_args=connect_args)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@ -30,4 +36,5 @@ def get_db():
db_dependency = Annotated[Session, Depends(get_db)] db_dependency = Annotated[Session, Depends(get_db)]
Base = declarative_base() class Base(DeclarativeBase):
pass

View file

@ -8,6 +8,7 @@ from fastapi import FastAPI
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from src.constants import Environment
from src.config import settings from src.config import settings
from src.api import api_router from src.api import api_router
@ -66,7 +67,7 @@ app.add_middleware(
allow_headers=settings.CORS_HEADERS, allow_headers=settings.CORS_HEADERS,
) )
if settings.ENVIRONMENT == "local" and settings.DISABLE_AUTH: if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL or settings.ENVIRONMENT == Environment.TESTING):
app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_current_user] = get_dev_user

View file

@ -6,29 +6,31 @@ Exports:
""" """
from typing import Any from typing import Any
from src.database import get_db from sqlalchemy.orm import Session
from src.exceptions import UnprocessableContentException from src.exceptions import UnprocessableContentException
from src.user.schemas import OIDCUser from src.user.schemas import OIDCUser
from src.user.models import User from src.user.models import User
async def add_user_to_db(user_claims: dict[str, Any]) -> int: async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
try: try:
valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"]) valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"])
except Exception as e: except Exception as e:
print(e) print(e)
raise UnprocessableContentException("Invalid or missing OIDC data") raise UnprocessableContentException("Invalid or missing OIDC data")
db = next(get_db())
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first() db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
if not db_user: if not db_user:
user_model = User(**valid_user.model_dump()) user_model = User(**valid_user.model_dump())
db.add(user_model) db.add(user_model)
user_id = user_model.id
db.commit() db.commit()
return user_model.id return user_id
else: else:
user_id = db_user.id
change = False change = False
if db_user.first_name != valid_user.first_name: if db_user.first_name != valid_user.first_name:
db_user.first_name = valid_user.first_name db_user.first_name = valid_user.first_name
@ -39,4 +41,4 @@ async def add_user_to_db(user_claims: dict[str, Any]) -> int:
if change: if change:
db.add(db_user) db.add(db_user)
db.commit() db.commit()
return db_user.id return user_id