Compare commits
10 commits
renovate/c
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 83a24a91f4 | |||
| 6871fcd75d | |||
| 34413b3fc5 | |||
| 26db93b769 | |||
| 7d84f33bfa | |||
| f54876eac6 | |||
| d89c926a38 | |||
| 6f4556a44b | |||
| 2b8296d622 | |||
| 34cb4414c9 |
13 changed files with 225 additions and 86 deletions
|
|
@ -18,3 +18,4 @@ httptools
|
|||
psycopg
|
||||
email-validator
|
||||
alembic
|
||||
joserfc
|
||||
|
|
|
|||
|
|
@ -5,20 +5,17 @@ Endpoints:
|
|||
- List: Description
|
||||
- Endpoints: Description
|
||||
"""
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.params import Path
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
from src.organisation.constants import ContactType
|
||||
from src.organisation.schemas import OrgContactGetResponse
|
||||
from src.organisation.models import Organisation as Org
|
||||
from src.contact.models import Contact
|
||||
from src.user.models import User
|
||||
from src.user.schemas import UserResponse, OIDCUser, OrgResponse
|
||||
|
||||
from src.organisation.models import OrgUsers, Organisation
|
||||
|
||||
from src.auth.service import claims_dependency, org_user_dependency, org_admin_dependency
|
||||
from src.auth.service import claims_dependency, org_or_super_admin_dependency
|
||||
from src.database import db_dependency
|
||||
|
||||
|
||||
|
|
@ -29,7 +26,7 @@ router = APIRouter(
|
|||
|
||||
|
||||
@router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse)
|
||||
async def get_contact(db: db_dependency, user: claims_dependency, is_org_admin: org_user_dependency, contact_type: ContactType, org_id: int = Path(gt=0)):
|
||||
async def get_contact(db: db_dependency, user: claims_dependency, is_admin: org_or_super_admin_dependency, contact_type: ContactType, org_id: Annotated[int, Path(gt=0)]):
|
||||
org_model = db.query(Org).filter(Org.id == org_id).first()
|
||||
if org_model is None:
|
||||
raise HTTPException(status_code=404, detail="Organisation not found")
|
||||
|
|
|
|||
|
|
@ -5,16 +5,16 @@ Exports:
|
|||
- claims_dependency
|
||||
"""
|
||||
import json
|
||||
import requests
|
||||
|
||||
from typing import Annotated
|
||||
from authlib.jose import jwt
|
||||
from typing import Annotated, Any
|
||||
from joserfc import jwt
|
||||
from joserfc.errors import ExpiredTokenError
|
||||
from joserfc.jwk import KeySet
|
||||
from urllib.request import urlopen
|
||||
|
||||
from fastapi import Depends, HTTPException, Path
|
||||
from fastapi.security import OpenIdConnect
|
||||
from authlib.jose.rfc7517.jwk import JsonWebKey
|
||||
from authlib.jose.rfc7517.key_set import KeySet
|
||||
from authlib.oauth2.rfc7523.validator import JWTBearerToken
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
from src.auth.config import auth_settings
|
||||
|
|
@ -27,12 +27,12 @@ oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG)
|
|||
oidc_dependency = Annotated[str, Depends(oidc)]
|
||||
|
||||
|
||||
async def get_current_user(oidc_auth_string: oidc_dependency) -> JWTBearerToken:
|
||||
async def get_current_user(oidc_auth_string: oidc_dependency) -> dict[str, Any]:
|
||||
config_url = urlopen(auth_settings.OIDC_CONFIG)
|
||||
config = json.loads(config_url.read())
|
||||
jwks_uri = config["jwks_uri"]
|
||||
key_response = urlopen(jwks_uri)
|
||||
jwk_keys: KeySet = JsonWebKey.import_key_set(json.loads(key_response.read()))
|
||||
key_response = requests.get(jwks_uri)
|
||||
jwk_keys = KeySet.import_key_set(key_response.json())
|
||||
|
||||
claims_options = {
|
||||
"exp": {"essential": True},
|
||||
|
|
@ -40,22 +40,26 @@ async def get_current_user(oidc_auth_string: oidc_dependency) -> JWTBearerToken:
|
|||
"iss": {"essential": True, "value": auth_settings.OIDC_ISSUER},
|
||||
}
|
||||
|
||||
claims: JWTBearerToken = jwt.decode(
|
||||
token = jwt.decode(
|
||||
oidc_auth_string.replace("Bearer ", ""),
|
||||
jwk_keys,
|
||||
claims_options=claims_options,
|
||||
claims_cls=JWTBearerToken,
|
||||
jwk_keys
|
||||
)
|
||||
|
||||
claims.validate()
|
||||
db_id = await add_user_to_db(claims)
|
||||
claims_requests = jwt.JWTClaimsRegistry(**claims_options)
|
||||
|
||||
claims["db_id"] = db_id
|
||||
try:
|
||||
claims_requests.validate(token.claims)
|
||||
except ExpiredTokenError as e:
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
|
||||
return claims
|
||||
db_id = await add_user_to_db(token.claims)
|
||||
|
||||
token.claims["db_id"] = db_id
|
||||
|
||||
return token.claims
|
||||
|
||||
|
||||
claims_dependency = Annotated[JWTBearerToken, Depends(get_current_user)]
|
||||
claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)]
|
||||
|
||||
|
||||
async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)):
|
||||
|
|
@ -81,7 +85,7 @@ async def is_org_user(claims: claims_dependency, db: db_dependency, org_id: int
|
|||
return org_user_exists
|
||||
|
||||
|
||||
org_user_dependency = Annotated[JWTBearerToken, Depends(is_org_user)]
|
||||
org_user_dependency = Annotated[dict[str, Any], Depends(is_org_user)]
|
||||
|
||||
|
||||
async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)):
|
||||
|
|
@ -108,7 +112,7 @@ async def is_org_admin(claims: claims_dependency, db: db_dependency, org_id: int
|
|||
return org_admin_exists
|
||||
|
||||
|
||||
org_admin_dependency = Annotated[JWTBearerToken, Depends(is_org_admin)]
|
||||
org_admin_dependency = Annotated[dict[str, Any], Depends(is_org_admin)]
|
||||
|
||||
|
||||
async def is_super_admin(claims: claims_dependency):
|
||||
|
|
@ -123,10 +127,22 @@ async def is_super_admin(claims: claims_dependency):
|
|||
return True
|
||||
|
||||
|
||||
super_admin_dependency = Annotated[JWTBearerToken, Depends(is_super_admin)]
|
||||
super_admin_dependency = Annotated[dict[str, Any], Depends(is_super_admin)]
|
||||
|
||||
|
||||
async def is_admin(claims: claims_dependency, db: db_dependency, org_id: int = Path(gt=0)):
|
||||
try:
|
||||
await is_super_admin(claims)
|
||||
return True
|
||||
except HTTPException as e:
|
||||
pass
|
||||
try:
|
||||
await is_org_admin(claims, db, org_id)
|
||||
return True
|
||||
except HTTPException as e:
|
||||
raise HTTPException(status_code=401, detail="Not authorised")
|
||||
|
||||
org_or_super_admin_dependency = Annotated[dict[str, Any], Depends(is_admin)]
|
||||
|
||||
# Middleware version of user auth
|
||||
# import json
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ Endpoints:
|
|||
- [patch]/{contact_id} - Updates the details of an existing contact
|
||||
- [delete]/{contact_id} - Deletes a contact by ID
|
||||
"""
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.params import Path
|
||||
|
||||
|
|
@ -55,7 +57,7 @@ async def create_contact(db: db_dependency, contact_request: ContactContactPostR
|
|||
|
||||
|
||||
@router.patch("/{contact_id}")
|
||||
async def update_contact(db: db_dependency, contact_request: ContactUpdateRequest, contact_id: int = Path(gt=0)):
|
||||
async def update_contact(db: db_dependency, contact_request: ContactUpdateRequest, contact_id: Annotated[int, Path(gt=0)]):
|
||||
contact_model = (db.query(Contact).filter(Contact.id == contact_id).first())
|
||||
if contact_model is None:
|
||||
raise HTTPException(status_code=404, detail="Contact not found")
|
||||
|
|
@ -72,7 +74,7 @@ async def update_contact(db: db_dependency, contact_request: ContactUpdateReques
|
|||
|
||||
|
||||
@router.delete("/{contact_id}")
|
||||
async def delete_contact(db: db_dependency, contact_id: int = Path(gt=0)):
|
||||
async def delete_contact(db: db_dependency, contact_id: Annotated[int, Path(gt=0)]):
|
||||
contact_model = (db.query(Contact).filter(Contact.id == contact_id).first())
|
||||
if contact_model is None:
|
||||
raise HTTPException(status_code=404, detail="Contact not found")
|
||||
|
|
@ -82,7 +84,7 @@ async def delete_contact(db: db_dependency, contact_id: int = Path(gt=0)):
|
|||
|
||||
|
||||
@router.get("/{contact_id}/orgs", response_model=list[ContactOrgGetResponse])
|
||||
async def get_contact_orgs(db: db_dependency, contact_id: int = Path(gt=0)):
|
||||
async def get_contact_orgs(db: db_dependency, contact_id: Annotated[int, Path(gt=0)]):
|
||||
contact_model = (db.query(Contact).filter(Contact.id == contact_id).first())
|
||||
if contact_model is None:
|
||||
raise HTTPException(status_code=404, detail="Contact not found")
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ Models:
|
|||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, EmailStr
|
||||
from pydantic import EmailStr
|
||||
|
||||
from src.organisation.constants import ContactType
|
||||
from src.schemas import CustomBaseModel
|
||||
|
|
|
|||
15
src/main.py
15
src/main.py
|
|
@ -26,15 +26,28 @@ if settings.ENVIRONMENT.is_deployed:
|
|||
pass
|
||||
|
||||
|
||||
tags_metadata = [
|
||||
{
|
||||
"name": "User",
|
||||
"description": "User related operations, includes getting information about the current user",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
swagger_ui_init_oauth={
|
||||
"clientId": auth_settings.CLIENT_ID,
|
||||
"usePkceWithAuthorizationCodeGrant": True,
|
||||
"scopes": "openid profile email",
|
||||
}
|
||||
},
|
||||
openapi_tags=tags_metadata,
|
||||
)
|
||||
|
||||
# Type inspection disabled for middleware injection.
|
||||
# Known bug in FastAPI type checking: https://github.com/astral-sh/ty/issues/1635
|
||||
# noinspection PyTypeChecker
|
||||
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value())
|
||||
# noinspection PyTypeChecker
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ Models:
|
|||
from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, JSON, false
|
||||
|
||||
from src.database import Base
|
||||
from src.contact.models import Contact
|
||||
from src.user.models import User
|
||||
|
||||
|
||||
class Organisation(Base):
|
||||
|
|
|
|||
|
|
@ -14,12 +14,13 @@ Endpoints:
|
|||
- [delete]/{org_id} - Deletes an organisation by ID
|
||||
- [get]/{org_id}/contact/{contact_type} - Retrieves the contact of a specific type (owner, billing, security) for an organisation
|
||||
"""
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.params import Path
|
||||
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
from src.auth.service import super_admin_dependency
|
||||
from src.database import db_dependency
|
||||
from src.contact.models import Contact
|
||||
|
||||
|
|
@ -36,7 +37,7 @@ router = APIRouter(
|
|||
|
||||
|
||||
@router.get("/id/{org_id}", response_model=OrgOrgGetResponse)
|
||||
async def get_org_by_id(db: db_dependency, org_id: int = Path(gt=0)):
|
||||
async def get_org_by_id(db: db_dependency, org_id: Annotated[int, Path(gt=0)]):
|
||||
org_model = (db.query(Org).filter(Org.id == org_id).first())
|
||||
if org_model is None:
|
||||
raise HTTPException(status_code=404, detail="Organisation not found")
|
||||
|
|
@ -63,7 +64,7 @@ async def create_org(db: db_dependency, org_request: OrgOrgPostRequest):
|
|||
|
||||
|
||||
@router.patch("/{org_id}/questionnaire")
|
||||
async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePatchRequest, org_id: int = Path(gt=0)):
|
||||
async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePatchRequest, org_id: Annotated[int, Path(gt=0)]):
|
||||
"""
|
||||
Route for updating questionnaire.
|
||||
The partial bool allows for submission of partially completed questionnaire and/or
|
||||
|
|
@ -84,7 +85,7 @@ async def update_questionnaire(db: db_dependency, q_request: OrgQuestionnairePat
|
|||
|
||||
|
||||
@router.patch("/{org_id}/status")
|
||||
async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest, org_id: int = Path(gt=0)):
|
||||
async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest, org_id: Annotated[int, Path(gt=0)]):
|
||||
org_model = db.query(Org).filter(Org.id == org_id).first()
|
||||
if org_model is None:
|
||||
raise HTTPException(status_code=404, detail="Organisation not found")
|
||||
|
|
@ -96,7 +97,7 @@ async def update_status(db: db_dependency, status_request: OrgStatusPatchRequest
|
|||
|
||||
|
||||
@router.patch("/{org_id}/contact")
|
||||
async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequest, org_id: int = Path(gt=0)):
|
||||
async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequest, org_id: Annotated[int, Path(gt=0)]):
|
||||
org_model = db.query(Org).filter(Org.id == org_id).first()
|
||||
if org_model is None:
|
||||
raise HTTPException(status_code=404, detail="Organisation not found")
|
||||
|
|
@ -116,7 +117,7 @@ async def update_contact(db: db_dependency, contact_request: OrgContactPatchRequ
|
|||
|
||||
|
||||
@router.get("/{org_id}/users", response_model=list[OrgUserGetResponse])
|
||||
async def get_users(db: db_dependency, org_id: int = Path(gt=0)):
|
||||
async def get_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)]):
|
||||
org_exists = db.query(exists().where(Org.id == org_id)).scalar()
|
||||
if not org_exists:
|
||||
raise HTTPException(status_code=404, detail="Organisation not found")
|
||||
|
|
@ -127,7 +128,7 @@ async def get_users(db: db_dependency, org_id: int = Path(gt=0)):
|
|||
|
||||
|
||||
@router.get("/{org_id}/users/admins", response_model=list[OrgUserGetResponse])
|
||||
async def get_admin_users(db: db_dependency, org_id: int = Path(gt=0)):
|
||||
async def get_admin_users(db: db_dependency, org_id: Annotated[int, Path(gt=0)]):
|
||||
org_exists = db.query(exists().where(Org.id == org_id)).scalar()
|
||||
if not org_exists:
|
||||
raise HTTPException(status_code=404, detail="Organisation not found")
|
||||
|
|
@ -138,7 +139,11 @@ async def get_admin_users(db: db_dependency, org_id: int = Path(gt=0)):
|
|||
|
||||
|
||||
@router.post("/{org_id}/users")
|
||||
async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, org_id: int = Path(gt=0)):
|
||||
async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, org_id: Annotated[int, Path(gt=0)]):
|
||||
org_model = (db.query(Org).filter(Org.id == org_id).first())
|
||||
if org_model is None:
|
||||
raise HTTPException(status_code=404, detail="Organisation not found")
|
||||
|
||||
org_user_model = OrgUsers(**user_request.model_dump(), org_id=org_id)
|
||||
|
||||
db.add(org_user_model)
|
||||
|
|
@ -146,11 +151,14 @@ async def add_user_to_org(db: db_dependency, user_request: OrgUserPostRequest, o
|
|||
|
||||
|
||||
@router.patch("/{org_id}/users")
|
||||
async def update_user_details(db: db_dependency, user_request: OrgUserPostRequest, org_id: int = Path(gt=0)):
|
||||
async def update_user_details(db: db_dependency, user_request: OrgUserPostRequest, org_id: Annotated[int, Path(gt=0)]):
|
||||
"""
|
||||
Currently used only to update user admin status for organisation.
|
||||
"""
|
||||
# TODO: Check if org exists
|
||||
org_model = (db.query(Org).filter(Org.id == org_id).first())
|
||||
if org_model is None:
|
||||
raise HTTPException(status_code=404, detail="Organisation not found")
|
||||
|
||||
org_user_model = db.query(OrgUsers).filter(OrgUsers.org_id == org_id).filter(OrgUsers.user_id == user_request.user_id).first()
|
||||
|
||||
if org_user_model is None:
|
||||
|
|
@ -164,7 +172,7 @@ async def update_user_details(db: db_dependency, user_request: OrgUserPostReques
|
|||
|
||||
|
||||
@router.delete("/{org_id}")
|
||||
async def delete_organisation_by_id(db: db_dependency, org_id: int = Path(gt=0)):
|
||||
async def delete_organisation_by_id(db: db_dependency, org_id: Annotated[int, Path(gt=0)]):
|
||||
org_model = (db.query(Org).filter(Org.id == org_id).first())
|
||||
if org_model is None:
|
||||
raise HTTPException(status_code=404, detail="Organisation not found")
|
||||
|
|
@ -173,10 +181,11 @@ async def delete_organisation_by_id(db: db_dependency, org_id: int = Path(gt=0))
|
|||
|
||||
|
||||
@router.get("/{org_id}/contact/{contact_type}", response_model=OrgContactGetResponse)
|
||||
async def get_contact(db: db_dependency, contact_type: ContactType, org_id: int = Path(gt=0)):
|
||||
async def get_contact(db: db_dependency, contact_type: ContactType, org_id: Annotated[int, Path(gt=0)]):
|
||||
org_model = db.query(Org).filter(Org.id == org_id).first()
|
||||
if org_model is None:
|
||||
raise HTTPException(status_code=404, detail="Organisation not found")
|
||||
|
||||
match contact_type:
|
||||
case "billing":
|
||||
contact_id = org_model.billing_contact_id
|
||||
|
|
|
|||
|
|
@ -6,22 +6,27 @@ Models:
|
|||
- Models: Description
|
||||
"""
|
||||
from typing import Optional
|
||||
from pydantic import Json
|
||||
|
||||
from src.schemas import CustomBaseModel
|
||||
from src.organisation.constants import Status, ContactType
|
||||
|
||||
|
||||
class OrgQuestionnaire(CustomBaseModel):
|
||||
question_one: str
|
||||
question_two: str
|
||||
question_three: str
|
||||
|
||||
|
||||
class OrgOrgPostRequest(CustomBaseModel):
|
||||
name: str
|
||||
intake_questionnaire: Optional[Json] = None
|
||||
intake_questionnaire: Optional[OrgQuestionnaire] = None
|
||||
|
||||
billing_contact_id: Optional[int] = None
|
||||
security_contact_id: Optional[int] = None
|
||||
owner_contact_id: Optional[int] = None
|
||||
|
||||
class OrgQuestionnairePatchRequest(CustomBaseModel):
|
||||
intake_questionnaire: Json
|
||||
intake_questionnaire: OrgQuestionnaire
|
||||
partial: bool
|
||||
|
||||
class OrgStatusPatchRequest(CustomBaseModel):
|
||||
|
|
|
|||
|
|
@ -4,4 +4,16 @@ Module specific exceptions for user module
|
|||
Exceptions:
|
||||
- List: Description
|
||||
- Exceptions: Description
|
||||
"""
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
class UserNotFoundException(HTTPException):
|
||||
def __init__(self, user_id: Optional[int] = None) -> None:
|
||||
detail = "User not found" if user_id is None else f"User with ID '{user_id}' was not found."
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,21 +2,25 @@
|
|||
Router endpoints for user module
|
||||
|
||||
Endpoints:
|
||||
- [get]/me/claims - Retrieves user's OIDC claims
|
||||
- [get]/me/db - Retrieves the user data from the db that corresponds to the current OIDC user
|
||||
- [get]/me/orgs - Retrieves all organisations associated with the current user
|
||||
- [get]/me/orgs/admin - Retrieves only admin organisations for the current user
|
||||
- [get]/self/claims - Retrieves user's OIDC claims
|
||||
- [get]/self/db - Retrieves the user data from the db that corresponds to the current OIDC user
|
||||
- [get]/self/orgs - Retrieves all organisations associated with the current user
|
||||
- [get]/self/orgs/admin - Retrieves only admin organisations for the current user
|
||||
- [get]/{user_id} - Retrieves a specific user by their ID
|
||||
- [get]/{user_id}/orgs - Retrieves all organisations associated with a specific user
|
||||
- [get]/{user_id}/orgs/admin - Retrieves only admin organisations for a specific user
|
||||
- [delete]/{user_id} - Deletes a user from the db by their db ID
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.params import Path
|
||||
from sqlalchemy.sql import exists
|
||||
from starlette import status
|
||||
|
||||
from src.user.models import User
|
||||
from src.user.schemas import UserResponse, OIDCUser, OrgResponse
|
||||
from src.user.schemas import UserResponse, OrgResponse, OIDCClaims
|
||||
from src.user.exceptions import UserNotFoundException
|
||||
|
||||
from src.organisation.models import OrgUsers, Organisation
|
||||
|
||||
|
|
@ -25,36 +29,54 @@ from src.database import db_dependency
|
|||
|
||||
router = APIRouter(
|
||||
prefix="/user",
|
||||
tags=["user"],
|
||||
tags=["User"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me/claims")
|
||||
@router.get("/self/claims", response_model=OIDCClaims, status_code=status.HTTP_200_OK, responses={
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
})
|
||||
async def current_user_claims(user: claims_dependency):
|
||||
"""
|
||||
Returns the full OIDC claims associated with the currently logged-in user.
|
||||
"""
|
||||
user["allowed_origins"] = user.get("allowed-origins", [])
|
||||
return user
|
||||
|
||||
|
||||
@router.get("/me/db", response_model=OIDCUser)
|
||||
@router.get("/self/db", response_model=UserResponse, status_code=status.HTTP_200_OK, responses={
|
||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
})
|
||||
async def current_user(user: claims_dependency, db: db_dependency):
|
||||
db_id = user.get("db_id", None)
|
||||
if db_id is None:
|
||||
raise HTTPException(status_code=404, detail="User not found in db")
|
||||
"""
|
||||
Returns the database details associated with the currently logged-in user.
|
||||
"""
|
||||
user_id = user.get("db_id", None)
|
||||
if user_id is None:
|
||||
raise UserNotFoundException()
|
||||
|
||||
user_model = (db.query(User).filter(User.id == db_id).first())
|
||||
user_model = (db.query(User).filter(User.id == user_id).first())
|
||||
if user_model is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
raise UserNotFoundException(user_id=user_id)
|
||||
|
||||
return user_model
|
||||
|
||||
|
||||
@router.get("/me/orgs", response_model=list[OrgResponse])
|
||||
@router.get("/self/orgs", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={
|
||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
})
|
||||
async def get_current_organisations(db: db_dependency, user: claims_dependency):
|
||||
"""
|
||||
Returns all organisations associated with the currently logged-in user.
|
||||
"""
|
||||
user_id = user.get("db_id", None)
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
raise UserNotFoundException()
|
||||
user_exists = db.query(exists().where(User.id == user_id)).scalar()
|
||||
if not user_exists:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
raise UserNotFoundException(user_id=user_id)
|
||||
|
||||
org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name)
|
||||
.join(OrgUsers, Organisation.id == OrgUsers.org_id)
|
||||
|
|
@ -65,14 +87,20 @@ async def get_current_organisations(db: db_dependency, user: claims_dependency):
|
|||
return org_user_models
|
||||
|
||||
|
||||
@router.get("/me/orgs/admin", response_model=list[OrgResponse])
|
||||
@router.get("/self/orgs/admin", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={
|
||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
})
|
||||
async def get_current_admin_organisations(db: db_dependency, user: claims_dependency):
|
||||
"""
|
||||
Returns the organisations for which the currently logged-in user is an admin.
|
||||
"""
|
||||
user_id = user.get("db_id", None)
|
||||
if user_id is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
raise UserNotFoundException()
|
||||
user_exists = db.query(exists().where(User.id == user_id)).scalar()
|
||||
if not user_exists:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
raise UserNotFoundException(user_id=user_id)
|
||||
|
||||
org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name)
|
||||
.join(OrgUsers, Organisation.id == OrgUsers.org_id)
|
||||
|
|
@ -84,20 +112,32 @@ async def get_current_admin_organisations(db: db_dependency, user: claims_depend
|
|||
return org_user_models
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user_by_id(user_id: int, db: db_dependency):
|
||||
@router.get("/{user_id}", response_model=UserResponse, status_code=status.HTTP_200_OK, responses={
|
||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
})
|
||||
async def get_user_by_id(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]):
|
||||
"""
|
||||
Returns the database details associated with the provided user ID.
|
||||
"""
|
||||
user_model = (db.query(User).filter(User.id == user_id).first())
|
||||
if user_model is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
raise UserNotFoundException(user_id=user_id)
|
||||
|
||||
return user_model
|
||||
|
||||
|
||||
@router.get("/{user_id}/orgs", response_model=list[OrgResponse])
|
||||
async def get_organisations(db: db_dependency, user_id: int = Path(gt=0)):
|
||||
@router.get("/{user_id}/orgs", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={
|
||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
})
|
||||
async def get_organisations(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]):
|
||||
"""
|
||||
Returns all organisations associated with the provided user ID.
|
||||
"""
|
||||
user_exists = db.query(exists().where(User.id == user_id)).scalar()
|
||||
if not user_exists:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
raise UserNotFoundException(user_id=user_id)
|
||||
|
||||
org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name)
|
||||
.join(OrgUsers, Organisation.id == OrgUsers.org_id)
|
||||
|
|
@ -108,11 +148,17 @@ async def get_organisations(db: db_dependency, user_id: int = Path(gt=0)):
|
|||
return org_user_models
|
||||
|
||||
|
||||
@router.get("/{user_id}/orgs/admin", response_model=list[OrgResponse])
|
||||
async def get_admin_organisations(db: db_dependency, user_id: int = Path(gt=0)):
|
||||
@router.get("/{user_id}/orgs/admin", response_model=list[OrgResponse], status_code=status.HTTP_200_OK, responses={
|
||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
})
|
||||
async def get_admin_organisations(db: db_dependency, user_id: Annotated[int, Path(gt=0,description="User database ID")]):
|
||||
"""
|
||||
Returns the organisations for which the user with the provided user ID is an admin.
|
||||
"""
|
||||
user_exists = db.query(exists().where(User.id == user_id)).scalar()
|
||||
if not user_exists:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
raise UserNotFoundException(user_id=user_id)
|
||||
|
||||
org_user_models = (db.query(OrgUsers.org_id, OrgUsers.is_admin, Organisation.name)
|
||||
.join(OrgUsers, Organisation.id == OrgUsers.org_id)
|
||||
|
|
@ -124,10 +170,16 @@ async def get_admin_organisations(db: db_dependency, user_id: int = Path(gt=0)):
|
|||
return org_user_models
|
||||
|
||||
|
||||
@router.delete("/{user_id}")
|
||||
async def delete_user_by_id(user_id: int, db: db_dependency):
|
||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT, responses={
|
||||
status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
|
||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||
})
|
||||
async def delete_user_by_id(user_id: Annotated[int, Path(gt=0)], db: db_dependency):
|
||||
"""
|
||||
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
|
||||
"""
|
||||
user_model = (db.query(User).filter(User.id == user_id).first())
|
||||
if user_model is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
raise UserNotFoundException(user_id=user_id)
|
||||
db.delete(user_model)
|
||||
db.commit()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,31 @@ Models:
|
|||
- Models: Description
|
||||
"""
|
||||
from src.schemas import CustomBaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class OIDCClaims(CustomBaseModel):
|
||||
exp: int
|
||||
iat: int
|
||||
auth_time: int
|
||||
jti: str
|
||||
iss: str
|
||||
aud: str
|
||||
sub: str
|
||||
typ: str
|
||||
azp: str
|
||||
sid: str
|
||||
acr: str
|
||||
allowed_origins: list[str]
|
||||
realm_access: dict[str, list[str]]
|
||||
resource_access: dict[str, dict[str, list[str]]]
|
||||
scope: str
|
||||
email_verified: bool
|
||||
name: str
|
||||
preferred_username: str
|
||||
given_name: str
|
||||
family_name: str
|
||||
email: str
|
||||
db_id: int
|
||||
|
||||
|
||||
class OIDCUser(CustomBaseModel):
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ Functions:
|
|||
Exports:
|
||||
- add_user_to_db
|
||||
"""
|
||||
from authlib.jose import JWTClaims
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from src.user.schemas import OIDCUser
|
||||
|
|
@ -15,7 +16,7 @@ from src.user.models import User
|
|||
from src.database import get_db
|
||||
|
||||
|
||||
async def add_user_to_db(user_claims: JWTClaims) -> int:
|
||||
async def add_user_to_db(user_claims: dict[str, Any]) -> int:
|
||||
try:
|
||||
valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"])
|
||||
except Exception as e:
|
||||
|
|
@ -31,5 +32,14 @@ async def add_user_to_db(user_claims: JWTClaims) -> int:
|
|||
db.commit()
|
||||
return user_model.id
|
||||
else:
|
||||
# Verify details still match and update accordingly.
|
||||
change = False
|
||||
if db_user.first_name != valid_user.first_name:
|
||||
db_user.first_name = valid_user.first_name
|
||||
change = True
|
||||
if db_user.last_name != valid_user.last_name:
|
||||
db_user.last_name = valid_user.last_name
|
||||
change = True
|
||||
if change:
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
return db_user.id
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue