minor: ruff format
Tabs -> spaces
This commit is contained in:
parent
b2921b73b8
commit
fab228bf8f
56 changed files with 3629 additions and 3630 deletions
|
|
@ -22,5 +22,5 @@ from fastapi import APIRouter
|
|||
|
||||
|
||||
router = APIRouter(
|
||||
tags=[""],
|
||||
tags=[""],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,6 @@ Exports:
|
|||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(
|
||||
tags=["admin"],
|
||||
prefix="/admin",
|
||||
tags=["admin"],
|
||||
prefix="/admin",
|
||||
)
|
||||
|
|
|
|||
14
src/api.py
14
src/api.py
|
|
@ -26,15 +26,15 @@ api_router.include_router(iam_router)
|
|||
|
||||
|
||||
class HealthCheckResponse(CustomBaseModel):
|
||||
status: str
|
||||
status: str
|
||||
|
||||
|
||||
@api_router.get(
|
||||
path="/healthcheck",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=HealthCheckResponse,
|
||||
include_in_schema=False,
|
||||
path="/healthcheck",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=HealthCheckResponse,
|
||||
include_in_schema=False,
|
||||
)
|
||||
def healthcheck():
|
||||
"""Simple health check endpoint."""
|
||||
return {"status": "ok"}
|
||||
"""Simple health check endpoint."""
|
||||
return {"status": "ok"}
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ from src.config import CustomBaseSettings
|
|||
|
||||
|
||||
class AuthConfig(CustomBaseSettings):
|
||||
OIDC_CONFIG: str = ""
|
||||
OIDC_ISSUER: str = ""
|
||||
CLIENT_ID: str = ""
|
||||
OIDC_CONFIG: str = ""
|
||||
OIDC_ISSUER: str = ""
|
||||
CLIENT_ID: str = ""
|
||||
|
||||
|
||||
auth_settings = AuthConfig()
|
||||
|
|
|
|||
|
|
@ -16,92 +16,92 @@ from src.exceptions import ForbiddenException
|
|||
from src.user.dependencies import user_model_claims_dependency
|
||||
from src.user.models import User
|
||||
from src.organisation.dependencies import (
|
||||
org_model_query_dependency,
|
||||
org_model_body_dependency,
|
||||
org_model_query_dependency,
|
||||
org_model_body_dependency,
|
||||
)
|
||||
from src.organisation.models import Organisation as Org
|
||||
|
||||
|
||||
async def org_query_user_claims(
|
||||
org_model: org_model_query_dependency, user_model: user_model_claims_dependency
|
||||
org_model: org_model_query_dependency, user_model: user_model_claims_dependency
|
||||
):
|
||||
if user_model in org_model.user_rel:
|
||||
return True
|
||||
if user_model in org_model.user_rel:
|
||||
return True
|
||||
|
||||
raise ForbiddenException()
|
||||
raise ForbiddenException()
|
||||
|
||||
|
||||
org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)]
|
||||
|
||||
|
||||
def get_super_admin_list():
|
||||
return []
|
||||
return []
|
||||
|
||||
|
||||
def empty_su_list():
|
||||
return []
|
||||
return []
|
||||
|
||||
|
||||
def testing_su_list():
|
||||
return ["admin@test.com"]
|
||||
return ["admin@test.com"]
|
||||
|
||||
|
||||
su_list_dependency = Annotated[list[str | None], Depends(get_super_admin_list)]
|
||||
|
||||
|
||||
async def user_model_super_admin(
|
||||
user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency
|
||||
user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency
|
||||
):
|
||||
if user_model.email in super_admin_emails:
|
||||
return user_model
|
||||
if user_model.email in super_admin_emails:
|
||||
return user_model
|
||||
|
||||
raise ForbiddenException(message="Must be super admin")
|
||||
raise ForbiddenException(message="Must be super admin")
|
||||
|
||||
|
||||
super_admin_dependency = Annotated[User, Depends(user_model_super_admin)]
|
||||
|
||||
|
||||
async def org_query_root_claims(
|
||||
user_model: user_model_claims_dependency,
|
||||
org_model: org_model_query_dependency,
|
||||
su_emails: su_list_dependency,
|
||||
request: Request,
|
||||
user_model: user_model_claims_dependency,
|
||||
org_model: org_model_query_dependency,
|
||||
su_emails: su_list_dependency,
|
||||
request: Request,
|
||||
):
|
||||
try:
|
||||
if await user_model_super_admin(user_model, su_emails):
|
||||
return org_model
|
||||
except ForbiddenException:
|
||||
pass
|
||||
try:
|
||||
if await user_model_super_admin(user_model, su_emails):
|
||||
return org_model
|
||||
except ForbiddenException:
|
||||
pass
|
||||
|
||||
await org_status_check(org_model, request)
|
||||
await org_status_check(org_model, request)
|
||||
|
||||
if org_model.root_user_id == user_model.id:
|
||||
return org_model
|
||||
if org_model.root_user_id == user_model.id:
|
||||
return org_model
|
||||
|
||||
raise ForbiddenException(message="Must be the org's root user")
|
||||
raise ForbiddenException(message="Must be the org's root user")
|
||||
|
||||
|
||||
org_model_root_claim_query_dependency = Annotated[Org, Depends(org_query_root_claims)]
|
||||
|
||||
|
||||
async def org_body_root_claims(
|
||||
user_model: user_model_claims_dependency,
|
||||
org_model: org_model_body_dependency,
|
||||
su_emails: su_list_dependency,
|
||||
request: Request,
|
||||
user_model: user_model_claims_dependency,
|
||||
org_model: org_model_body_dependency,
|
||||
su_emails: su_list_dependency,
|
||||
request: Request,
|
||||
):
|
||||
try:
|
||||
if await user_model_super_admin(user_model, su_emails):
|
||||
return org_model
|
||||
except ForbiddenException:
|
||||
pass
|
||||
try:
|
||||
if await user_model_super_admin(user_model, su_emails):
|
||||
return org_model
|
||||
except ForbiddenException:
|
||||
pass
|
||||
|
||||
await org_status_check(org_model, request)
|
||||
await org_status_check(org_model, request)
|
||||
|
||||
if org_model.root_user_id == user_model.id:
|
||||
return org_model
|
||||
if org_model.root_user_id == user_model.id:
|
||||
return org_model
|
||||
|
||||
raise ForbiddenException(message="Must be the org's root user")
|
||||
raise ForbiddenException(message="Must be the org's root user")
|
||||
|
||||
|
||||
org_model_root_claim_body_dependency = Annotated[Org, Depends(org_body_root_claims)]
|
||||
|
|
|
|||
|
|
@ -8,5 +8,5 @@ Exports:
|
|||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(
|
||||
tags=["auth"],
|
||||
tags=["auth"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -31,56 +31,56 @@ oidc_dependency = Annotated[str, Depends(oidc)]
|
|||
|
||||
|
||||
async def get_dev_user():
|
||||
return {"db_id": 1, "email": "chris@sr2.uk"}
|
||||
return {"db_id": 1, "email": "chris@sr2.uk"}
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
oidc_auth_string: oidc_dependency, db: DbSession
|
||||
oidc_auth_string: oidc_dependency, db: DbSession
|
||||
) -> dict[str, Any]:
|
||||
config_url = urlopen(auth_settings.OIDC_CONFIG)
|
||||
config = json.loads(config_url.read())
|
||||
jwks_uri = config["jwks_uri"]
|
||||
key_response = requests.get(jwks_uri)
|
||||
jwk_keys = KeySet.import_key_set(key_response.json())
|
||||
config_url = urlopen(auth_settings.OIDC_CONFIG)
|
||||
config = json.loads(config_url.read())
|
||||
jwks_uri = config["jwks_uri"]
|
||||
key_response = requests.get(jwks_uri)
|
||||
jwk_keys = KeySet.import_key_set(key_response.json())
|
||||
|
||||
token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys)
|
||||
token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys)
|
||||
|
||||
claims_requests = jwt.JWTClaimsRegistry(
|
||||
exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER}
|
||||
)
|
||||
claims_requests = jwt.JWTClaimsRegistry(
|
||||
exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER}
|
||||
)
|
||||
|
||||
try:
|
||||
claims_requests.validate(token.claims)
|
||||
except ExpiredTokenError:
|
||||
raise UnauthorizedException(message="Token is expired")
|
||||
db_id = await add_user(db, token.claims)
|
||||
try:
|
||||
claims_requests.validate(token.claims)
|
||||
except ExpiredTokenError:
|
||||
raise UnauthorizedException(message="Token is expired")
|
||||
db_id = await add_user(db, token.claims)
|
||||
|
||||
token.claims["db_id"] = db_id
|
||||
token.claims["db_id"] = db_id
|
||||
|
||||
return token.claims
|
||||
return token.claims
|
||||
|
||||
|
||||
claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)]
|
||||
|
||||
|
||||
async def org_status_check(org_model: Org, request: Request):
|
||||
org_status = OrgStatus(org_model.status)
|
||||
if org_status.is_blocked:
|
||||
raise ForbiddenException("This organisation cannot perform this action.")
|
||||
org_status = OrgStatus(org_model.status)
|
||||
if org_status.is_blocked:
|
||||
raise ForbiddenException("This organisation cannot perform this action.")
|
||||
|
||||
root = "/api/v1"
|
||||
root = "/api/v1"
|
||||
|
||||
pre_approval_endpoints = [
|
||||
f"PATCH{root}/org/status",
|
||||
f"PATCH{root}/org/questionnaire",
|
||||
f"GET{root}/org",
|
||||
f"GET{root}/org/contact",
|
||||
f"PATCH{root}/org/contact",
|
||||
f"DELETE{root}/org/self",
|
||||
]
|
||||
current_request = f"{request.method}{request.url.path}"
|
||||
if (
|
||||
current_request not in pre_approval_endpoints
|
||||
and org_model.status != OrgStatus.APPROVED
|
||||
):
|
||||
raise AwaitingApprovalException(org_model.id)
|
||||
pre_approval_endpoints = [
|
||||
f"PATCH{root}/org/status",
|
||||
f"PATCH{root}/org/questionnaire",
|
||||
f"GET{root}/org",
|
||||
f"GET{root}/org/contact",
|
||||
f"PATCH{root}/org/contact",
|
||||
f"DELETE{root}/org/self",
|
||||
]
|
||||
current_request = f"{request.method}{request.url.path}"
|
||||
if (
|
||||
current_request not in pre_approval_endpoints
|
||||
and org_model.status != OrgStatus.APPROVED
|
||||
):
|
||||
raise AwaitingApprovalException(org_model.id)
|
||||
|
|
|
|||
|
|
@ -16,31 +16,31 @@ from src.constants import Environment
|
|||
|
||||
|
||||
class CustomBaseSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env", env_file_encoding="utf-8", extra="ignore"
|
||||
)
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env", env_file_encoding="utf-8", extra="ignore"
|
||||
)
|
||||
|
||||
|
||||
class Config(CustomBaseSettings):
|
||||
APP_VERSION: str = "0.1"
|
||||
ENVIRONMENT: Environment = Environment.PRODUCTION
|
||||
SECRET_KEY: SecretStr = SecretStr("")
|
||||
DISABLE_AUTH: bool = False
|
||||
APP_VERSION: str = "0.1"
|
||||
ENVIRONMENT: Environment = Environment.PRODUCTION
|
||||
SECRET_KEY: SecretStr = SecretStr("")
|
||||
DISABLE_AUTH: bool = False
|
||||
|
||||
CORS_ORIGINS: list[str] = ["*"]
|
||||
CORS_ORIGINS_REGEX: str | None = None
|
||||
CORS_HEADERS: list[str] = ["*"]
|
||||
CORS_ORIGINS: list[str] = ["*"]
|
||||
CORS_ORIGINS_REGEX: str | None = None
|
||||
CORS_HEADERS: list[str] = ["*"]
|
||||
|
||||
DATABASE_NAME: str = "fastapi-exp"
|
||||
DATABASE_PORT: str = "5432"
|
||||
DATABASE_HOSTNAME: str = "localhost"
|
||||
DATABASE_CREDENTIALS: SecretStr = SecretStr(":")
|
||||
DATABASE_NAME: str = "fastapi-exp"
|
||||
DATABASE_PORT: str = "5432"
|
||||
DATABASE_HOSTNAME: str = "localhost"
|
||||
DATABASE_CREDENTIALS: SecretStr = SecretStr(":")
|
||||
|
||||
DATABASE_POOL_SIZE: int = 16
|
||||
DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes
|
||||
DATABASE_POOL_PRE_PING: bool = True
|
||||
DATABASE_POOL_SIZE: int = 16
|
||||
DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes
|
||||
DATABASE_POOL_PRE_PING: bool = True
|
||||
|
||||
LETTERMINT_API_TOKEN: SecretStr = SecretStr("")
|
||||
LETTERMINT_API_TOKEN: SecretStr = SecretStr("")
|
||||
|
||||
|
||||
settings = Config()
|
||||
|
|
@ -51,20 +51,20 @@ DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME
|
|||
DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value()
|
||||
# this will support special chars for credentials
|
||||
_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split(
|
||||
":"
|
||||
":"
|
||||
)
|
||||
_QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD))
|
||||
|
||||
SQLALCHEMY_DATABASE_URI = SecretStr(
|
||||
f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
|
||||
f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
|
||||
)
|
||||
|
||||
if settings.ENVIRONMENT == Environment.TESTING:
|
||||
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")
|
||||
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")
|
||||
|
||||
app_configs: dict[str, Any] = {"title": "App API"}
|
||||
if settings.ENVIRONMENT.is_deployed:
|
||||
app_configs["root_path"] = f"/v{settings.APP_VERSION}"
|
||||
app_configs["root_path"] = f"/v{settings.APP_VERSION}"
|
||||
|
||||
if not settings.ENVIRONMENT.is_debug:
|
||||
app_configs["openapi_url"] = None # hide docs
|
||||
app_configs["openapi_url"] = None # hide docs
|
||||
|
|
|
|||
|
|
@ -9,29 +9,29 @@ from enum import StrEnum, auto
|
|||
|
||||
|
||||
class Environment(StrEnum):
|
||||
"""
|
||||
Enumeration of environments.
|
||||
"""
|
||||
Enumeration of environments.
|
||||
|
||||
Attributes:
|
||||
LOCAL (str): Application is running locally
|
||||
TESTING (str): Application is running in testing mode
|
||||
STAGING (str): Application is running in staging mode (ie not testing)
|
||||
PRODUCTION (str): Application is running in production mode
|
||||
"""
|
||||
Attributes:
|
||||
LOCAL (str): Application is running locally
|
||||
TESTING (str): Application is running in testing mode
|
||||
STAGING (str): Application is running in staging mode (ie not testing)
|
||||
PRODUCTION (str): Application is running in production mode
|
||||
"""
|
||||
|
||||
LOCAL = auto()
|
||||
TESTING = auto()
|
||||
STAGING = auto()
|
||||
PRODUCTION = auto()
|
||||
LOCAL = auto()
|
||||
TESTING = auto()
|
||||
STAGING = auto()
|
||||
PRODUCTION = auto()
|
||||
|
||||
@property
|
||||
def is_debug(self):
|
||||
return self in (self.LOCAL, self.STAGING, self.TESTING)
|
||||
@property
|
||||
def is_debug(self):
|
||||
return self in (self.LOCAL, self.STAGING, self.TESTING)
|
||||
|
||||
@property
|
||||
def is_testing(self):
|
||||
return self == self.TESTING
|
||||
@property
|
||||
def is_testing(self):
|
||||
return self == self.TESTING
|
||||
|
||||
@property
|
||||
def is_deployed(self) -> bool:
|
||||
return self in (self.STAGING, self.PRODUCTION)
|
||||
@property
|
||||
def is_deployed(self) -> bool:
|
||||
return self in (self.STAGING, self.PRODUCTION)
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@ from fastapi import HTTPException, status
|
|||
|
||||
|
||||
class ContactNotFoundException(HTTPException):
|
||||
def __init__(self, contact_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"Contact not found"
|
||||
if contact_id is None
|
||||
else f"Contact with ID '{contact_id}' was not found."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
def __init__(self, contact_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"Contact not found"
|
||||
if contact_id is None
|
||||
else f"Contact with ID '{contact_id}' was not found."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,22 +15,22 @@ from src.models import CustomBase
|
|||
|
||||
|
||||
class Contact(CustomBase, IdMixin):
|
||||
__tablename__ = "contact"
|
||||
__tablename__ = "contact"
|
||||
|
||||
email: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
first_name: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
last_name: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
phonenumber: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
|
||||
email: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
first_name: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
last_name: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
phonenumber: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
|
||||
|
||||
street_address: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
|
||||
locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City
|
||||
country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB
|
||||
address_region: Mapped[str | None] = mapped_column(default=None, nullable=True)
|
||||
postal_code: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
street_address: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
|
||||
locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City
|
||||
country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB
|
||||
address_region: Mapped[str | None] = mapped_column(default=None, nullable=True)
|
||||
postal_code: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||
|
||||
org_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
org_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,6 @@ from fastapi import APIRouter
|
|||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/contact",
|
||||
tags=["contact"],
|
||||
prefix="/contact",
|
||||
tags=["contact"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,22 +14,22 @@ from src.schemas import CustomBaseModel
|
|||
|
||||
|
||||
class ContactAddress(CustomBaseModel):
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
|
||||
post_office_box_number: Optional[str] = None
|
||||
street_address: Optional[str] = None
|
||||
street_address_line_2: Optional[str] = None
|
||||
locality: Optional[str] = None
|
||||
address_region: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
postal_code: Optional[str] = None
|
||||
post_office_box_number: Optional[str] = None
|
||||
street_address: Optional[str] = None
|
||||
street_address_line_2: Optional[str] = None
|
||||
locality: Optional[str] = None
|
||||
address_region: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
postal_code: Optional[str] = None
|
||||
|
||||
|
||||
class ContactModel(CustomBaseModel):
|
||||
email: Optional[EmailStr] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phonenumber: Optional[str] = None
|
||||
vat_number: Optional[str] = None
|
||||
email: Optional[EmailStr] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phonenumber: Optional[str] = None
|
||||
vat_number: Optional[str] = None
|
||||
|
||||
address: ContactAddress
|
||||
address: ContactAddress
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Database connection and session utilities
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Annotated, Generator
|
||||
from sqlalchemy import create_engine, StaticPool, Connection
|
||||
|
|
@ -29,6 +30,7 @@ else:
|
|||
|
||||
sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection() -> Generator[Connection, None, None]:
|
||||
with engine.connect() as connection:
|
||||
|
|
@ -38,12 +40,15 @@ def get_db_connection() -> Generator[Connection, None, None]:
|
|||
connection.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def _get_db_connection() -> Generator[Connection, None]:
|
||||
with get_db_connection() as connection:
|
||||
yield connection
|
||||
|
||||
|
||||
DbConnection = Annotated[Connection, Depends(_get_db_connection)]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session() -> Generator[Session, None, None]:
|
||||
session = sm()
|
||||
|
|
@ -60,4 +65,5 @@ def _get_db_session() -> Generator[Session, None]:
|
|||
with get_db_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
DbSession = Annotated[Session, Depends(_get_db_session)]
|
||||
|
|
|
|||
|
|
@ -12,36 +12,36 @@ from fastapi import HTTPException, status
|
|||
|
||||
|
||||
class UnprocessableContentException(HTTPException):
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
detail = "Unprocessable content" if not message else message
|
||||
super().__init__(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
detail=detail,
|
||||
)
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
detail = "Unprocessable content" if not message else message
|
||||
super().__init__(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
class ConflictException(HTTPException):
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
detail = "Conflict" if not message else message
|
||||
super().__init__(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=detail,
|
||||
)
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
detail = "Conflict" if not message else message
|
||||
super().__init__(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
class ForbiddenException(HTTPException):
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
detail = "Forbidden" if not message else message
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=detail,
|
||||
)
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
detail = "Forbidden" if not message else message
|
||||
super().__init__(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
class UnauthorizedException(HTTPException):
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
detail = "Not authorized" if not message else message
|
||||
super().__init__(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=detail,
|
||||
)
|
||||
def __init__(self, message: Optional[str] = None) -> None:
|
||||
detail = "Not authorized" if not message else message
|
||||
super().__init__(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=detail,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,59 +18,55 @@ from src.iam.exceptions import GroupNotFoundException, PermNotFoundException
|
|||
from src.iam.schemas import GroupIDMixin, PermIDMixin
|
||||
|
||||
|
||||
def get_group_model_query(
|
||||
db: DbSession, group_id: Annotated[int, Query(gt=0)]
|
||||
) -> Group:
|
||||
group_model = db.get(Group, group_id)
|
||||
if group_model is None:
|
||||
raise GroupNotFoundException(group_id)
|
||||
def get_group_model_query(db: DbSession, group_id: Annotated[int, Query(gt=0)]) -> Group:
|
||||
group_model = db.get(Group, group_id)
|
||||
if group_model is None:
|
||||
raise GroupNotFoundException(group_id)
|
||||
|
||||
return group_model
|
||||
return group_model
|
||||
|
||||
|
||||
group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)]
|
||||
|
||||
|
||||
def get_group_model_body(
|
||||
db: DbSession, request_model: Optional[GroupIDMixin] = None
|
||||
db: DbSession, request_model: Optional[GroupIDMixin] = None
|
||||
) -> Group:
|
||||
group_id = getattr(request_model, "group_id", None)
|
||||
if group_id is None:
|
||||
raise GroupNotFoundException()
|
||||
group_model = db.get(Group, group_id)
|
||||
if group_model is None:
|
||||
raise GroupNotFoundException(group_id)
|
||||
group_id = getattr(request_model, "group_id", None)
|
||||
if group_id is None:
|
||||
raise GroupNotFoundException()
|
||||
group_model = db.get(Group, group_id)
|
||||
if group_model is None:
|
||||
raise GroupNotFoundException(group_id)
|
||||
|
||||
return group_model
|
||||
return group_model
|
||||
|
||||
|
||||
group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)]
|
||||
|
||||
|
||||
def get_perm_model_body(
|
||||
db: DbSession, request_model: Optional[PermIDMixin] = None
|
||||
db: DbSession, request_model: Optional[PermIDMixin] = None
|
||||
) -> Permission:
|
||||
perm_id = getattr(request_model, "permission_id", None)
|
||||
if perm_id is None:
|
||||
raise PermNotFoundException
|
||||
perm_model = db.get(Permission, perm_id)
|
||||
if perm_model is None:
|
||||
raise PermNotFoundException(perm_id)
|
||||
perm_id = getattr(request_model, "permission_id", None)
|
||||
if perm_id is None:
|
||||
raise PermNotFoundException
|
||||
perm_model = db.get(Permission, perm_id)
|
||||
if perm_model is None:
|
||||
raise PermNotFoundException(perm_id)
|
||||
|
||||
return perm_model
|
||||
return perm_model
|
||||
|
||||
|
||||
perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)]
|
||||
|
||||
|
||||
def get_perm_model_query(
|
||||
db: DbSession, perm_id: Annotated[int, Query(gt=0)]
|
||||
) -> Permission:
|
||||
perm_model = db.get(Permission, perm_id)
|
||||
if perm_model is None:
|
||||
raise PermNotFoundException(perm_id)
|
||||
def get_perm_model_query(db: DbSession, perm_id: Annotated[int, Query(gt=0)]) -> Permission:
|
||||
perm_model = db.get(Permission, perm_id)
|
||||
if perm_model is None:
|
||||
raise PermNotFoundException(perm_id)
|
||||
|
||||
return perm_model
|
||||
return perm_model
|
||||
|
||||
|
||||
perm_model_query_dependency = Annotated[Permission, Depends(get_perm_model_query)]
|
||||
|
|
|
|||
|
|
@ -12,26 +12,26 @@ from fastapi import HTTPException, status
|
|||
|
||||
|
||||
class GroupNotFoundException(HTTPException):
|
||||
def __init__(self, group_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"Group not found"
|
||||
if group_id is None
|
||||
else f"User with ID '{group_id}' was not found."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
def __init__(self, group_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"Group not found"
|
||||
if group_id is None
|
||||
else f"User with ID '{group_id}' was not found."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
class PermNotFoundException(HTTPException):
|
||||
def __init__(self, perm_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"Permission not found"
|
||||
if perm_id is None
|
||||
else f"User with ID '{perm_id}' was not found."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
def __init__(self, perm_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"Permission not found"
|
||||
if perm_id is None
|
||||
else f"User with ID '{perm_id}' was not found."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,90 +25,90 @@ from src.models import CustomBase, IdMixin
|
|||
|
||||
|
||||
class Permission(CustomBase, IdMixin):
|
||||
__tablename__ = "permission"
|
||||
__tablename__ = "permission"
|
||||
|
||||
resource: Mapped[str]
|
||||
action: Mapped[str]
|
||||
resource: Mapped[str]
|
||||
action: Mapped[str]
|
||||
|
||||
service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE"))
|
||||
service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE"))
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"service_id",
|
||||
"resource",
|
||||
"action",
|
||||
name="uniq_permission_resource_and_action",
|
||||
),
|
||||
)
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"service_id",
|
||||
"resource",
|
||||
"action",
|
||||
name="uniq_permission_resource_and_action",
|
||||
),
|
||||
)
|
||||
|
||||
service_rel = relationship(
|
||||
"Service",
|
||||
back_populates="permission_rel",
|
||||
foreign_keys="Permission.service_id",
|
||||
)
|
||||
service_rel = relationship(
|
||||
"Service",
|
||||
back_populates="permission_rel",
|
||||
foreign_keys="Permission.service_id",
|
||||
)
|
||||
|
||||
group_rel = relationship(
|
||||
"Group", secondary="group_permissions", back_populates="permission_rel"
|
||||
)
|
||||
group_rel = relationship(
|
||||
"Group", secondary="group_permissions", back_populates="permission_rel"
|
||||
)
|
||||
|
||||
org_rel = relationship(
|
||||
"Organisation", secondary="org_permissions", back_populates="permission_rel"
|
||||
)
|
||||
org_rel = relationship(
|
||||
"Organisation", secondary="org_permissions", back_populates="permission_rel"
|
||||
)
|
||||
|
||||
@property
|
||||
def service_name(self):
|
||||
return self.service_rel.name
|
||||
@property
|
||||
def service_name(self):
|
||||
return self.service_rel.name
|
||||
|
||||
|
||||
class Group(CustomBase, IdMixin):
|
||||
__tablename__ = "group"
|
||||
__tablename__ = "group"
|
||||
|
||||
name: Mapped[str]
|
||||
name: Mapped[str]
|
||||
|
||||
org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE"))
|
||||
org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE"))
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"name",
|
||||
"org_id",
|
||||
name="uniq_group_name_org_id",
|
||||
),
|
||||
)
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"name",
|
||||
"org_id",
|
||||
name="uniq_group_name_org_id",
|
||||
),
|
||||
)
|
||||
|
||||
user_rel = relationship("User", secondary="user_groups", back_populates="group_rel")
|
||||
user_rel = relationship("User", secondary="user_groups", back_populates="group_rel")
|
||||
|
||||
org_rel = relationship("Organisation", back_populates="group_rel")
|
||||
org_rel = relationship("Organisation", back_populates="group_rel")
|
||||
|
||||
permission_rel = relationship(
|
||||
"Permission", secondary="group_permissions", back_populates="group_rel"
|
||||
)
|
||||
permission_rel = relationship(
|
||||
"Permission", secondary="group_permissions", back_populates="group_rel"
|
||||
)
|
||||
|
||||
|
||||
class GroupPermissions(CustomBase):
|
||||
__tablename__ = "group_permissions"
|
||||
group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
permission_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
__tablename__ = "group_permissions"
|
||||
group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
permission_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class UserGroups(CustomBase):
|
||||
__tablename__ = "user_groups"
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
__tablename__ = "user_groups"
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class OrgPermissions(CustomBase):
|
||||
__tablename__ = "org_permissions"
|
||||
org_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
permission_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
__tablename__ = "org_permissions"
|
||||
org_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
permission_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
|
|
|
|||
1050
src/iam/router.py
1050
src/iam/router.py
File diff suppressed because it is too large
Load diff
|
|
@ -11,151 +11,151 @@ from typing import Optional, Annotated
|
|||
from pydantic import EmailStr, ConfigDict, Field
|
||||
|
||||
from src.schemas import (
|
||||
CustomBaseModel,
|
||||
ResourceName,
|
||||
ServiceIDMixin,
|
||||
OrgIDMixin,
|
||||
UserIDMixin,
|
||||
PermIDMixin,
|
||||
GroupIDMixin,
|
||||
GroupSummary,
|
||||
OrgSummary,
|
||||
UserSummary,
|
||||
CustomBaseModel,
|
||||
ResourceName,
|
||||
ServiceIDMixin,
|
||||
OrgIDMixin,
|
||||
UserIDMixin,
|
||||
PermIDMixin,
|
||||
GroupIDMixin,
|
||||
GroupSummary,
|
||||
OrgSummary,
|
||||
UserSummary,
|
||||
)
|
||||
|
||||
|
||||
class UserSchema(CustomBaseModel):
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
|
||||
id: int
|
||||
first_name: str
|
||||
last_name: str
|
||||
email: EmailStr
|
||||
id: int
|
||||
first_name: str
|
||||
last_name: str
|
||||
email: EmailStr
|
||||
|
||||
|
||||
class PermissionSchema(CustomBaseModel):
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
|
||||
id: int
|
||||
service_name: str
|
||||
resource: str
|
||||
action: str
|
||||
id: int
|
||||
service_name: str
|
||||
resource: str
|
||||
action: str
|
||||
|
||||
|
||||
class GroupDetails(CustomBaseModel):
|
||||
details: GroupSummary
|
||||
permissions: list[PermissionSchema]
|
||||
details: GroupSummary
|
||||
permissions: list[PermissionSchema]
|
||||
|
||||
|
||||
class IAMCAoRRequest(CustomBaseModel):
|
||||
action: str
|
||||
rn: ResourceName
|
||||
action: str
|
||||
rn: ResourceName
|
||||
|
||||
|
||||
class IAMCAoRResponse(CustomBaseModel):
|
||||
allowed: bool
|
||||
user: UserSummary
|
||||
action: str
|
||||
rn: ResourceName
|
||||
allowed: bool
|
||||
user: UserSummary
|
||||
action: str
|
||||
rn: ResourceName
|
||||
|
||||
|
||||
class IAMGetGroupPermissionsResponse(CustomBaseModel):
|
||||
organisation: OrgSummary
|
||||
group: GroupSummary
|
||||
permissions: list[PermissionSchema]
|
||||
organisation: OrgSummary
|
||||
group: GroupSummary
|
||||
permissions: list[PermissionSchema]
|
||||
|
||||
|
||||
class IAMGetGroupUsersResponse(CustomBaseModel):
|
||||
organisation: OrgSummary
|
||||
group: GroupSummary
|
||||
users: list[UserSummary]
|
||||
organisation: OrgSummary
|
||||
group: GroupSummary
|
||||
users: list[UserSummary]
|
||||
|
||||
|
||||
class IAMPostGroupRequest(OrgIDMixin):
|
||||
name: str = Field(min_length=3)
|
||||
name: str = Field(min_length=3)
|
||||
|
||||
|
||||
class IAMPostGroupResponse(CustomBaseModel):
|
||||
organisation: OrgSummary
|
||||
group: GroupSummary
|
||||
organisation: OrgSummary
|
||||
group: GroupSummary
|
||||
|
||||
|
||||
class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class IAMPutGroupPermissionResponse(CustomBaseModel):
|
||||
organisation: OrgSummary
|
||||
group: GroupSummary
|
||||
permissions: list[PermissionSchema]
|
||||
organisation: OrgSummary
|
||||
group: GroupSummary
|
||||
permissions: list[PermissionSchema]
|
||||
|
||||
|
||||
class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class IAMPutGroupUserResponse(CustomBaseModel):
|
||||
group: GroupSummary
|
||||
users: list[UserSchema]
|
||||
group: GroupSummary
|
||||
users: list[UserSchema]
|
||||
|
||||
|
||||
class IAMDeleteGroupPermissionResponse(CustomBaseModel):
|
||||
group: GroupSummary
|
||||
permissions: list[PermissionSchema]
|
||||
group: GroupSummary
|
||||
permissions: list[PermissionSchema]
|
||||
|
||||
|
||||
class IAMDeleteGroupUserResponse(CustomBaseModel):
|
||||
group: GroupSummary
|
||||
users: list[UserSchema]
|
||||
group: GroupSummary
|
||||
users: list[UserSchema]
|
||||
|
||||
|
||||
class IAMGetPermissionsResponse(CustomBaseModel):
|
||||
permissions: list[PermissionSchema]
|
||||
permissions: list[PermissionSchema]
|
||||
|
||||
|
||||
class IAMPostPermissionRequest(ServiceIDMixin):
|
||||
resource: str
|
||||
action: str
|
||||
resource: str
|
||||
action: str
|
||||
|
||||
|
||||
class IAMPostPermissionResponse(CustomBaseModel):
|
||||
permission: PermissionSchema
|
||||
permission: PermissionSchema
|
||||
|
||||
|
||||
class IAMGetPermissionsSearchRequest(OrgIDMixin):
|
||||
service_id: Annotated[int | None, Field(gt=0)] = None
|
||||
resource: Optional[str] = None
|
||||
action: Optional[str] = None
|
||||
service_id: Annotated[int | None, Field(gt=0)] = None
|
||||
resource: Optional[str] = None
|
||||
action: Optional[str] = None
|
||||
|
||||
|
||||
class IAMGetPermissionsSearchResponse(CustomBaseModel):
|
||||
permissions: list[PermissionSchema]
|
||||
permissions: list[PermissionSchema]
|
||||
|
||||
|
||||
class IAMPutGroupInvitationRequest(OrgIDMixin, GroupIDMixin):
|
||||
user_email: EmailStr
|
||||
user_email: EmailStr
|
||||
|
||||
|
||||
class IAMPutGroupInvitationResponse(CustomBaseModel):
|
||||
organisation: OrgSummary
|
||||
group: GroupSummary
|
||||
invited_email: EmailStr
|
||||
organisation: OrgSummary
|
||||
group: GroupSummary
|
||||
invited_email: EmailStr
|
||||
|
||||
|
||||
class IAMPutGroupInvitationAcceptRequest(CustomBaseModel):
|
||||
jwt: str
|
||||
jwt: str
|
||||
|
||||
|
||||
class IAMPutGroupInvitationAcceptResponse(CustomBaseModel):
|
||||
organisation: OrgSummary
|
||||
user: UserSummary
|
||||
group: GroupDetails
|
||||
organisation: OrgSummary
|
||||
user: UserSummary
|
||||
group: GroupDetails
|
||||
|
||||
|
||||
class IAMPutOrgPermissionsRequest(OrgIDMixin):
|
||||
permissions: list[int]
|
||||
permissions: list[int]
|
||||
|
||||
|
||||
class IAMPutOrgPermissionsResponse(CustomBaseModel):
|
||||
organisation: OrgSummary
|
||||
permissions: list[PermissionSchema]
|
||||
organisation: OrgSummary
|
||||
permissions: list[PermissionSchema]
|
||||
|
|
|
|||
|
|
@ -23,90 +23,90 @@ from src.service.schemas import HasServiceName
|
|||
|
||||
|
||||
def valid_service_key(
|
||||
db: DbSession, request: Request, request_model: HasServiceName
|
||||
db: DbSession, request: Request, request_model: HasServiceName
|
||||
) -> bool:
|
||||
rn = request_model.rn
|
||||
api_key = request.headers.get("X-API-Key", None)
|
||||
if not api_key:
|
||||
raise UnauthorizedException("Missing API key")
|
||||
service = rn.service
|
||||
result = (
|
||||
db.query(Service)
|
||||
.filter(Service.name == service)
|
||||
.filter(Service.api_key == api_key)
|
||||
.first()
|
||||
)
|
||||
if result is None:
|
||||
raise UnauthorizedException("Invalid API key")
|
||||
rn = request_model.rn
|
||||
api_key = request.headers.get("X-API-Key", None)
|
||||
if not api_key:
|
||||
raise UnauthorizedException("Missing API key")
|
||||
service = rn.service
|
||||
result = (
|
||||
db.query(Service)
|
||||
.filter(Service.name == service)
|
||||
.filter(Service.api_key == api_key)
|
||||
.first()
|
||||
)
|
||||
if result is None:
|
||||
raise UnauthorizedException("Invalid API key")
|
||||
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
service_key_dependency = Annotated[bool, Depends(valid_service_key)]
|
||||
|
||||
|
||||
async def send_user_group_invitation(
|
||||
user_email: str, org_name: str, org_id: int, group_id: int, group_name: str
|
||||
user_email: str, org_name: str, org_id: int, group_id: int, group_name: str
|
||||
):
|
||||
expiry_delta = timedelta(hours=24)
|
||||
expiry = datetime.now(timezone.utc) + expiry_delta
|
||||
claims = {
|
||||
"email": user_email,
|
||||
"org_id": org_id,
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"exp": expiry,
|
||||
"type": "group_invite",
|
||||
}
|
||||
expiry_delta = timedelta(hours=24)
|
||||
expiry = datetime.now(timezone.utc) + expiry_delta
|
||||
claims = {
|
||||
"email": user_email,
|
||||
"org_id": org_id,
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"exp": expiry,
|
||||
"type": "group_invite",
|
||||
}
|
||||
|
||||
token = await generate_jwt(claims)
|
||||
subject = f"You have been invited to join a group of {org_name}"
|
||||
body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
|
||||
token = await generate_jwt(claims)
|
||||
subject = f"You have been invited to join a group of {org_name}"
|
||||
body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
|
||||
|
||||
await send_email(
|
||||
recipient=user_email,
|
||||
subject=subject,
|
||||
body=body,
|
||||
)
|
||||
await send_email(
|
||||
recipient=user_email,
|
||||
subject=subject,
|
||||
body=body,
|
||||
)
|
||||
|
||||
|
||||
async def create_group_and_assign_perms(
|
||||
db: Session, org_model: Org, group_name: str, perm_list: list[int]
|
||||
db: Session, org_model: Org, group_name: str, perm_list: list[int]
|
||||
):
|
||||
new_group = Group(name=group_name, org_id=org_model.id)
|
||||
db.add(new_group)
|
||||
db.flush()
|
||||
new_group = Group(name=group_name, org_id=org_model.id)
|
||||
db.add(new_group)
|
||||
db.flush()
|
||||
|
||||
for permission in perm_list:
|
||||
perm_model = db.get(Perm, permission)
|
||||
for permission in perm_list:
|
||||
perm_model = db.get(Perm, permission)
|
||||
|
||||
if perm_model is None:
|
||||
continue
|
||||
if perm_model is None:
|
||||
continue
|
||||
|
||||
new_group.permission_rel.append(perm_model)
|
||||
db.flush()
|
||||
new_group.permission_rel.append(perm_model)
|
||||
db.flush()
|
||||
|
||||
return new_group
|
||||
return new_group
|
||||
|
||||
|
||||
async def assign_default_group(
|
||||
db: DbSession,
|
||||
org_model: Org,
|
||||
user_model: User,
|
||||
group_name: str,
|
||||
perm_list: list[int],
|
||||
db: DbSession,
|
||||
org_model: Org,
|
||||
user_model: User,
|
||||
group_name: str,
|
||||
perm_list: list[int],
|
||||
):
|
||||
group_model = (
|
||||
db.query(Group)
|
||||
.filter(Group.org_id == org_model.id)
|
||||
.filter(Group.name == group_name)
|
||||
.first()
|
||||
)
|
||||
group_model = (
|
||||
db.query(Group)
|
||||
.filter(Group.org_id == org_model.id)
|
||||
.filter(Group.name == group_name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if group_model is None:
|
||||
group_model = await create_group_and_assign_perms(
|
||||
db=db, group_name=group_name, org_model=org_model, perm_list=perm_list
|
||||
)
|
||||
if group_model is None:
|
||||
group_model = await create_group_and_assign_perms(
|
||||
db=db, group_name=group_name, org_model=org_model, perm_list=perm_list
|
||||
)
|
||||
|
||||
user_model.group_rel.append(group_model)
|
||||
db.flush()
|
||||
user_model.group_rel.append(group_model)
|
||||
db.flush()
|
||||
|
|
|
|||
71
src/main.py
71
src/main.py
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Application root file: Inits the FastAPI application
|
||||
"""
|
||||
|
||||
import os.path
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
|
@ -19,43 +20,43 @@ from src.auth.service import get_current_user, get_dev_user
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_application: FastAPI) -> AsyncGenerator:
|
||||
# Startup
|
||||
yield
|
||||
# Shutdown
|
||||
# Startup
|
||||
yield
|
||||
# Shutdown
|
||||
|
||||
|
||||
if settings.ENVIRONMENT.is_deployed:
|
||||
# Just a precaution, should be False anyway
|
||||
settings.DISABLE_AUTH = False
|
||||
# Just a precaution, should be False anyway
|
||||
settings.DISABLE_AUTH = False
|
||||
|
||||
|
||||
tags_metadata = [
|
||||
{
|
||||
"name": "User",
|
||||
"description": "User related operations, includes getting information about the current user",
|
||||
},
|
||||
{
|
||||
"name": "Organisation",
|
||||
"description": "Organisation related operations, includes getting lists of users etc associated with orgs",
|
||||
},
|
||||
{
|
||||
"name": "Service",
|
||||
"description": "Services related operations, includes registering services and reissuing API keys",
|
||||
},
|
||||
{
|
||||
"name": "IAM",
|
||||
"description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.",
|
||||
},
|
||||
{
|
||||
"name": "User",
|
||||
"description": "User related operations, includes getting information about the current user",
|
||||
},
|
||||
{
|
||||
"name": "Organisation",
|
||||
"description": "Organisation related operations, includes getting lists of users etc associated with orgs",
|
||||
},
|
||||
{
|
||||
"name": "Service",
|
||||
"description": "Services related operations, includes registering services and reissuing API keys",
|
||||
},
|
||||
{
|
||||
"name": "IAM",
|
||||
"description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
swagger_ui_init_oauth={
|
||||
"clientId": auth_settings.CLIENT_ID,
|
||||
"usePkceWithAuthorizationCodeGrant": True,
|
||||
"scopes": "openid profile email",
|
||||
},
|
||||
openapi_tags=tags_metadata,
|
||||
swagger_ui_init_oauth={
|
||||
"clientId": auth_settings.CLIENT_ID,
|
||||
"usePkceWithAuthorizationCodeGrant": True,
|
||||
"scopes": "openid profile email",
|
||||
},
|
||||
openapi_tags=tags_metadata,
|
||||
)
|
||||
|
||||
# Type inspection disabled for middleware injection.
|
||||
|
|
@ -64,19 +65,19 @@ app = FastAPI(
|
|||
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value())
|
||||
# noinspection PyTypeChecker
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_origin_regex=settings.CORS_ORIGINS_REGEX,
|
||||
allow_credentials=True,
|
||||
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
|
||||
allow_headers=settings.CORS_HEADERS,
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_origin_regex=settings.CORS_ORIGINS_REGEX,
|
||||
allow_credentials=True,
|
||||
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
|
||||
allow_headers=settings.CORS_HEADERS,
|
||||
)
|
||||
|
||||
if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL):
|
||||
app.dependency_overrides[get_current_user] = get_dev_user
|
||||
app.dependency_overrides[get_current_user] = get_dev_user
|
||||
|
||||
|
||||
app.include_router(api_router)
|
||||
|
||||
if os.path.exists("/app/static"):
|
||||
app.frontend("/ui", directory="/app/static", fallback="index.html")
|
||||
app.frontend("/ui", directory="/app/static", fallback="index.html")
|
||||
|
|
|
|||
|
|
@ -10,28 +10,28 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|||
|
||||
|
||||
class CustomBase(DeclarativeBase):
|
||||
type_annotation_map = {
|
||||
datetime: DateTime(timezone=True),
|
||||
dict[str, Any]: JSON,
|
||||
}
|
||||
type_annotation_map = {
|
||||
datetime: DateTime(timezone=True),
|
||||
dict[str, Any]: JSON,
|
||||
}
|
||||
|
||||
|
||||
class ActivatedMixin:
|
||||
active: Mapped[bool] = mapped_column(default=True)
|
||||
active: Mapped[bool] = mapped_column(default=True)
|
||||
|
||||
|
||||
class DeletedTimestampMixin:
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
||||
|
||||
|
||||
class DescriptionMixin:
|
||||
description: Mapped[str]
|
||||
description: Mapped[str]
|
||||
|
||||
|
||||
class IdMixin:
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
created_at: Mapped[datetime] = mapped_column(default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now())
|
||||
created_at: Mapped[datetime] = mapped_column(default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now())
|
||||
|
|
|
|||
|
|
@ -10,48 +10,48 @@ from enum import StrEnum, auto
|
|||
|
||||
|
||||
class Status(StrEnum):
|
||||
"""
|
||||
Enumeration of organisation statuses.
|
||||
"""
|
||||
Enumeration of organisation statuses.
|
||||
|
||||
Attributes:
|
||||
PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted.
|
||||
SUBMITTED (str): Questionnaire submitted but not approved.
|
||||
REMEDIATION (str): Questionnaire submitted but requires revisions.
|
||||
APPROVED (str): Questionnaire has been approved by an admin.
|
||||
REJECTED (str): Questionnaire has been rejected by an admin.
|
||||
REMOVED (str): Organisation has been removed.
|
||||
"""
|
||||
Attributes:
|
||||
PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted.
|
||||
SUBMITTED (str): Questionnaire submitted but not approved.
|
||||
REMEDIATION (str): Questionnaire submitted but requires revisions.
|
||||
APPROVED (str): Questionnaire has been approved by an admin.
|
||||
REJECTED (str): Questionnaire has been rejected by an admin.
|
||||
REMOVED (str): Organisation has been removed.
|
||||
"""
|
||||
|
||||
PARTIAL = auto()
|
||||
SUBMITTED = auto()
|
||||
REMEDIATION = auto()
|
||||
APPROVED = auto()
|
||||
REJECTED = auto()
|
||||
REMOVED = auto()
|
||||
PARTIAL = auto()
|
||||
SUBMITTED = auto()
|
||||
REMEDIATION = auto()
|
||||
APPROVED = auto()
|
||||
REJECTED = auto()
|
||||
REMOVED = auto()
|
||||
|
||||
@property
|
||||
def is_pre_approval(self):
|
||||
return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION)
|
||||
@property
|
||||
def is_pre_approval(self):
|
||||
return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION)
|
||||
|
||||
@property
|
||||
def is_pre_submission(self):
|
||||
return self in (self.PARTIAL, self.REMEDIATION)
|
||||
@property
|
||||
def is_pre_submission(self):
|
||||
return self in (self.PARTIAL, self.REMEDIATION)
|
||||
|
||||
@property
|
||||
def is_blocked(self):
|
||||
return self in (self.REMOVED, self.REJECTED)
|
||||
@property
|
||||
def is_blocked(self):
|
||||
return self in (self.REMOVED, self.REJECTED)
|
||||
|
||||
|
||||
class ContactType(StrEnum):
|
||||
"""
|
||||
Enumeration of organisation contact types.
|
||||
"""
|
||||
Enumeration of organisation contact types.
|
||||
|
||||
Attributes:
|
||||
BILLING(str): Billing contact.
|
||||
SECURITY (str): Security contact.
|
||||
OWNER (str): Owner contact.
|
||||
"""
|
||||
Attributes:
|
||||
BILLING(str): Billing contact.
|
||||
SECURITY (str): Security contact.
|
||||
OWNER (str): Owner contact.
|
||||
"""
|
||||
|
||||
BILLING = auto()
|
||||
SECURITY = auto()
|
||||
OWNER = auto()
|
||||
BILLING = auto()
|
||||
SECURITY = auto()
|
||||
OWNER = auto()
|
||||
|
|
|
|||
|
|
@ -18,25 +18,25 @@ from src.organisation.exceptions import OrgNotFoundException
|
|||
|
||||
|
||||
def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org:
|
||||
org_model = db.get(Org, org_id)
|
||||
if org_model is None:
|
||||
raise OrgNotFoundException(org_id)
|
||||
return org_model
|
||||
org_model = db.get(Org, org_id)
|
||||
if org_model is None:
|
||||
raise OrgNotFoundException(org_id)
|
||||
return org_model
|
||||
|
||||
|
||||
org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)]
|
||||
|
||||
|
||||
def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org:
|
||||
org_id: Optional[int] = getattr(request_model, "organisation_id", None)
|
||||
if org_id is None:
|
||||
raise OrgNotFoundException()
|
||||
org_id: Optional[int] = getattr(request_model, "organisation_id", None)
|
||||
if org_id is None:
|
||||
raise OrgNotFoundException()
|
||||
|
||||
org_model = db.get(Org, org_id)
|
||||
if org_model is None:
|
||||
raise OrgNotFoundException(org_id)
|
||||
org_model = db.get(Org, org_id)
|
||||
if org_model is None:
|
||||
raise OrgNotFoundException(org_id)
|
||||
|
||||
return org_model
|
||||
return org_model
|
||||
|
||||
|
||||
org_model_body_dependency = Annotated[Org, Depends(get_org_model_body)]
|
||||
|
|
|
|||
|
|
@ -12,26 +12,26 @@ from fastapi import HTTPException, status
|
|||
|
||||
|
||||
class OrgNotFoundException(HTTPException):
|
||||
def __init__(self, org_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"Organisation not found"
|
||||
if org_id is None
|
||||
else f"Organisation with ID '{org_id}' was not found."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
def __init__(self, org_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"Organisation not found"
|
||||
if org_id is None
|
||||
else f"Organisation with ID '{org_id}' was not found."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
class AwaitingApprovalException(HTTPException):
|
||||
def __init__(self, org_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"Organisation has not been approved."
|
||||
if org_id is None
|
||||
else f"Organisation with ID '{org_id}' has not been approved."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=detail,
|
||||
)
|
||||
def __init__(self, org_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"Organisation has not been approved."
|
||||
if org_id is None
|
||||
else f"Organisation with ID '{org_id}' has not been approved."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=detail,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,51 +25,51 @@ from src.models import CustomBase, TimestampMixin
|
|||
|
||||
|
||||
class Organisation(CustomBase, IdMixin, TimestampMixin, DeletedTimestampMixin):
|
||||
__tablename__ = "organisation"
|
||||
__tablename__ = "organisation"
|
||||
|
||||
name: Mapped[str]
|
||||
status: Mapped[str] = mapped_column(default="partial")
|
||||
intake_questionnaire: Mapped[dict[str, Any] | None]
|
||||
name: Mapped[str]
|
||||
status: Mapped[str] = mapped_column(default="partial")
|
||||
intake_questionnaire: Mapped[dict[str, Any] | None]
|
||||
|
||||
root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
|
||||
billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
|
||||
security_contact_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("contact.id"), nullable=True
|
||||
)
|
||||
owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
|
||||
root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
|
||||
billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
|
||||
security_contact_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("contact.id"), nullable=True
|
||||
)
|
||||
owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
|
||||
|
||||
user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel")
|
||||
user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel")
|
||||
|
||||
group_rel = relationship(
|
||||
"Group", back_populates="org_rel", cascade="all, delete-orphan"
|
||||
)
|
||||
root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id")
|
||||
group_rel = relationship(
|
||||
"Group", back_populates="org_rel", cascade="all, delete-orphan"
|
||||
)
|
||||
root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id")
|
||||
|
||||
billing_contact_rel = relationship(
|
||||
"Contact", foreign_keys="Organisation.billing_contact_id"
|
||||
)
|
||||
security_contact_rel = relationship(
|
||||
"Contact", foreign_keys="Organisation.security_contact_id"
|
||||
)
|
||||
owner_contact_rel = relationship(
|
||||
"Contact", foreign_keys="Organisation.owner_contact_id"
|
||||
)
|
||||
billing_contact_rel = relationship(
|
||||
"Contact", foreign_keys="Organisation.billing_contact_id"
|
||||
)
|
||||
security_contact_rel = relationship(
|
||||
"Contact", foreign_keys="Organisation.security_contact_id"
|
||||
)
|
||||
owner_contact_rel = relationship(
|
||||
"Contact", foreign_keys="Organisation.owner_contact_id"
|
||||
)
|
||||
|
||||
permission_rel = relationship(
|
||||
"Permission", secondary="org_permissions", back_populates="org_rel"
|
||||
)
|
||||
permission_rel = relationship(
|
||||
"Permission", secondary="org_permissions", back_populates="org_rel"
|
||||
)
|
||||
|
||||
@property
|
||||
def root_user_email(self) -> str:
|
||||
return self.root_user_rel.email if self.root_user_rel else ""
|
||||
@property
|
||||
def root_user_email(self) -> str:
|
||||
return self.root_user_rel.email if self.root_user_rel else ""
|
||||
|
||||
|
||||
class OrgUsers(CustomBase):
|
||||
__tablename__ = "orgusers"
|
||||
__tablename__ = "orgusers"
|
||||
|
||||
org_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
org_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -12,139 +12,139 @@ from datetime import datetime
|
|||
from pydantic import EmailStr, ConfigDict, Field
|
||||
|
||||
from src.schemas import (
|
||||
CustomBaseModel,
|
||||
OrgIDMixin,
|
||||
UserIDMixin,
|
||||
GroupSummary,
|
||||
OrgSummary,
|
||||
UserSummary,
|
||||
CustomBaseModel,
|
||||
OrgIDMixin,
|
||||
UserIDMixin,
|
||||
GroupSummary,
|
||||
OrgSummary,
|
||||
UserSummary,
|
||||
)
|
||||
from src.contact.schemas import ContactModel
|
||||
|
||||
from src.organisation.constants import Status, ContactType
|
||||
from src.organisation.schemas_questionnaires import (
|
||||
QuestionnaireQuestionsVersion0 as CurrentQuestions,
|
||||
questionnaire_union,
|
||||
QuestionnaireQuestionsVersion0 as CurrentQuestions,
|
||||
questionnaire_union,
|
||||
)
|
||||
|
||||
|
||||
class QuestionnaireMetadata(CustomBaseModel):
|
||||
version: int
|
||||
submission_date: Optional[datetime] = None
|
||||
version: int
|
||||
submission_date: Optional[datetime] = None
|
||||
|
||||
|
||||
class Questionnaire(CustomBaseModel):
|
||||
metadata: QuestionnaireMetadata
|
||||
questions: questionnaire_union
|
||||
metadata: QuestionnaireMetadata
|
||||
questions: questionnaire_union
|
||||
|
||||
|
||||
class ContactSummary(CustomBaseModel):
|
||||
id: int
|
||||
email: Optional[EmailStr] = None
|
||||
id: int
|
||||
email: Optional[EmailStr] = None
|
||||
|
||||
|
||||
class OrgSchema(OrgIDMixin):
|
||||
name: str
|
||||
status: Status
|
||||
root_user_email: EmailStr
|
||||
intake_questionnaire: Optional[Questionnaire] = None
|
||||
name: str
|
||||
status: Status
|
||||
root_user_email: EmailStr
|
||||
intake_questionnaire: Optional[Questionnaire] = None
|
||||
|
||||
billing_contact: ContactSummary
|
||||
owner_contact: ContactSummary
|
||||
security_contact: ContactSummary
|
||||
billing_contact: ContactSummary
|
||||
owner_contact: ContactSummary
|
||||
security_contact: ContactSummary
|
||||
|
||||
|
||||
class OrgPostOrgRequest(CustomBaseModel):
|
||||
name: str = Field(min_length=3)
|
||||
intake_questionnaire: Optional[CurrentQuestions] = None
|
||||
name: str = Field(min_length=3)
|
||||
intake_questionnaire: Optional[CurrentQuestions] = None
|
||||
|
||||
|
||||
class OrgPostOrgResponse(CustomBaseModel):
|
||||
id: int
|
||||
name: str
|
||||
status: Status
|
||||
id: int
|
||||
name: str
|
||||
status: Status
|
||||
|
||||
|
||||
class OrgPatchQuestionnaireRequest(OrgIDMixin):
|
||||
intake_questionnaire: CurrentQuestions
|
||||
partial: bool
|
||||
intake_questionnaire: CurrentQuestions
|
||||
partial: bool
|
||||
|
||||
|
||||
class OrgPatchQuestionnaireResponse(CustomBaseModel):
|
||||
id: int
|
||||
name: str
|
||||
intake_questionnaire: Questionnaire
|
||||
status: Status
|
||||
id: int
|
||||
name: str
|
||||
intake_questionnaire: Questionnaire
|
||||
status: Status
|
||||
|
||||
|
||||
class OrgPatchStatusRequest(OrgIDMixin):
|
||||
status: Status
|
||||
status: Status
|
||||
|
||||
|
||||
class OrgPatchStatusResponse(CustomBaseModel):
|
||||
id: int
|
||||
name: str
|
||||
status: Status
|
||||
id: int
|
||||
name: str
|
||||
status: Status
|
||||
|
||||
|
||||
class OrgPatchContactRequest(OrgIDMixin):
|
||||
contact_type: ContactType
|
||||
contact_type: ContactType
|
||||
|
||||
email: Optional[EmailStr] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phonenumber: Optional[str] = None
|
||||
vat_number: Optional[str] = None
|
||||
post_office_box_number: Optional[str] = None
|
||||
street_address: Optional[str] = None
|
||||
street_address_line_2: Optional[str] = None
|
||||
locality: Optional[str] = None
|
||||
address_region: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
postal_code: Optional[str] = None
|
||||
email: Optional[EmailStr] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phonenumber: Optional[str] = None
|
||||
vat_number: Optional[str] = None
|
||||
post_office_box_number: Optional[str] = None
|
||||
street_address: Optional[str] = None
|
||||
street_address_line_2: Optional[str] = None
|
||||
locality: Optional[str] = None
|
||||
address_region: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
postal_code: Optional[str] = None
|
||||
|
||||
|
||||
class OrgPostUserRequest(OrgIDMixin, UserIDMixin):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class OrgPostUserResponse(CustomBaseModel):
|
||||
organisation: OrgSummary
|
||||
users: list[UserSummary]
|
||||
organisation: OrgSummary
|
||||
users: list[UserSummary]
|
||||
|
||||
|
||||
class OrgPatchRootRequest(OrgIDMixin, UserIDMixin):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class OrgPatchRootResponse(CustomBaseModel):
|
||||
name: str
|
||||
root_user_email: str
|
||||
name: str
|
||||
root_user_email: str
|
||||
|
||||
|
||||
class OrgGetUserResponse(CustomBaseModel):
|
||||
users: list[dict[str, str | int]]
|
||||
organisation: OrgSummary
|
||||
users: list[dict[str, str | int]]
|
||||
organisation: OrgSummary
|
||||
|
||||
|
||||
class OrgGetGroupResponse(CustomBaseModel):
|
||||
organisation: OrgSummary
|
||||
groups: list[GroupSummary]
|
||||
organisation: OrgSummary
|
||||
groups: list[GroupSummary]
|
||||
|
||||
|
||||
class OrgGetContactResponse(CustomBaseModel):
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
|
||||
contact: ContactModel
|
||||
organisation: OrgSummary
|
||||
contact: ContactModel
|
||||
organisation: OrgSummary
|
||||
|
||||
|
||||
class OrgPatchContactResponse(CustomBaseModel):
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
|
||||
contact: ContactModel
|
||||
organisation: OrgSummary
|
||||
contact: ContactModel
|
||||
organisation: OrgSummary
|
||||
|
||||
|
||||
class OrgGetOrgResponse(CustomBaseModel):
|
||||
organisations: list[OrgSchema]
|
||||
organisations: list[OrgSchema]
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ from src.schemas import CustomBaseModel
|
|||
|
||||
|
||||
class QuestionnaireQuestions(CustomBaseModel):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class QuestionnaireQuestionsVersion0(QuestionnaireQuestions):
|
||||
question_one: Optional[str] = None
|
||||
question_two: Optional[str] = None
|
||||
question_three: Optional[str] = None
|
||||
question_one: Optional[str] = None
|
||||
question_two: Optional[str] = None
|
||||
question_three: Optional[str] = None
|
||||
|
||||
|
||||
questionnaire_union = QuestionnaireQuestionsVersion0 # | QuestionnaireQuestionsVersion1
|
||||
|
|
|
|||
|
|
@ -11,57 +11,57 @@ from src.user.models import User
|
|||
|
||||
|
||||
async def add_default_org_permissions(
|
||||
db: Session,
|
||||
org_model: Org,
|
||||
perm_list: list[int],
|
||||
db: Session,
|
||||
org_model: Org,
|
||||
perm_list: list[int],
|
||||
):
|
||||
for permission in perm_list:
|
||||
perm_model = db.get(Perm, permission)
|
||||
for permission in perm_list:
|
||||
perm_model = db.get(Perm, permission)
|
||||
|
||||
if perm_model is None:
|
||||
continue
|
||||
if perm_model is None:
|
||||
continue
|
||||
|
||||
if perm_model in org_model.permission_rel:
|
||||
continue
|
||||
if perm_model in org_model.permission_rel:
|
||||
continue
|
||||
|
||||
org_model.permission_rel.append(perm_model)
|
||||
db.flush()
|
||||
org_model.permission_rel.append(perm_model)
|
||||
db.flush()
|
||||
|
||||
db.commit()
|
||||
db.commit()
|
||||
|
||||
|
||||
async def assign_defaults(
|
||||
db: Session,
|
||||
org_id: int,
|
||||
user_id: int,
|
||||
db: Session,
|
||||
org_id: int,
|
||||
user_id: int,
|
||||
):
|
||||
default_org_permissions = []
|
||||
default_org_permissions = []
|
||||
|
||||
default_user_permissions = []
|
||||
default_user_permissions = []
|
||||
|
||||
org_model = db.get(Org, org_id)
|
||||
if org_model is None:
|
||||
print("Org not found while adding defaults")
|
||||
return
|
||||
org_model = db.get(Org, org_id)
|
||||
if org_model is None:
|
||||
print("Org not found while adding defaults")
|
||||
return
|
||||
|
||||
user_model = db.get(User, user_id)
|
||||
if user_model is None:
|
||||
print("User not found while adding defaults")
|
||||
return
|
||||
user_model = db.get(User, user_id)
|
||||
if user_model is None:
|
||||
print("User not found while adding defaults")
|
||||
return
|
||||
|
||||
await add_default_org_permissions(db, org_model, default_org_permissions)
|
||||
await assign_default_group(
|
||||
db=db,
|
||||
org_model=org_model,
|
||||
user_model=user_model,
|
||||
group_name="Default Users",
|
||||
perm_list=default_user_permissions,
|
||||
)
|
||||
await assign_default_group(
|
||||
db=db,
|
||||
org_model=org_model,
|
||||
user_model=user_model,
|
||||
group_name="Root User",
|
||||
perm_list=default_org_permissions,
|
||||
)
|
||||
db.commit()
|
||||
await add_default_org_permissions(db, org_model, default_org_permissions)
|
||||
await assign_default_group(
|
||||
db=db,
|
||||
org_model=org_model,
|
||||
user_model=user_model,
|
||||
group_name="Default Users",
|
||||
perm_list=default_user_permissions,
|
||||
)
|
||||
await assign_default_group(
|
||||
db=db,
|
||||
org_model=org_model,
|
||||
user_model=user_model,
|
||||
group_name="Root User",
|
||||
perm_list=default_org_permissions,
|
||||
)
|
||||
db.commit()
|
||||
|
|
|
|||
|
|
@ -11,54 +11,54 @@ from typing import Optional
|
|||
|
||||
|
||||
class CustomBaseModel(BaseModel):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
### Mixins ###
|
||||
class OrgIDMixin(CustomBaseModel):
|
||||
organisation_id: int = Field(gt=0)
|
||||
organisation_id: int = Field(gt=0)
|
||||
|
||||
|
||||
class GroupIDMixin(CustomBaseModel):
|
||||
group_id: int = Field(gt=0)
|
||||
group_id: int = Field(gt=0)
|
||||
|
||||
|
||||
class PermIDMixin(CustomBaseModel):
|
||||
permission_id: int = Field(gt=0)
|
||||
permission_id: int = Field(gt=0)
|
||||
|
||||
|
||||
class ServiceIDMixin(CustomBaseModel):
|
||||
service_id: int = Field(gt=0)
|
||||
service_id: int = Field(gt=0)
|
||||
|
||||
|
||||
class UserIDMixin(CustomBaseModel):
|
||||
user_id: int = Field(gt=0)
|
||||
user_id: int = Field(gt=0)
|
||||
|
||||
|
||||
class ServiceNameMixin(CustomBaseModel):
|
||||
service: str
|
||||
service: str
|
||||
|
||||
|
||||
class OrgSummary(CustomBaseModel):
|
||||
id: int
|
||||
name: str
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class GroupSummary(CustomBaseModel):
|
||||
id: int
|
||||
name: str
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class UserSummary(CustomBaseModel):
|
||||
id: int
|
||||
email: str
|
||||
id: int
|
||||
email: str
|
||||
|
||||
|
||||
class ServiceSummary(CustomBaseModel):
|
||||
id: int
|
||||
name: str
|
||||
id: int
|
||||
name: str
|
||||
|
||||
|
||||
class ResourceName(ServiceNameMixin, OrgIDMixin):
|
||||
resource: str
|
||||
instance: Optional[str] = None
|
||||
resource: str
|
||||
instance: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -16,25 +16,23 @@ from src.service.models import Service
|
|||
from src.service.schemas import ServiceIDMixin
|
||||
|
||||
|
||||
async def get_service_model_query(
|
||||
db: DbSession, service_id: Annotated[int, Query(gt=0)]
|
||||
):
|
||||
service_model = db.get(Service, service_id)
|
||||
if service_model is None:
|
||||
raise ServiceNotFoundException(service_id=service_id)
|
||||
async def get_service_model_query(db: DbSession, service_id: Annotated[int, Query(gt=0)]):
|
||||
service_model = db.get(Service, service_id)
|
||||
if service_model is None:
|
||||
raise ServiceNotFoundException(service_id=service_id)
|
||||
|
||||
return service_model
|
||||
return service_model
|
||||
|
||||
|
||||
service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)]
|
||||
|
||||
|
||||
async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin):
|
||||
service_model = db.get(Service, request_model.service_id)
|
||||
if service_model is None:
|
||||
raise ServiceNotFoundException(service_id=request_model.service_id)
|
||||
service_model = db.get(Service, request_model.service_id)
|
||||
if service_model is None:
|
||||
raise ServiceNotFoundException(service_id=request_model.service_id)
|
||||
|
||||
return service_model
|
||||
return service_model
|
||||
|
||||
|
||||
service_model_body_dependency = Annotated[Service, Depends(get_service_model_body)]
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@ from fastapi import HTTPException, status
|
|||
|
||||
|
||||
class ServiceNotFoundException(HTTPException):
|
||||
def __init__(self, service_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"Service not found"
|
||||
if service_id is None
|
||||
else f"Service with ID '{service_id}' was not found."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
def __init__(self, service_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"Service not found"
|
||||
if service_id is None
|
||||
else f"Service with ID '{service_id}' was not found."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,11 +12,11 @@ from src.models import CustomBase, IdMixin
|
|||
|
||||
|
||||
class Service(CustomBase, IdMixin):
|
||||
__tablename__ = "service"
|
||||
__tablename__ = "service"
|
||||
|
||||
name: Mapped[str] = mapped_column(unique=True)
|
||||
api_key: Mapped[str]
|
||||
name: Mapped[str] = mapped_column(unique=True)
|
||||
api_key: Mapped[str]
|
||||
|
||||
permission_rel = relationship(
|
||||
"Permission", back_populates="service_rel", cascade="all, delete-orphan"
|
||||
)
|
||||
permission_rel = relationship(
|
||||
"Permission", back_populates="service_rel", cascade="all, delete-orphan"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ from psycopg.errors import UniqueViolation
|
|||
from src.exceptions import ConflictException
|
||||
from src.database import DbSession
|
||||
from src.auth.dependencies import (
|
||||
super_admin_dependency,
|
||||
org_model_root_claim_query_dependency,
|
||||
super_admin_dependency,
|
||||
org_model_root_claim_query_dependency,
|
||||
)
|
||||
from src.iam.service import service_key_dependency
|
||||
from src.iam.models import Permission as Perm
|
||||
|
|
@ -25,212 +25,210 @@ from src.service.exceptions import ServiceNotFoundException
|
|||
from src.service.models import Service
|
||||
from src.service.utils import generate_api_key
|
||||
from src.service.dependencies import (
|
||||
service_model_body_dependency,
|
||||
service_model_query_dependency,
|
||||
service_model_body_dependency,
|
||||
service_model_query_dependency,
|
||||
)
|
||||
from src.service.schemas import (
|
||||
ServiceGetServiceResponse,
|
||||
ServicePostServiceRequest,
|
||||
ServicePostServiceResponse,
|
||||
ServiceWithKeySchema,
|
||||
ServicePatchKeyResponse,
|
||||
ServicePatchKeyRequest,
|
||||
ServicePostPermissionsResponse,
|
||||
ServicePostPermissionsRequest,
|
||||
ServiceGetServiceResponse,
|
||||
ServicePostServiceRequest,
|
||||
ServicePostServiceResponse,
|
||||
ServiceWithKeySchema,
|
||||
ServicePatchKeyResponse,
|
||||
ServicePatchKeyRequest,
|
||||
ServicePostPermissionsResponse,
|
||||
ServicePostPermissionsRequest,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
tags=["Service"],
|
||||
prefix="/service",
|
||||
tags=["Service"],
|
||||
prefix="/service",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
summary="Get all services",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=ServiceGetServiceResponse,
|
||||
responses={
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
"description": "Unauthorized",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"examples": {
|
||||
"awaiting_approval": {
|
||||
"summary": "Organisation has not yet been approved."
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
status.HTTP_403_FORBIDDEN: {
|
||||
"description": "Forbidden",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"examples": {
|
||||
"not_root": {"summary": "Not authorised. Must be root user."},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
"",
|
||||
summary="Get all services",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=ServiceGetServiceResponse,
|
||||
responses={
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
"description": "Unauthorized",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"examples": {
|
||||
"awaiting_approval": {
|
||||
"summary": "Organisation has not yet been approved."
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
status.HTTP_403_FORBIDDEN: {
|
||||
"description": "Forbidden",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"examples": {
|
||||
"not_root": {"summary": "Not authorised. Must be root user."},
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
async def get_all_services(
|
||||
db: DbSession, org_model: org_model_root_claim_query_dependency
|
||||
):
|
||||
"""
|
||||
Returns the ID and name of all services registered to the hub.
|
||||
"""
|
||||
permission_models = db.query(Service).all()
|
||||
async def get_all_services(db: DbSession, org_model: org_model_root_claim_query_dependency):
|
||||
"""
|
||||
Returns the ID and name of all services registered to the hub.
|
||||
"""
|
||||
permission_models = db.query(Service).all()
|
||||
|
||||
return {"services": permission_models}
|
||||
return {"services": permission_models}
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
summary="Register a new service.",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=ServicePostServiceResponse,
|
||||
responses={
|
||||
status.HTTP_200_OK: {"description": "Successfully registered a new service"},
|
||||
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
|
||||
status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"},
|
||||
},
|
||||
"",
|
||||
summary="Register a new service.",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=ServicePostServiceResponse,
|
||||
responses={
|
||||
status.HTTP_200_OK: {"description": "Successfully registered a new service"},
|
||||
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
|
||||
status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"},
|
||||
},
|
||||
)
|
||||
async def register_service(
|
||||
db: DbSession,
|
||||
su: super_admin_dependency,
|
||||
request_model: ServicePostServiceRequest,
|
||||
db: DbSession,
|
||||
su: super_admin_dependency,
|
||||
request_model: ServicePostServiceRequest,
|
||||
):
|
||||
"""
|
||||
Registers a new service to the hub, generating and returning an API key for it.
|
||||
"""
|
||||
key = generate_api_key()
|
||||
service_model = Service(name=request_model.name, api_key=key)
|
||||
"""
|
||||
Registers a new service to the hub, generating and returning an API key for it.
|
||||
"""
|
||||
key = generate_api_key()
|
||||
service_model = Service(name=request_model.name, api_key=key)
|
||||
|
||||
db.add(service_model)
|
||||
try:
|
||||
db.flush()
|
||||
except IntegrityError as e:
|
||||
if (
|
||||
isinstance(e.orig, UniqueViolation) # Postgres unique violation
|
||||
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
|
||||
):
|
||||
raise ConflictException(message="Service with this name already exists")
|
||||
raise
|
||||
response = ServiceWithKeySchema(**service_model.__dict__)
|
||||
db.commit()
|
||||
return {"service": response}
|
||||
db.add(service_model)
|
||||
try:
|
||||
db.flush()
|
||||
except IntegrityError as e:
|
||||
if (
|
||||
isinstance(e.orig, UniqueViolation) # Postgres unique violation
|
||||
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
|
||||
):
|
||||
raise ConflictException(message="Service with this name already exists")
|
||||
raise
|
||||
response = ServiceWithKeySchema(**service_model.__dict__)
|
||||
db.commit()
|
||||
return {"service": response}
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/key",
|
||||
summary="Regenerate service API key.",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=ServicePatchKeyResponse,
|
||||
responses={
|
||||
status.HTTP_200_OK: {"description": "Successful update of API key"},
|
||||
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
|
||||
},
|
||||
"/key",
|
||||
summary="Regenerate service API key.",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=ServicePatchKeyResponse,
|
||||
responses={
|
||||
status.HTTP_200_OK: {"description": "Successful update of API key"},
|
||||
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
|
||||
},
|
||||
)
|
||||
async def regenerate_api_key(
|
||||
db: DbSession,
|
||||
su: super_admin_dependency,
|
||||
service_model: service_model_body_dependency,
|
||||
request_model: ServicePatchKeyRequest,
|
||||
db: DbSession,
|
||||
su: super_admin_dependency,
|
||||
service_model: service_model_body_dependency,
|
||||
request_model: ServicePatchKeyRequest,
|
||||
):
|
||||
"""
|
||||
Generates and returns a new API key for the service to access the hub.
|
||||
"""
|
||||
key = generate_api_key()
|
||||
service_model.api_key = key
|
||||
"""
|
||||
Generates and returns a new API key for the service to access the hub.
|
||||
"""
|
||||
key = generate_api_key()
|
||||
service_model.api_key = key
|
||||
|
||||
db.flush()
|
||||
response = ServiceWithKeySchema(**service_model.__dict__)
|
||||
db.commit()
|
||||
return {"service": response}
|
||||
db.flush()
|
||||
response = ServiceWithKeySchema(**service_model.__dict__)
|
||||
db.commit()
|
||||
return {"service": response}
|
||||
|
||||
|
||||
@router.delete(
|
||||
"",
|
||||
summary="Remove a service.",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses={
|
||||
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"},
|
||||
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
|
||||
},
|
||||
"",
|
||||
summary="Remove a service.",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses={
|
||||
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"},
|
||||
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
|
||||
},
|
||||
)
|
||||
async def remove_service(
|
||||
db: DbSession,
|
||||
service_model: service_model_query_dependency,
|
||||
su: super_admin_dependency,
|
||||
db: DbSession,
|
||||
service_model: service_model_query_dependency,
|
||||
su: super_admin_dependency,
|
||||
):
|
||||
"""
|
||||
Removes a service from the hub.
|
||||
"""
|
||||
db.delete(service_model)
|
||||
db.commit()
|
||||
"""
|
||||
Removes a service from the hub.
|
||||
"""
|
||||
db.delete(service_model)
|
||||
db.commit()
|
||||
|
||||
|
||||
@router.post(
|
||||
path="/permissions",
|
||||
summary="Service endpoint for creating its own permissions.",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=ServicePostPermissionsResponse,
|
||||
responses={
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
"description": "API Key missing or invalid | Issue verifying user OIDC claims"
|
||||
},
|
||||
},
|
||||
path="/permissions",
|
||||
summary="Service endpoint for creating its own permissions.",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=ServicePostPermissionsResponse,
|
||||
responses={
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
"description": "API Key missing or invalid | Issue verifying user OIDC claims"
|
||||
},
|
||||
},
|
||||
)
|
||||
async def service_create_new_permissions(
|
||||
db: DbSession,
|
||||
request_model: ServicePostPermissionsRequest,
|
||||
valid_key: service_key_dependency,
|
||||
db: DbSession,
|
||||
request_model: ServicePostPermissionsRequest,
|
||||
valid_key: service_key_dependency,
|
||||
):
|
||||
"""
|
||||
Allows a service to register its own set of permissions.
|
||||
"""
|
||||
service_model = (
|
||||
db.query(Service).filter(Service.name == request_model.rn.service).first()
|
||||
)
|
||||
if service_model is None:
|
||||
raise ServiceNotFoundException()
|
||||
else:
|
||||
service_id = service_model.id
|
||||
response_list = []
|
||||
for new_permission in request_model.permissions:
|
||||
perm_model = (
|
||||
db.query(Perm)
|
||||
.filter(Perm.service_id == service_id)
|
||||
.filter(Perm.resource == new_permission.resource)
|
||||
.filter(Perm.action == new_permission.action)
|
||||
.first()
|
||||
)
|
||||
if perm_model is not None:
|
||||
response_code = 409
|
||||
response = {
|
||||
"id": perm_model.id,
|
||||
"service_name": perm_model.service_name,
|
||||
"resource": perm_model.resource,
|
||||
"action": perm_model.action,
|
||||
}
|
||||
response_list.append((response, response_code))
|
||||
continue
|
||||
"""
|
||||
Allows a service to register its own set of permissions.
|
||||
"""
|
||||
service_model = (
|
||||
db.query(Service).filter(Service.name == request_model.rn.service).first()
|
||||
)
|
||||
if service_model is None:
|
||||
raise ServiceNotFoundException()
|
||||
else:
|
||||
service_id = service_model.id
|
||||
response_list = []
|
||||
for new_permission in request_model.permissions:
|
||||
perm_model = (
|
||||
db.query(Perm)
|
||||
.filter(Perm.service_id == service_id)
|
||||
.filter(Perm.resource == new_permission.resource)
|
||||
.filter(Perm.action == new_permission.action)
|
||||
.first()
|
||||
)
|
||||
if perm_model is not None:
|
||||
response_code = 409
|
||||
response = {
|
||||
"id": perm_model.id,
|
||||
"service_name": perm_model.service_name,
|
||||
"resource": perm_model.resource,
|
||||
"action": perm_model.action,
|
||||
}
|
||||
response_list.append((response, response_code))
|
||||
continue
|
||||
|
||||
new_perm_model = Perm(**new_permission.__dict__)
|
||||
new_perm_model.service_id = service_id
|
||||
db.add(new_perm_model)
|
||||
db.flush()
|
||||
response_code = 201
|
||||
response = {
|
||||
"id": new_perm_model.id,
|
||||
"service_name": new_perm_model.service_name,
|
||||
"resource": new_perm_model.resource,
|
||||
"action": new_perm_model.action,
|
||||
}
|
||||
response_list.append((response, response_code))
|
||||
new_perm_model = Perm(**new_permission.__dict__)
|
||||
new_perm_model.service_id = service_id
|
||||
db.add(new_perm_model)
|
||||
db.flush()
|
||||
response_code = 201
|
||||
response = {
|
||||
"id": new_perm_model.id,
|
||||
"service_name": new_perm_model.service_name,
|
||||
"resource": new_perm_model.resource,
|
||||
"action": new_perm_model.action,
|
||||
}
|
||||
response_list.append((response, response_code))
|
||||
|
||||
db.commit()
|
||||
return {"permissions": response_list}
|
||||
db.commit()
|
||||
return {"permissions": response_list}
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ from typing import Generic, TypeVar
|
|||
from pydantic import Field, ConfigDict
|
||||
|
||||
from src.schemas import (
|
||||
CustomBaseModel,
|
||||
ServiceIDMixin,
|
||||
ServiceSummary,
|
||||
ServiceNameMixin,
|
||||
CustomBaseModel,
|
||||
ServiceIDMixin,
|
||||
ServiceSummary,
|
||||
ServiceNameMixin,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -21,51 +21,51 @@ T = TypeVar("T", bound=ServiceNameMixin)
|
|||
|
||||
|
||||
class HasServiceName(CustomBaseModel, Generic[T]):
|
||||
rn: T
|
||||
rn: T
|
||||
|
||||
|
||||
class PermissionResponseSchema(CustomBaseModel):
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||
|
||||
id: int
|
||||
service_name: str
|
||||
resource: str
|
||||
action: str
|
||||
id: int
|
||||
service_name: str
|
||||
resource: str
|
||||
action: str
|
||||
|
||||
|
||||
class PermissionRequestSchema(CustomBaseModel):
|
||||
resource: str
|
||||
action: str
|
||||
resource: str
|
||||
action: str
|
||||
|
||||
|
||||
class ServiceWithKeySchema(ServiceSummary):
|
||||
api_key: str
|
||||
api_key: str
|
||||
|
||||
|
||||
class ServiceGetServiceResponse(CustomBaseModel):
|
||||
services: list[ServiceSummary]
|
||||
services: list[ServiceSummary]
|
||||
|
||||
|
||||
class ServicePostServiceRequest(CustomBaseModel):
|
||||
name: str = Field(min_length=3)
|
||||
name: str = Field(min_length=3)
|
||||
|
||||
|
||||
class ServicePostServiceResponse(CustomBaseModel):
|
||||
service: ServiceWithKeySchema
|
||||
service: ServiceWithKeySchema
|
||||
|
||||
|
||||
class ServicePatchKeyRequest(ServiceIDMixin):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class ServicePatchKeyResponse(CustomBaseModel):
|
||||
service: ServiceWithKeySchema
|
||||
service: ServiceWithKeySchema
|
||||
|
||||
|
||||
class ServicePostPermissionsRequest(CustomBaseModel):
|
||||
rn: ServiceNameMixin
|
||||
permissions: list[PermissionRequestSchema]
|
||||
rn: ServiceNameMixin
|
||||
permissions: list[PermissionRequestSchema]
|
||||
|
||||
|
||||
class ServicePostPermissionsResponse(CustomBaseModel):
|
||||
permissions: list[tuple[PermissionResponseSchema, int]]
|
||||
permissions: list[tuple[PermissionResponseSchema, int]]
|
||||
|
|
|
|||
|
|
@ -9,4 +9,4 @@ import uuid
|
|||
|
||||
|
||||
def generate_api_key() -> str:
|
||||
return str(uuid.uuid4())
|
||||
return str(uuid.uuid4())
|
||||
|
|
|
|||
|
|
@ -19,37 +19,37 @@ from src.schemas import UserIDMixin
|
|||
|
||||
|
||||
async def get_user_model_claims(claims: claims_dependency, db: DbSession):
|
||||
user_id = claims.get("db_id", None)
|
||||
if user_id is None:
|
||||
raise UserNotFoundException()
|
||||
user_id = claims.get("db_id", None)
|
||||
if user_id is None:
|
||||
raise UserNotFoundException()
|
||||
|
||||
user_model = db.get(User, user_id)
|
||||
if user_model is None:
|
||||
raise UserNotFoundException(user_id=user_id)
|
||||
user_model = db.get(User, user_id)
|
||||
if user_model is None:
|
||||
raise UserNotFoundException(user_id=user_id)
|
||||
|
||||
return user_model
|
||||
return user_model
|
||||
|
||||
|
||||
user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)]
|
||||
|
||||
|
||||
async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]):
|
||||
user_model = db.get(User, user_id)
|
||||
if user_model is None:
|
||||
raise UserNotFoundException(user_id=user_id)
|
||||
user_model = db.get(User, user_id)
|
||||
if user_model is None:
|
||||
raise UserNotFoundException(user_id=user_id)
|
||||
|
||||
return user_model
|
||||
return user_model
|
||||
|
||||
|
||||
user_model_query_dependency = Annotated[User, Depends(get_user_model_query)]
|
||||
|
||||
|
||||
async def get_user_model_body(db: DbSession, request_model: UserIDMixin):
|
||||
user_model = db.get(User, request_model.user_id)
|
||||
if user_model is None:
|
||||
raise UserNotFoundException(user_id=request_model.user_id)
|
||||
user_model = db.get(User, request_model.user_id)
|
||||
if user_model is None:
|
||||
raise UserNotFoundException(user_id=request_model.user_id)
|
||||
|
||||
return user_model
|
||||
return user_model
|
||||
|
||||
|
||||
user_model_body_dependency = Annotated[User, Depends(get_user_model_body)]
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@ from fastapi import HTTPException, status
|
|||
|
||||
|
||||
class UserNotFoundException(HTTPException):
|
||||
def __init__(self, user_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"User not found"
|
||||
if user_id is None
|
||||
else f"User with ID '{user_id}' was not found."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
def __init__(self, user_id: Optional[int] = None) -> None:
|
||||
detail = (
|
||||
"User not found"
|
||||
if user_id is None
|
||||
else f"User with ID '{user_id}' was not found."
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=detail,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,26 +20,26 @@ from src.models import CustomBase
|
|||
|
||||
|
||||
class User(CustomBase, IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin):
|
||||
__tablename__ = "user"
|
||||
__tablename__ = "user"
|
||||
|
||||
email: Mapped[str]
|
||||
first_name: Mapped[str]
|
||||
last_name: Mapped[str]
|
||||
oidc_id: Mapped[str] = mapped_column(index=True, unique=True)
|
||||
email: Mapped[str]
|
||||
first_name: Mapped[str]
|
||||
last_name: Mapped[str]
|
||||
oidc_id: Mapped[str] = mapped_column(index=True, unique=True)
|
||||
|
||||
organisation_rel = relationship(
|
||||
"Organisation", secondary="orgusers", back_populates="user_rel"
|
||||
)
|
||||
organisation_rel = relationship(
|
||||
"Organisation", secondary="orgusers", back_populates="user_rel"
|
||||
)
|
||||
|
||||
group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel")
|
||||
group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel")
|
||||
|
||||
@property
|
||||
def organisations(self):
|
||||
return [{"name": org.name, "id": org.id} for org in self.organisation_rel]
|
||||
@property
|
||||
def organisations(self):
|
||||
return [{"name": org.name, "id": org.id} for org in self.organisation_rel]
|
||||
|
||||
@property
|
||||
def groups(self):
|
||||
result = defaultdict(list)
|
||||
for group in self.group_rel:
|
||||
result[group.org_rel.name].append({"name": group.name, "id": group.id})
|
||||
return dict(result)
|
||||
@property
|
||||
def groups(self):
|
||||
result = defaultdict(list)
|
||||
for group in self.group_rel:
|
||||
result[group.org_rel.name].append({"name": group.name, "id": group.id})
|
||||
return dict(result)
|
||||
|
|
|
|||
|
|
@ -13,205 +13,205 @@ from fastapi import APIRouter, status, BackgroundTasks
|
|||
from src.iam.models import Group
|
||||
from src.organisation.exceptions import OrgNotFoundException
|
||||
from src.user.schemas import (
|
||||
UserResponse,
|
||||
OIDCClaims,
|
||||
UserPostInvitationRequest,
|
||||
UserPostInvitationAcceptRequest,
|
||||
UserGetSelfOrgsResponse,
|
||||
UserPostInvitationResponse,
|
||||
UserPostInvitationAcceptResponse,
|
||||
UserResponse,
|
||||
OIDCClaims,
|
||||
UserPostInvitationRequest,
|
||||
UserPostInvitationAcceptRequest,
|
||||
UserGetSelfOrgsResponse,
|
||||
UserPostInvitationResponse,
|
||||
UserPostInvitationAcceptResponse,
|
||||
)
|
||||
from src.user.dependencies import (
|
||||
user_model_claims_dependency,
|
||||
user_model_query_dependency,
|
||||
user_model_claims_dependency,
|
||||
user_model_query_dependency,
|
||||
)
|
||||
from src.user.service import send_invitation
|
||||
from src.organisation.models import Organisation as Org
|
||||
|
||||
from src.auth.dependencies import (
|
||||
super_admin_dependency,
|
||||
org_model_root_claim_body_dependency,
|
||||
super_admin_dependency,
|
||||
org_model_root_claim_body_dependency,
|
||||
)
|
||||
from src.auth.service import claims_dependency
|
||||
from src.database import DbSession
|
||||
from src.utils import verify_email_token
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/user",
|
||||
tags=["User"],
|
||||
prefix="/user",
|
||||
tags=["User"],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/self/claims",
|
||||
summary="Get current user OIDC claims.",
|
||||
response_model=OIDCClaims,
|
||||
status_code=status.HTTP_200_OK,
|
||||
responses={
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
},
|
||||
"/self/claims",
|
||||
summary="Get current user OIDC claims.",
|
||||
response_model=OIDCClaims,
|
||||
status_code=status.HTTP_200_OK,
|
||||
responses={
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
},
|
||||
)
|
||||
async def current_user_claims(user: claims_dependency):
|
||||
"""
|
||||
Returns the full OIDC claims associated with the currently logged-in user.
|
||||
"""
|
||||
user["allowed_origins"] = user.get("allowed-origins", [])
|
||||
return user
|
||||
"""
|
||||
Returns the full OIDC claims associated with the currently logged-in user.
|
||||
"""
|
||||
user["allowed_origins"] = user.get("allowed-origins", [])
|
||||
return user
|
||||
|
||||
|
||||
@router.get(
|
||||
"/self/db",
|
||||
summary="Get current user hub details.",
|
||||
response_model=UserResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
responses={
|
||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
},
|
||||
"/self/db",
|
||||
summary="Get current user hub details.",
|
||||
response_model=UserResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
responses={
|
||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
},
|
||||
)
|
||||
async def current_user(user_model: user_model_claims_dependency):
|
||||
"""
|
||||
Returns the database details associated with the currently logged-in user.
|
||||
"""
|
||||
return user_model
|
||||
"""
|
||||
Returns the database details associated with the currently logged-in user.
|
||||
"""
|
||||
return user_model
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
summary="Get user hub details by ID.",
|
||||
response_model=UserResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
responses={
|
||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
},
|
||||
"",
|
||||
summary="Get user hub details by ID.",
|
||||
response_model=UserResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
responses={
|
||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||
},
|
||||
)
|
||||
async def get_user_by_id(
|
||||
user_model: user_model_query_dependency, su: super_admin_dependency
|
||||
user_model: user_model_query_dependency, su: super_admin_dependency
|
||||
):
|
||||
"""
|
||||
Returns the database details associated with the provided user ID.
|
||||
"""
|
||||
return user_model
|
||||
"""
|
||||
Returns the database details associated with the provided user ID.
|
||||
"""
|
||||
return user_model
|
||||
|
||||
|
||||
@router.delete(
|
||||
"",
|
||||
summary="Delete user from hub by ID.",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses={
|
||||
status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
|
||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||
},
|
||||
"",
|
||||
summary="Delete user from hub by ID.",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
responses={
|
||||
status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
|
||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||
},
|
||||
)
|
||||
async def delete_user_by_id(
|
||||
db: DbSession,
|
||||
user_model: user_model_query_dependency,
|
||||
su: super_admin_dependency,
|
||||
db: DbSession,
|
||||
user_model: user_model_query_dependency,
|
||||
su: super_admin_dependency,
|
||||
):
|
||||
"""
|
||||
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
|
||||
"""
|
||||
db.delete(user_model)
|
||||
db.commit()
|
||||
"""
|
||||
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
|
||||
"""
|
||||
db.delete(user_model)
|
||||
db.commit()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/self/orgs",
|
||||
summary="Get all orgs the current user is a member of",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=UserGetSelfOrgsResponse,
|
||||
responses={},
|
||||
"/self/orgs",
|
||||
summary="Get all orgs the current user is a member of",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=UserGetSelfOrgsResponse,
|
||||
responses={},
|
||||
)
|
||||
async def get_user_orgs(user_model: user_model_claims_dependency):
|
||||
user_orgs = user_model.organisation_rel
|
||||
response = []
|
||||
for org in user_orgs:
|
||||
response.append(
|
||||
{
|
||||
"organisation_id": org.id,
|
||||
"name": org.name,
|
||||
"status": org.status,
|
||||
"intake_questionnaire": org.intake_questionnaire,
|
||||
"root_user_email": org.root_user_email,
|
||||
"billing_contact": {
|
||||
"id": org.billing_contact_id,
|
||||
"email": org.billing_contact_rel.email,
|
||||
},
|
||||
"owner_contact": {
|
||||
"id": org.owner_contact_id,
|
||||
"email": org.owner_contact_rel.email,
|
||||
},
|
||||
"security_contact": {
|
||||
"id": org.security_contact_id,
|
||||
"email": org.security_contact_rel.email,
|
||||
},
|
||||
}
|
||||
)
|
||||
user_orgs = user_model.organisation_rel
|
||||
response = []
|
||||
for org in user_orgs:
|
||||
response.append(
|
||||
{
|
||||
"organisation_id": org.id,
|
||||
"name": org.name,
|
||||
"status": org.status,
|
||||
"intake_questionnaire": org.intake_questionnaire,
|
||||
"root_user_email": org.root_user_email,
|
||||
"billing_contact": {
|
||||
"id": org.billing_contact_id,
|
||||
"email": org.billing_contact_rel.email,
|
||||
},
|
||||
"owner_contact": {
|
||||
"id": org.owner_contact_id,
|
||||
"email": org.owner_contact_rel.email,
|
||||
},
|
||||
"security_contact": {
|
||||
"id": org.security_contact_id,
|
||||
"email": org.security_contact_rel.email,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return {"organisations": response}
|
||||
return {"organisations": response}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invitation",
|
||||
summary="Send an email invitation for a user to join an org",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=UserPostInvitationResponse,
|
||||
"/invitation",
|
||||
summary="Send an email invitation for a user to join an org",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=UserPostInvitationResponse,
|
||||
)
|
||||
async def invitation(
|
||||
background_tasks: BackgroundTasks,
|
||||
org_model: org_model_root_claim_body_dependency,
|
||||
request_model: UserPostInvitationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
org_model: org_model_root_claim_body_dependency,
|
||||
request_model: UserPostInvitationRequest,
|
||||
):
|
||||
org_id = org_model.id
|
||||
org_name = org_model.name
|
||||
user_email = request_model.user_email
|
||||
org_id = org_model.id
|
||||
org_name = org_model.name
|
||||
user_email = request_model.user_email
|
||||
|
||||
background_tasks.add_task(
|
||||
send_invitation, org_id=org_id, org_name=org_name, user_email=user_email
|
||||
)
|
||||
background_tasks.add_task(
|
||||
send_invitation, org_id=org_id, org_name=org_name, user_email=user_email
|
||||
)
|
||||
|
||||
response = {
|
||||
"organisation": org_model,
|
||||
"invited_email": user_email,
|
||||
}
|
||||
response = {
|
||||
"organisation": org_model,
|
||||
"invited_email": user_email,
|
||||
}
|
||||
|
||||
return response
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invitation/accept",
|
||||
summary="Accept email invitation to join an org",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=UserPostInvitationAcceptResponse,
|
||||
"/invitation/accept",
|
||||
summary="Accept email invitation to join an org",
|
||||
status_code=status.HTTP_200_OK,
|
||||
response_model=UserPostInvitationAcceptResponse,
|
||||
)
|
||||
async def accept_invitation(
|
||||
db: DbSession,
|
||||
user_model: user_model_claims_dependency,
|
||||
request_model: UserPostInvitationAcceptRequest,
|
||||
db: DbSession,
|
||||
user_model: user_model_claims_dependency,
|
||||
request_model: UserPostInvitationAcceptRequest,
|
||||
):
|
||||
email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model)
|
||||
email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model)
|
||||
|
||||
org_model = db.get(Org, email_claims["org_id"])
|
||||
if org_model is None:
|
||||
raise OrgNotFoundException()
|
||||
org_model = db.get(Org, email_claims["org_id"])
|
||||
if org_model is None:
|
||||
raise OrgNotFoundException()
|
||||
|
||||
org_model.user_rel.append(user_model)
|
||||
db.flush()
|
||||
group_model = (
|
||||
db.query(Group)
|
||||
.filter(Group.org_id == org_model.id)
|
||||
.filter(Group.name == "Default Users")
|
||||
.first()
|
||||
)
|
||||
if group_model is not None:
|
||||
user_model.group_rel.append(group_model)
|
||||
org_model.user_rel.append(user_model)
|
||||
db.flush()
|
||||
group_model = (
|
||||
db.query(Group)
|
||||
.filter(Group.org_id == org_model.id)
|
||||
.filter(Group.name == "Default Users")
|
||||
.first()
|
||||
)
|
||||
if group_model is not None:
|
||||
user_model.group_rel.append(group_model)
|
||||
|
||||
response = {
|
||||
"organisation": org_model,
|
||||
"user": user_model,
|
||||
}
|
||||
response = {
|
||||
"organisation": org_model,
|
||||
"user": user_model,
|
||||
}
|
||||
|
||||
db.commit()
|
||||
db.commit()
|
||||
|
||||
return response
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -10,63 +10,63 @@ from src.schemas import CustomBaseModel, OrgIDMixin, OrgSummary, UserSummary
|
|||
|
||||
|
||||
class OIDCClaims(CustomBaseModel):
|
||||
exp: int
|
||||
iat: int
|
||||
auth_time: int
|
||||
jti: str
|
||||
iss: str
|
||||
aud: str
|
||||
sub: str
|
||||
typ: str
|
||||
azp: str
|
||||
sid: str
|
||||
acr: str
|
||||
allowed_origins: list[str]
|
||||
realm_access: dict[str, list[str]]
|
||||
resource_access: dict[str, dict[str, list[str]]]
|
||||
scope: str
|
||||
email_verified: bool
|
||||
name: str
|
||||
preferred_username: str
|
||||
given_name: str
|
||||
family_name: str
|
||||
email: str
|
||||
db_id: int
|
||||
exp: int
|
||||
iat: int
|
||||
auth_time: int
|
||||
jti: str
|
||||
iss: str
|
||||
aud: str
|
||||
sub: str
|
||||
typ: str
|
||||
azp: str
|
||||
sid: str
|
||||
acr: str
|
||||
allowed_origins: list[str]
|
||||
realm_access: dict[str, list[str]]
|
||||
resource_access: dict[str, dict[str, list[str]]]
|
||||
scope: str
|
||||
email_verified: bool
|
||||
name: str
|
||||
preferred_username: str
|
||||
given_name: str
|
||||
family_name: str
|
||||
email: str
|
||||
db_id: int
|
||||
|
||||
|
||||
class OIDCUser(CustomBaseModel):
|
||||
first_name: str
|
||||
last_name: str
|
||||
email: str
|
||||
oidc_id: str
|
||||
first_name: str
|
||||
last_name: str
|
||||
email: str
|
||||
oidc_id: str
|
||||
|
||||
|
||||
class UserResponse(CustomBaseModel):
|
||||
id: int
|
||||
first_name: str
|
||||
last_name: str
|
||||
email: str
|
||||
organisations: list[Optional[dict[str, str | int]]]
|
||||
groups: Optional[dict[str, list[dict[str, str | int]]]] = None
|
||||
id: int
|
||||
first_name: str
|
||||
last_name: str
|
||||
email: str
|
||||
organisations: list[Optional[dict[str, str | int]]]
|
||||
groups: Optional[dict[str, list[dict[str, str | int]]]] = None
|
||||
|
||||
|
||||
class UserPostInvitationRequest(OrgIDMixin):
|
||||
user_email: EmailStr
|
||||
user_email: EmailStr
|
||||
|
||||
|
||||
class UserPostInvitationAcceptRequest(CustomBaseModel):
|
||||
jwt: str
|
||||
jwt: str
|
||||
|
||||
|
||||
class UserGetSelfOrgsResponse(CustomBaseModel):
|
||||
organisations: list[OrgSchema]
|
||||
organisations: list[OrgSchema]
|
||||
|
||||
|
||||
class UserPostInvitationResponse(CustomBaseModel):
|
||||
organisation: OrgSummary
|
||||
invited_email: EmailStr
|
||||
organisation: OrgSummary
|
||||
invited_email: EmailStr
|
||||
|
||||
|
||||
class UserPostInvitationAcceptResponse(CustomBaseModel):
|
||||
organisation: OrgSummary
|
||||
user: UserSummary
|
||||
organisation: OrgSummary
|
||||
user: UserSummary
|
||||
|
|
|
|||
|
|
@ -16,49 +16,49 @@ from src.user.models import User
|
|||
|
||||
|
||||
async def add_user(db: Session, user_claims: dict[str, Any]) -> int:
|
||||
try:
|
||||
valid_user = OIDCUser(
|
||||
first_name=user_claims["given_name"],
|
||||
last_name=user_claims["family_name"],
|
||||
email=user_claims["email"],
|
||||
oidc_id=user_claims["sub"],
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise UnprocessableContentException("Invalid or missing OIDC data")
|
||||
try:
|
||||
valid_user = OIDCUser(
|
||||
first_name=user_claims["given_name"],
|
||||
last_name=user_claims["family_name"],
|
||||
email=user_claims["email"],
|
||||
oidc_id=user_claims["sub"],
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
raise UnprocessableContentException("Invalid or missing OIDC data")
|
||||
|
||||
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
|
||||
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
|
||||
|
||||
if not db_user:
|
||||
user_model = User(**valid_user.model_dump())
|
||||
db.add(user_model)
|
||||
user_id = user_model.id
|
||||
db.commit()
|
||||
return user_id
|
||||
if not db_user:
|
||||
user_model = User(**valid_user.model_dump())
|
||||
db.add(user_model)
|
||||
user_id = user_model.id
|
||||
db.commit()
|
||||
return user_id
|
||||
|
||||
user_id = db_user.id
|
||||
db_user.first_name = valid_user.first_name
|
||||
db_user.last_name = valid_user.last_name
|
||||
db.commit()
|
||||
return user_id
|
||||
user_id = db_user.id
|
||||
db_user.first_name = valid_user.first_name
|
||||
db_user.last_name = valid_user.last_name
|
||||
db.commit()
|
||||
return user_id
|
||||
|
||||
|
||||
async def send_invitation(user_email: str, org_name: str, org_id: int):
|
||||
expiry_delta = timedelta(hours=24)
|
||||
expiry = datetime.now(timezone.utc) + expiry_delta
|
||||
claims = {
|
||||
"email": user_email,
|
||||
"org_id": org_id,
|
||||
"exp": expiry,
|
||||
"type": "org_invite",
|
||||
}
|
||||
expiry_delta = timedelta(hours=24)
|
||||
expiry = datetime.now(timezone.utc) + expiry_delta
|
||||
claims = {
|
||||
"email": user_email,
|
||||
"org_id": org_id,
|
||||
"exp": expiry,
|
||||
"type": "org_invite",
|
||||
}
|
||||
|
||||
token = await generate_jwt(claims)
|
||||
subject = f"You have been invited to join {org_name}"
|
||||
body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
|
||||
token = await generate_jwt(claims)
|
||||
subject = f"You have been invited to join {org_name}"
|
||||
body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
|
||||
|
||||
await send_email(
|
||||
recipient=user_email,
|
||||
subject=subject,
|
||||
body=body,
|
||||
)
|
||||
await send_email(
|
||||
recipient=user_email,
|
||||
subject=subject,
|
||||
body=body,
|
||||
)
|
||||
|
|
|
|||
66
src/utils.py
66
src/utils.py
|
|
@ -11,52 +11,56 @@ KEY = jwk.import_key(settings.SECRET_KEY.get_secret_value(), "oct")
|
|||
|
||||
|
||||
async def generate_jwt(claims):
|
||||
jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims)
|
||||
jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims)
|
||||
|
||||
return jwt_token
|
||||
return jwt_token
|
||||
|
||||
|
||||
async def decode_jwt(encoded):
|
||||
try:
|
||||
token = jwt.decode(encoded, key=KEY)
|
||||
return token.claims
|
||||
except errors.DecodeError:
|
||||
raise UnauthorizedException("Invalid JWS")
|
||||
try:
|
||||
token = jwt.decode(encoded, key=KEY)
|
||||
return token.claims
|
||||
except errors.DecodeError:
|
||||
raise UnauthorizedException("Invalid JWS")
|
||||
|
||||
|
||||
async def verify_email_token(user_model, token):
|
||||
email_claims = await decode_jwt(token)
|
||||
email_claims = await decode_jwt(token)
|
||||
|
||||
claimed_email = email_claims["email"]
|
||||
claimed_email = email_claims["email"]
|
||||
|
||||
expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc)
|
||||
expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc)
|
||||
|
||||
if expiry < datetime.now(timezone.utc):
|
||||
raise UnauthorizedException("Invitation expired.")
|
||||
if expiry < datetime.now(timezone.utc):
|
||||
raise UnauthorizedException("Invitation expired.")
|
||||
|
||||
if user_model.email != claimed_email:
|
||||
raise ForbiddenException("The logged in user and email do not match.")
|
||||
if user_model.email != claimed_email:
|
||||
raise ForbiddenException("The logged in user and email do not match.")
|
||||
|
||||
return email_claims
|
||||
return email_claims
|
||||
|
||||
|
||||
async def send_email(recipient: str, subject: str, body: str):
|
||||
if settings.ENVIRONMENT.is_testing:
|
||||
return
|
||||
if settings.ENVIRONMENT.is_testing:
|
||||
return
|
||||
|
||||
lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value())
|
||||
lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value())
|
||||
|
||||
if settings.ENVIRONMENT == "local":
|
||||
recipient = "ok@testing.lettermint.co"
|
||||
if settings.ENVIRONMENT == "local":
|
||||
recipient = "ok@testing.lettermint.co"
|
||||
|
||||
try:
|
||||
response = (
|
||||
lettermint.email.from_("noreply@sr2.uk")
|
||||
.to(recipient)
|
||||
.subject(subject)
|
||||
.text(body)
|
||||
.send()
|
||||
)
|
||||
logging.info("Email sent to {} with subject {} (Status: {})".format(recipient, subject, response.status_code))
|
||||
except ValidationError as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
response = (
|
||||
lettermint.email.from_("noreply@sr2.uk")
|
||||
.to(recipient)
|
||||
.subject(subject)
|
||||
.text(body)
|
||||
.send()
|
||||
)
|
||||
logging.info(
|
||||
"Email sent to {} with subject {} (Status: {})".format(
|
||||
recipient, subject, response.status_code
|
||||
)
|
||||
)
|
||||
except ValidationError as e:
|
||||
logging.exception(e)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue