1
0
Fork 0
forked from sr2/cloud-api

Compare commits

..

34 commits

Author SHA1 Message Date
7dad2e920e tests: get_testable_routes finds auth level
Checks all dependencies used on each endpoint and determines the highest level of auth applied to each endpoint.

API Key>SU>Root>User>None
2026-06-22 16:45:50 +01:00
bee0dcd4fe feat: soft deleted users access blocked 2026-06-22 16:12:03 +01:00
a9e059bf0a feat: user soft delete 2026-06-22 15:43:38 +01:00
irl
5b98be9787 ci: define context for docker 2026-06-22 15:30:29 +01:00
be46e43042 fix(db): user active default true 2026-06-22 15:28:46 +01:00
irl
8ab0390977 ci: fix branch name tag again 2026-06-22 15:26:39 +01:00
irl
cc4ae42646 ci: adds frontend ref 2026-06-22 15:24:33 +01:00
irl
44e1d4986f ci: relative repo path 2026-06-22 15:22:42 +01:00
irl
20615f438a ci: fix branch name tag 2026-06-22 15:21:31 +01:00
irl
a481be8352 ci: check out the frontend repo 2026-06-22 15:20:14 +01:00
irl
e7bd455b2d ci: run the build step somewhere 2026-06-22 15:18:28 +01:00
4b3ab92d2a fix: fastapi 0.137 router.route changes 2026-06-22 15:15:42 +01:00
irl
ee47186c5a fix(db): generator types 2026-06-22 15:12:34 +01:00
fab228bf8f minor: ruff format
Tabs -> spaces
2026-06-22 15:04:11 +01:00
b2921b73b8 fix: conftest match db changes 2026-06-22 15:02:39 +01:00
1a851859d0 fix: logging import for email 2026-06-22 15:02:04 +01:00
a343b76f63 fix: invalid toml syntax 2026-06-22 15:01:36 +01:00
irl
84ba3b6bee feat(db): db tuning options and consistency 2026-06-22 14:50:05 +01:00
40918fd8b8 feat: delete org soft deletes 2026-06-22 14:50:05 +01:00
irl
d395b01997 fix: only serve frontend if present in prod 2026-06-22 14:42:13 +01:00
irl
1384ee7bd6 feat: adds empty static directory for frontend 2026-06-22 14:40:02 +01:00
irl
df8ab32cb1 ci: build and publish OCI image 2026-06-22 14:38:23 +01:00
f41f76bcf8 Merge pull request 'feat(utils): use logging around email send' (#31) from irl/cloud-api:maillog into main
Reviewed-on: sr2/cloud-api#31
2026-06-22 13:37:15 +00:00
d07230b3b0 Merge pull request 'fix(user): simplify add_user' (#28) from irl/cloud-api:add_user into main
Reviewed-on: sr2/cloud-api#28
2026-06-22 13:34:36 +00:00
irl
9e1d6026b5 feat: adds Containerfile with frontend serving 2026-06-22 14:24:56 +01:00
c28b4dc37b feat: applied model mixins
IdMixin used on every table with an ID index (no changes needed to db)

Timestamp and Deleted mixins applied to org and user tables.

ActivatedMixin added to users.
2026-06-22 13:46:11 +01:00
7e1ab6c6ee feat: db model mixins 2026-06-22 13:46:11 +01:00
irl
0baa50d10f misc: add frontend dir to .gitignore 2026-06-22 13:30:53 +01:00
irl
53b42b24dd feat(utils): use logging around email send 2026-06-22 13:26:47 +01:00
irl
fe8f627fa5 ci: reduce min age for renovate to 7 days 2026-06-22 12:02:29 +00:00
c2777db2e3 Add renovate.json 2026-06-22 12:02:02 +00:00
irl
a9e539ef74 fix(user): simplify add_user 2026-06-22 12:23:38 +01:00
02ddf9a3ed fix: skip sending email process while running tests
Removes the need for lettermint api key in CI.
2026-06-22 12:06:43 +01:00
63e7d48c07 ci: remove non-ty checks from ty job 2026-06-22 12:04:39 +01:00
65 changed files with 3945 additions and 3726 deletions

View file

@ -34,8 +34,6 @@ jobs:
- run: uv python install # Gets Python version from pyproject.toml - run: uv python install # Gets Python version from pyproject.toml
- run: uv sync --dev - run: uv sync --dev
- run: uv run ty check - run: uv run ty check
- run: uv run ruff format
- run: uv run pytest test
env: env:
ENVIRONMENT: testing ENVIRONMENT: testing
@ -54,3 +52,35 @@ jobs:
- run: uv run pytest test - run: uv run pytest test
env: env:
ENVIRONMENT: testing ENVIRONMENT: testing
build:
needs: [ ruff, ty, tests ]
if: ${{ always() && needs.ruff.result == 'success' && needs.ty.result == 'success' && needs.tests.result == 'success' }}
runs-on: docker
container:
image: ghcr.io/catthehacker/ubuntu:act-latest
options: -v /dind/docker.sock:/var/run/docker.sock
steps:
- name: Checkout the repo
uses: actions/checkout@v4
- name: Checkout the frontend
uses: actions/checkout@v4
with:
repository: sr2/cloud-portal.git
path: frontend
ref: main
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to the registry
uses: docker/login-action@v3
with:
registry: guardianproject.dev
username: irl
password: ${{ secrets.PACKAGE_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v6
with:
file: /workspace/sr2/cloud-api/Containerfile
context: /workspace/sr2/cloud-api/
push: true
tags: guardianproject.dev/${{ github.repository }}:${{ github.ref_name }}

4
.gitignore vendored
View file

@ -206,5 +206,7 @@ marimo/_static/
marimo/_lsp/ marimo/_lsp/
__marimo__/ __marimo__/
endpoints.txt endpoints.txt
# React Frontend
/frontend/

View file

@ -1 +1 @@
3.14 3.12

42
Containerfile Normal file
View file

@ -0,0 +1,42 @@
FROM node:22-slim AS react-builder
WORKDIR /app
COPY frontend/ /app/
RUN --mount=type=cache,target=/root/.npm npm ci
RUN npm run build # Outputs to /app/dist
FROM ghcr.io/astral-sh/uv:python3.12-trixie-slim AS python-builder
ENV UV_PYTHON_DOWNLOADS=0
WORKDIR /app
# Install dependencies first (layer caching)
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
uv sync --locked --no-install-project --no-editable
# Copy project source and install the project itself
COPY ./ /app/
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --locked --no-editable
FROM python:3.12-slim-trixie
WORKDIR /app
COPY alembic /app/alembic
COPY alembic.ini /app
COPY src /app/src
COPY --from=python-builder /app/.venv /app/.venv
COPY --from=react-builder /app/dist /app/static
# Ensure venv is on PATH
ENV PATH="/app/.venv/bin:$PATH" \
UV_PYTHON_DOWNLOADS=0
EXPOSE 8000
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]

View file

@ -0,0 +1,32 @@
"""fix user activated default
Revision ID: ae433e1c3b20
Revises: 661202797ecd
Create Date: 2026-06-22 15:26:57.805129
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'ae433e1c3b20'
down_revision: Union[str, Sequence[str], None] = '661202797ecd'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('user', 'active', server_default=sa.true())
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('user', 'active', server_default=sa.false())
# ### end Alembic commands ###

View file

@ -0,0 +1,44 @@
"""model mixins
Revision ID: 661202797ecd
Revises: 869d48618a1c
Create Date: 2026-06-22 13:29:39.689067
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '661202797ecd'
down_revision: Union[str, Sequence[str], None] = '869d48618a1c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('organisation', sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()))
op.add_column('organisation', sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()))
op.add_column('organisation', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
op.add_column('user', sa.Column('active', sa.Boolean(), nullable=False, server_default=sa.false()))
op.add_column('user', sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()))
op.add_column('user', sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()))
op.add_column('user', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('user', 'deleted_at')
op.drop_column('user', 'updated_at')
op.drop_column('user', 'created_at')
op.drop_column('user', 'active')
op.drop_column('organisation', 'deleted_at')
op.drop_column('organisation', 'updated_at')
op.drop_column('organisation', 'created_at')
# ### end Alembic commands ###

View file

@ -8,7 +8,7 @@ requires-python = ">=3.12"
dependencies = [ dependencies = [
"alembic>=1.18.4", "alembic>=1.18.4",
"email-validator>=2.3.0", "email-validator>=2.3.0",
"fastapi>=0.136.3", "fastapi>=0.138.0",
"httptools>=0.7.1", "httptools>=0.7.1",
"httpx>=0.28.1", "httpx>=0.28.1",
"itsdangerous>=2.2.0", "itsdangerous>=2.2.0",
@ -34,11 +34,11 @@ line-length = 92
[tool.ruff.format] [tool.ruff.format]
quote-style = "double" quote-style = "double"
indent-style = "tab"
[tool.uv] [tool.uv]
add-bounds = "major" add-bounds = "major"
exclude-newer = "P2W" exclude-newer = "P2W"
exclude-newer-package = { "fastapi" = "2026-06-22T00:00:00Z" }
[dependency-groups] [dependency-groups]
dev = [ dev = [

8
renovate.json Normal file
View file

@ -0,0 +1,8 @@
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": [
"config:recommended"
],
"minimumReleaseAge": "7 days",
"gitAuthor": "Renovate<noreply@sr2.uk>"
}

View file

@ -22,5 +22,5 @@ from fastapi import APIRouter
router = APIRouter( router = APIRouter(
tags=[""], tags=[""],
) )

View file

@ -8,6 +8,6 @@ Exports:
from fastapi import APIRouter from fastapi import APIRouter
router = APIRouter( router = APIRouter(
tags=["admin"], tags=["admin"],
prefix="/admin", prefix="/admin",
) )

View file

@ -26,15 +26,15 @@ api_router.include_router(iam_router)
class HealthCheckResponse(CustomBaseModel): class HealthCheckResponse(CustomBaseModel):
status: str status: str
@api_router.get( @api_router.get(
path="/healthcheck", path="/healthcheck",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=HealthCheckResponse, response_model=HealthCheckResponse,
include_in_schema=False, include_in_schema=False,
) )
def healthcheck(): def healthcheck():
"""Simple health check endpoint.""" """Simple health check endpoint."""
return {"status": "ok"} return {"status": "ok"}

View file

@ -9,9 +9,9 @@ from src.config import CustomBaseSettings
class AuthConfig(CustomBaseSettings): class AuthConfig(CustomBaseSettings):
OIDC_CONFIG: str = "" OIDC_CONFIG: str = ""
OIDC_ISSUER: str = "" OIDC_ISSUER: str = ""
CLIENT_ID: str = "" CLIENT_ID: str = ""
auth_settings = AuthConfig() auth_settings = AuthConfig()

View file

@ -16,92 +16,92 @@ from src.exceptions import ForbiddenException
from src.user.dependencies import user_model_claims_dependency from src.user.dependencies import user_model_claims_dependency
from src.user.models import User from src.user.models import User
from src.organisation.dependencies import ( from src.organisation.dependencies import (
org_model_query_dependency, org_model_query_dependency,
org_model_body_dependency, org_model_body_dependency,
) )
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
async def org_query_user_claims( async def org_query_user_claims(
org_model: org_model_query_dependency, user_model: user_model_claims_dependency org_model: org_model_query_dependency, user_model: user_model_claims_dependency
): ):
if user_model in org_model.user_rel: if user_model in org_model.user_rel:
return True return True
raise ForbiddenException() raise ForbiddenException()
org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)] org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)]
def get_super_admin_list(): def get_super_admin_list():
return [] return []
def empty_su_list(): def empty_su_list():
return [] return []
def testing_su_list(): def testing_su_list():
return ["admin@test.com"] return ["admin@test.com"]
su_list_dependency = Annotated[list[str | None], Depends(get_super_admin_list)] su_list_dependency = Annotated[list[str | None], Depends(get_super_admin_list)]
async def user_model_super_admin( async def user_model_super_admin(
user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency
): ):
if user_model.email in super_admin_emails: if user_model.email in super_admin_emails:
return user_model return user_model
raise ForbiddenException(message="Must be super admin") raise ForbiddenException(message="Must be super admin")
super_admin_dependency = Annotated[User, Depends(user_model_super_admin)] super_admin_dependency = Annotated[User, Depends(user_model_super_admin)]
async def org_query_root_claims( async def org_query_root_claims(
user_model: user_model_claims_dependency, user_model: user_model_claims_dependency,
org_model: org_model_query_dependency, org_model: org_model_query_dependency,
su_emails: su_list_dependency, su_emails: su_list_dependency,
request: Request, request: Request,
): ):
try: try:
if await user_model_super_admin(user_model, su_emails): if await user_model_super_admin(user_model, su_emails):
return org_model return org_model
except ForbiddenException: except ForbiddenException:
pass pass
await org_status_check(org_model, request) await org_status_check(org_model, request)
if org_model.root_user_id == user_model.id: if org_model.root_user_id == user_model.id:
return org_model return org_model
raise ForbiddenException(message="Must be the org's root user") raise ForbiddenException(message="Must be the org's root user")
org_model_root_claim_query_dependency = Annotated[Org, Depends(org_query_root_claims)] org_model_root_claim_query_dependency = Annotated[Org, Depends(org_query_root_claims)]
async def org_body_root_claims( async def org_body_root_claims(
user_model: user_model_claims_dependency, user_model: user_model_claims_dependency,
org_model: org_model_body_dependency, org_model: org_model_body_dependency,
su_emails: su_list_dependency, su_emails: su_list_dependency,
request: Request, request: Request,
): ):
try: try:
if await user_model_super_admin(user_model, su_emails): if await user_model_super_admin(user_model, su_emails):
return org_model return org_model
except ForbiddenException: except ForbiddenException:
pass pass
await org_status_check(org_model, request) await org_status_check(org_model, request)
if org_model.root_user_id == user_model.id: if org_model.root_user_id == user_model.id:
return org_model return org_model
raise ForbiddenException(message="Must be the org's root user") raise ForbiddenException(message="Must be the org's root user")
org_model_root_claim_body_dependency = Annotated[Org, Depends(org_body_root_claims)] org_model_root_claim_body_dependency = Annotated[Org, Depends(org_body_root_claims)]

View file

@ -8,5 +8,5 @@ Exports:
from fastapi import APIRouter from fastapi import APIRouter
router = APIRouter( router = APIRouter(
tags=["auth"], tags=["auth"],
) )

View file

@ -22,7 +22,7 @@ from src.organisation.exceptions import AwaitingApprovalException
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
from src.exceptions import UnauthorizedException, ForbiddenException from src.exceptions import UnauthorizedException, ForbiddenException
from src.auth.config import auth_settings from src.auth.config import auth_settings
from src.user.service import add_user_to_db from src.user.service import add_user
from src.database import DbSession from src.database import DbSession
@ -31,56 +31,56 @@ oidc_dependency = Annotated[str, Depends(oidc)]
async def get_dev_user(): async def get_dev_user():
return {"db_id": 1, "email": "chris@sr2.uk"} return {"db_id": 1, "email": "chris@sr2.uk"}
async def get_current_user( async def get_current_user(
oidc_auth_string: oidc_dependency, db: DbSession oidc_auth_string: oidc_dependency, db: DbSession
) -> dict[str, Any]: ) -> dict[str, Any]:
config_url = urlopen(auth_settings.OIDC_CONFIG) config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read()) config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"] jwks_uri = config["jwks_uri"]
key_response = requests.get(jwks_uri) key_response = requests.get(jwks_uri)
jwk_keys = KeySet.import_key_set(key_response.json()) jwk_keys = KeySet.import_key_set(key_response.json())
token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys) token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys)
claims_requests = jwt.JWTClaimsRegistry( claims_requests = jwt.JWTClaimsRegistry(
exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER} exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER}
) )
try: try:
claims_requests.validate(token.claims) claims_requests.validate(token.claims)
except ExpiredTokenError: except ExpiredTokenError:
raise UnauthorizedException(message="Token is expired") raise UnauthorizedException(message="Token is expired")
db_id = await add_user_to_db(db, token.claims) db_id = await add_user(db, token.claims)
token.claims["db_id"] = db_id token.claims["db_id"] = db_id
return token.claims return token.claims
claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)]
async def org_status_check(org_model: Org, request: Request): async def org_status_check(org_model: Org, request: Request):
org_status = OrgStatus(org_model.status) org_status = OrgStatus(org_model.status)
if org_status.is_blocked: if org_status.is_blocked:
raise ForbiddenException("This organisation cannot perform this action.") raise ForbiddenException("This organisation cannot perform this action.")
root = "/api/v1" root = "/api/v1"
pre_approval_endpoints = [ pre_approval_endpoints = [
f"PATCH{root}/org/status", f"PATCH{root}/org/status",
f"PATCH{root}/org/questionnaire", f"PATCH{root}/org/questionnaire",
f"GET{root}/org", f"GET{root}/org",
f"GET{root}/org/contact", f"GET{root}/org/contact",
f"PATCH{root}/org/contact", f"PATCH{root}/org/contact",
f"DELETE{root}/org/self", f"DELETE{root}/org/self",
] ]
current_request = f"{request.method}{request.url.path}" current_request = f"{request.method}{request.url.path}"
if ( if (
current_request not in pre_approval_endpoints current_request not in pre_approval_endpoints
and org_model.status != OrgStatus.APPROVED and org_model.status != OrgStatus.APPROVED
): ):
raise AwaitingApprovalException(org_model.id) raise AwaitingApprovalException(org_model.id)

View file

@ -16,31 +16,31 @@ from src.constants import Environment
class CustomBaseSettings(BaseSettings): class CustomBaseSettings(BaseSettings):
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore" env_file=".env", env_file_encoding="utf-8", extra="ignore"
) )
class Config(CustomBaseSettings): class Config(CustomBaseSettings):
APP_VERSION: str = "0.1" APP_VERSION: str = "0.1"
ENVIRONMENT: Environment = Environment.PRODUCTION ENVIRONMENT: Environment = Environment.PRODUCTION
SECRET_KEY: SecretStr = SecretStr("") SECRET_KEY: SecretStr = SecretStr("")
DISABLE_AUTH: bool = False DISABLE_AUTH: bool = False
CORS_ORIGINS: list[str] = ["*"] CORS_ORIGINS: list[str] = ["*"]
CORS_ORIGINS_REGEX: str | None = None CORS_ORIGINS_REGEX: str | None = None
CORS_HEADERS: list[str] = ["*"] CORS_HEADERS: list[str] = ["*"]
DATABASE_NAME: str = "fastapi-exp" DATABASE_NAME: str = "fastapi-exp"
DATABASE_PORT: str = "5432" DATABASE_PORT: str = "5432"
DATABASE_HOSTNAME: str = "localhost" DATABASE_HOSTNAME: str = "localhost"
DATABASE_CREDENTIALS: SecretStr = SecretStr(":") DATABASE_CREDENTIALS: SecretStr = SecretStr(":")
DATABASE_POOL_SIZE: int = 16 DATABASE_POOL_SIZE: int = 16
DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes
DATABASE_POOL_PRE_PING: bool = True DATABASE_POOL_PRE_PING: bool = True
LETTERMINT_API_TOKEN: SecretStr = SecretStr("") LETTERMINT_API_TOKEN: SecretStr = SecretStr("")
settings = Config() settings = Config()
@ -51,20 +51,20 @@ DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME
DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value() DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value()
# this will support special chars for credentials # this will support special chars for credentials
_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split( _DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split(
":" ":"
) )
_QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD)) _QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD))
SQLALCHEMY_DATABASE_URI = SecretStr( SQLALCHEMY_DATABASE_URI = SecretStr(
f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}" f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
) )
if settings.ENVIRONMENT == Environment.TESTING: if settings.ENVIRONMENT == Environment.TESTING:
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:") SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")
app_configs: dict[str, Any] = {"title": "App API"} app_configs: dict[str, Any] = {"title": "App API"}
if settings.ENVIRONMENT.is_deployed: if settings.ENVIRONMENT.is_deployed:
app_configs["root_path"] = f"/v{settings.APP_VERSION}" app_configs["root_path"] = f"/v{settings.APP_VERSION}"
if not settings.ENVIRONMENT.is_debug: if not settings.ENVIRONMENT.is_debug:
app_configs["openapi_url"] = None # hide docs app_configs["openapi_url"] = None # hide docs

View file

@ -9,29 +9,29 @@ from enum import StrEnum, auto
class Environment(StrEnum): class Environment(StrEnum):
""" """
Enumeration of environments. Enumeration of environments.
Attributes: Attributes:
LOCAL (str): Application is running locally LOCAL (str): Application is running locally
TESTING (str): Application is running in testing mode TESTING (str): Application is running in testing mode
STAGING (str): Application is running in staging mode (ie not testing) STAGING (str): Application is running in staging mode (ie not testing)
PRODUCTION (str): Application is running in production mode PRODUCTION (str): Application is running in production mode
""" """
LOCAL = auto() LOCAL = auto()
TESTING = auto() TESTING = auto()
STAGING = auto() STAGING = auto()
PRODUCTION = auto() PRODUCTION = auto()
@property @property
def is_debug(self): def is_debug(self):
return self in (self.LOCAL, self.STAGING, self.TESTING) return self in (self.LOCAL, self.STAGING, self.TESTING)
@property @property
def is_testing(self): def is_testing(self):
return self == self.TESTING return self == self.TESTING
@property @property
def is_deployed(self) -> bool: def is_deployed(self) -> bool:
return self in (self.STAGING, self.PRODUCTION) return self in (self.STAGING, self.PRODUCTION)

View file

@ -11,13 +11,13 @@ from fastapi import HTTPException, status
class ContactNotFoundException(HTTPException): class ContactNotFoundException(HTTPException):
def __init__(self, contact_id: Optional[int] = None) -> None: def __init__(self, contact_id: Optional[int] = None) -> None:
detail = ( detail = (
"Contact not found" "Contact not found"
if contact_id is None if contact_id is None
else f"Contact with ID '{contact_id}' was not found." else f"Contact with ID '{contact_id}' was not found."
) )
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )

View file

@ -6,30 +6,31 @@ Models:
street_address, street_address_line_2, post_office_box_number, address_locality, country_code, address_region, postal_code street_address, street_address_line_2, post_office_box_number, address_locality, country_code, address_region, postal_code
""" """
from src.models import IdMixin
from sqlalchemy import ForeignKey from sqlalchemy import ForeignKey
from sqlalchemy.orm import mapped_column, Mapped from sqlalchemy.orm import mapped_column, Mapped
from src.models import CustomBase from src.models import CustomBase
class Contact(CustomBase): class Contact(CustomBase, IdMixin):
__tablename__ = "contact" __tablename__ = "contact"
id: Mapped[int] = mapped_column(primary_key=True) email: Mapped[str] = mapped_column(default=None, nullable=True)
email: Mapped[str] = mapped_column(default=None, nullable=True) first_name: Mapped[str] = mapped_column(default=None, nullable=True)
first_name: Mapped[str] = mapped_column(default=None, nullable=True) last_name: Mapped[str] = mapped_column(default=None, nullable=True)
last_name: Mapped[str] = mapped_column(default=None, nullable=True) phonenumber: Mapped[str] = mapped_column(default=None, nullable=True)
phonenumber: Mapped[str] = mapped_column(default=None, nullable=True) vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
street_address: Mapped[str] = mapped_column(default=None, nullable=True) street_address: Mapped[str] = mapped_column(default=None, nullable=True)
street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True) street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True)
post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True) post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City
country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB
address_region: Mapped[str | None] = mapped_column(default=None, nullable=True) address_region: Mapped[str | None] = mapped_column(default=None, nullable=True)
postal_code: Mapped[str] = mapped_column(default=None, nullable=True) postal_code: Mapped[str] = mapped_column(default=None, nullable=True)
org_id: Mapped[int] = mapped_column( org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False
) )

View file

@ -6,6 +6,6 @@ from fastapi import APIRouter
router = APIRouter( router = APIRouter(
prefix="/contact", prefix="/contact",
tags=["contact"], tags=["contact"],
) )

View file

@ -14,22 +14,22 @@ from src.schemas import CustomBaseModel
class ContactAddress(CustomBaseModel): class ContactAddress(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
post_office_box_number: Optional[str] = None post_office_box_number: Optional[str] = None
street_address: Optional[str] = None street_address: Optional[str] = None
street_address_line_2: Optional[str] = None street_address_line_2: Optional[str] = None
locality: Optional[str] = None locality: Optional[str] = None
address_region: Optional[str] = None address_region: Optional[str] = None
country_code: Optional[str] = None country_code: Optional[str] = None
postal_code: Optional[str] = None postal_code: Optional[str] = None
class ContactModel(CustomBaseModel): class ContactModel(CustomBaseModel):
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
first_name: Optional[str] = None first_name: Optional[str] = None
last_name: Optional[str] = None last_name: Optional[str] = None
phonenumber: Optional[str] = None phonenumber: Optional[str] = None
vat_number: Optional[str] = None vat_number: Optional[str] = None
address: ContactAddress address: ContactAddress

View file

@ -1,6 +1,7 @@
""" """
Database connection and session utilities Database connection and session utilities
""" """
from contextlib import contextmanager from contextlib import contextmanager
from typing import Annotated, Generator from typing import Annotated, Generator
from sqlalchemy import create_engine, StaticPool, Connection from sqlalchemy import create_engine, StaticPool, Connection
@ -29,6 +30,7 @@ else:
sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine) sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine)
@contextmanager @contextmanager
def get_db_connection() -> Generator[Connection, None, None]: def get_db_connection() -> Generator[Connection, None, None]:
with engine.connect() as connection: with engine.connect() as connection:
@ -38,12 +40,15 @@ def get_db_connection() -> Generator[Connection, None, None]:
connection.rollback() connection.rollback()
raise raise
def _get_db_connection() -> Generator[Connection, None]:
def _get_db_connection() -> Generator[Connection, None, None]:
with get_db_connection() as connection: with get_db_connection() as connection:
yield connection yield connection
DbConnection = Annotated[Connection, Depends(_get_db_connection)] DbConnection = Annotated[Connection, Depends(_get_db_connection)]
@contextmanager @contextmanager
def get_db_session() -> Generator[Session, None, None]: def get_db_session() -> Generator[Session, None, None]:
session = sm() session = sm()
@ -56,8 +61,9 @@ def get_db_session() -> Generator[Session, None, None]:
session.close() session.close()
def _get_db_session() -> Generator[Session, None]: def _get_db_session() -> Generator[Session, None, None]:
with get_db_session() as session: with get_db_session() as session:
yield session yield session
DbSession = Annotated[Session, Depends(_get_db_session)] DbSession = Annotated[Session, Depends(_get_db_session)]

View file

@ -12,36 +12,36 @@ from fastapi import HTTPException, status
class UnprocessableContentException(HTTPException): class UnprocessableContentException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message: Optional[str] = None) -> None:
detail = "Unprocessable content" if not message else message detail = "Unprocessable content" if not message else message
super().__init__( super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=detail, detail=detail,
) )
class ConflictException(HTTPException): class ConflictException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message: Optional[str] = None) -> None:
detail = "Conflict" if not message else message detail = "Conflict" if not message else message
super().__init__( super().__init__(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail=detail, detail=detail,
) )
class ForbiddenException(HTTPException): class ForbiddenException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message: Optional[str] = None) -> None:
detail = "Forbidden" if not message else message detail = "Forbidden" if not message else message
super().__init__( super().__init__(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=detail, detail=detail,
) )
class UnauthorizedException(HTTPException): class UnauthorizedException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message: Optional[str] = None) -> None:
detail = "Not authorized" if not message else message detail = "Not authorized" if not message else message
super().__init__( super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail, detail=detail,
) )

View file

@ -18,59 +18,55 @@ from src.iam.exceptions import GroupNotFoundException, PermNotFoundException
from src.iam.schemas import GroupIDMixin, PermIDMixin from src.iam.schemas import GroupIDMixin, PermIDMixin
def get_group_model_query( def get_group_model_query(db: DbSession, group_id: Annotated[int, Query(gt=0)]) -> Group:
db: DbSession, group_id: Annotated[int, Query(gt=0)] group_model = db.get(Group, group_id)
) -> Group: if group_model is None:
group_model = db.get(Group, group_id) raise GroupNotFoundException(group_id)
if group_model is None:
raise GroupNotFoundException(group_id)
return group_model return group_model
group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)] group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)]
def get_group_model_body( def get_group_model_body(
db: DbSession, request_model: Optional[GroupIDMixin] = None db: DbSession, request_model: Optional[GroupIDMixin] = None
) -> Group: ) -> Group:
group_id = getattr(request_model, "group_id", None) group_id = getattr(request_model, "group_id", None)
if group_id is None: if group_id is None:
raise GroupNotFoundException() raise GroupNotFoundException()
group_model = db.get(Group, group_id) group_model = db.get(Group, group_id)
if group_model is None: if group_model is None:
raise GroupNotFoundException(group_id) raise GroupNotFoundException(group_id)
return group_model return group_model
group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)] group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)]
def get_perm_model_body( def get_perm_model_body(
db: DbSession, request_model: Optional[PermIDMixin] = None db: DbSession, request_model: Optional[PermIDMixin] = None
) -> Permission: ) -> Permission:
perm_id = getattr(request_model, "permission_id", None) perm_id = getattr(request_model, "permission_id", None)
if perm_id is None: if perm_id is None:
raise PermNotFoundException raise PermNotFoundException
perm_model = db.get(Permission, perm_id) perm_model = db.get(Permission, perm_id)
if perm_model is None: if perm_model is None:
raise PermNotFoundException(perm_id) raise PermNotFoundException(perm_id)
return perm_model return perm_model
perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)] perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)]
def get_perm_model_query( def get_perm_model_query(db: DbSession, perm_id: Annotated[int, Query(gt=0)]) -> Permission:
db: DbSession, perm_id: Annotated[int, Query(gt=0)] perm_model = db.get(Permission, perm_id)
) -> Permission: if perm_model is None:
perm_model = db.get(Permission, perm_id) raise PermNotFoundException(perm_id)
if perm_model is None:
raise PermNotFoundException(perm_id)
return perm_model return perm_model
perm_model_query_dependency = Annotated[Permission, Depends(get_perm_model_query)] perm_model_query_dependency = Annotated[Permission, Depends(get_perm_model_query)]

View file

@ -12,26 +12,26 @@ from fastapi import HTTPException, status
class GroupNotFoundException(HTTPException): class GroupNotFoundException(HTTPException):
def __init__(self, group_id: Optional[int] = None) -> None: def __init__(self, group_id: Optional[int] = None) -> None:
detail = ( detail = (
"Group not found" "Group not found"
if group_id is None if group_id is None
else f"User with ID '{group_id}' was not found." else f"User with ID '{group_id}' was not found."
) )
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )
class PermNotFoundException(HTTPException): class PermNotFoundException(HTTPException):
def __init__(self, perm_id: Optional[int] = None) -> None: def __init__(self, perm_id: Optional[int] = None) -> None:
detail = ( detail = (
"Permission not found" "Permission not found"
if perm_id is None if perm_id is None
else f"User with ID '{perm_id}' was not found." else f"User with ID '{perm_id}' was not found."
) )
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )

View file

@ -21,95 +21,94 @@ Models:
from sqlalchemy import ForeignKey, UniqueConstraint from sqlalchemy import ForeignKey, UniqueConstraint
from sqlalchemy.orm import relationship, mapped_column, Mapped from sqlalchemy.orm import relationship, mapped_column, Mapped
from src.models import CustomBase from src.models import CustomBase, IdMixin
class Permission(CustomBase): class Permission(CustomBase, IdMixin):
__tablename__ = "permission" __tablename__ = "permission"
id: Mapped[int] = mapped_column(primary_key=True) resource: Mapped[str]
resource: Mapped[str] action: Mapped[str]
action: Mapped[str]
service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE")) service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE"))
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(
"service_id", "service_id",
"resource", "resource",
"action", "action",
name="uniq_permission_resource_and_action", name="uniq_permission_resource_and_action",
), ),
) )
service_rel = relationship( service_rel = relationship(
"Service", "Service",
back_populates="permission_rel", back_populates="permission_rel",
foreign_keys="Permission.service_id", foreign_keys="Permission.service_id",
) )
group_rel = relationship( group_rel = relationship(
"Group", secondary="group_permissions", back_populates="permission_rel" "Group", secondary="group_permissions", back_populates="permission_rel"
) )
org_rel = relationship( org_rel = relationship(
"Organisation", secondary="org_permissions", back_populates="permission_rel" "Organisation", secondary="org_permissions", back_populates="permission_rel"
) )
@property @property
def service_name(self): def service_name(self):
return self.service_rel.name return self.service_rel.name
class Group(CustomBase): class Group(CustomBase, IdMixin):
__tablename__ = "group" __tablename__ = "group"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]
org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE")) name: Mapped[str]
__table_args__ = ( org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE"))
UniqueConstraint(
"name",
"org_id",
name="uniq_group_name_org_id",
),
)
user_rel = relationship("User", secondary="user_groups", back_populates="group_rel") __table_args__ = (
UniqueConstraint(
"name",
"org_id",
name="uniq_group_name_org_id",
),
)
org_rel = relationship("Organisation", back_populates="group_rel") user_rel = relationship("User", secondary="user_groups", back_populates="group_rel")
permission_rel = relationship( org_rel = relationship("Organisation", back_populates="group_rel")
"Permission", secondary="group_permissions", back_populates="group_rel"
) permission_rel = relationship(
"Permission", secondary="group_permissions", back_populates="group_rel"
)
class GroupPermissions(CustomBase): class GroupPermissions(CustomBase):
__tablename__ = "group_permissions" __tablename__ = "group_permissions"
group_id: Mapped[int] = mapped_column( group_id: Mapped[int] = mapped_column(
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
) )
permission_id: Mapped[int] = mapped_column( permission_id: Mapped[int] = mapped_column(
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
) )
class UserGroups(CustomBase): class UserGroups(CustomBase):
__tablename__ = "user_groups" __tablename__ = "user_groups"
user_id: Mapped[int] = mapped_column( user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
) )
group_id: Mapped[int] = mapped_column( group_id: Mapped[int] = mapped_column(
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
) )
class OrgPermissions(CustomBase): class OrgPermissions(CustomBase):
__tablename__ = "org_permissions" __tablename__ = "org_permissions"
org_id: Mapped[int] = mapped_column( org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
) )
permission_id: Mapped[int] = mapped_column( permission_id: Mapped[int] = mapped_column(
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
) )

File diff suppressed because it is too large Load diff

View file

@ -11,151 +11,151 @@ from typing import Optional, Annotated
from pydantic import EmailStr, ConfigDict, Field from pydantic import EmailStr, ConfigDict, Field
from src.schemas import ( from src.schemas import (
CustomBaseModel, CustomBaseModel,
ResourceName, ResourceName,
ServiceIDMixin, ServiceIDMixin,
OrgIDMixin, OrgIDMixin,
UserIDMixin, UserIDMixin,
PermIDMixin, PermIDMixin,
GroupIDMixin, GroupIDMixin,
GroupSummary, GroupSummary,
OrgSummary, OrgSummary,
UserSummary, UserSummary,
) )
class UserSchema(CustomBaseModel): class UserSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int id: int
first_name: str first_name: str
last_name: str last_name: str
email: EmailStr email: EmailStr
class PermissionSchema(CustomBaseModel): class PermissionSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int id: int
service_name: str service_name: str
resource: str resource: str
action: str action: str
class GroupDetails(CustomBaseModel): class GroupDetails(CustomBaseModel):
details: GroupSummary details: GroupSummary
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMCAoRRequest(CustomBaseModel): class IAMCAoRRequest(CustomBaseModel):
action: str action: str
rn: ResourceName rn: ResourceName
class IAMCAoRResponse(CustomBaseModel): class IAMCAoRResponse(CustomBaseModel):
allowed: bool allowed: bool
user: UserSummary user: UserSummary
action: str action: str
rn: ResourceName rn: ResourceName
class IAMGetGroupPermissionsResponse(CustomBaseModel): class IAMGetGroupPermissionsResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
group: GroupSummary group: GroupSummary
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMGetGroupUsersResponse(CustomBaseModel): class IAMGetGroupUsersResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
group: GroupSummary group: GroupSummary
users: list[UserSummary] users: list[UserSummary]
class IAMPostGroupRequest(OrgIDMixin): class IAMPostGroupRequest(OrgIDMixin):
name: str = Field(min_length=3) name: str = Field(min_length=3)
class IAMPostGroupResponse(CustomBaseModel): class IAMPostGroupResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
group: GroupSummary group: GroupSummary
class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin): class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin):
pass pass
class IAMPutGroupPermissionResponse(CustomBaseModel): class IAMPutGroupPermissionResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
group: GroupSummary group: GroupSummary
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin): class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin):
pass pass
class IAMPutGroupUserResponse(CustomBaseModel): class IAMPutGroupUserResponse(CustomBaseModel):
group: GroupSummary group: GroupSummary
users: list[UserSchema] users: list[UserSchema]
class IAMDeleteGroupPermissionResponse(CustomBaseModel): class IAMDeleteGroupPermissionResponse(CustomBaseModel):
group: GroupSummary group: GroupSummary
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMDeleteGroupUserResponse(CustomBaseModel): class IAMDeleteGroupUserResponse(CustomBaseModel):
group: GroupSummary group: GroupSummary
users: list[UserSchema] users: list[UserSchema]
class IAMGetPermissionsResponse(CustomBaseModel): class IAMGetPermissionsResponse(CustomBaseModel):
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMPostPermissionRequest(ServiceIDMixin): class IAMPostPermissionRequest(ServiceIDMixin):
resource: str resource: str
action: str action: str
class IAMPostPermissionResponse(CustomBaseModel): class IAMPostPermissionResponse(CustomBaseModel):
permission: PermissionSchema permission: PermissionSchema
class IAMGetPermissionsSearchRequest(OrgIDMixin): class IAMGetPermissionsSearchRequest(OrgIDMixin):
service_id: Annotated[int | None, Field(gt=0)] = None service_id: Annotated[int | None, Field(gt=0)] = None
resource: Optional[str] = None resource: Optional[str] = None
action: Optional[str] = None action: Optional[str] = None
class IAMGetPermissionsSearchResponse(CustomBaseModel): class IAMGetPermissionsSearchResponse(CustomBaseModel):
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMPutGroupInvitationRequest(OrgIDMixin, GroupIDMixin): class IAMPutGroupInvitationRequest(OrgIDMixin, GroupIDMixin):
user_email: EmailStr user_email: EmailStr
class IAMPutGroupInvitationResponse(CustomBaseModel): class IAMPutGroupInvitationResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
group: GroupSummary group: GroupSummary
invited_email: EmailStr invited_email: EmailStr
class IAMPutGroupInvitationAcceptRequest(CustomBaseModel): class IAMPutGroupInvitationAcceptRequest(CustomBaseModel):
jwt: str jwt: str
class IAMPutGroupInvitationAcceptResponse(CustomBaseModel): class IAMPutGroupInvitationAcceptResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
user: UserSummary user: UserSummary
group: GroupDetails group: GroupDetails
class IAMPutOrgPermissionsRequest(OrgIDMixin): class IAMPutOrgPermissionsRequest(OrgIDMixin):
permissions: list[int] permissions: list[int]
class IAMPutOrgPermissionsResponse(CustomBaseModel): class IAMPutOrgPermissionsResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
permissions: list[PermissionSchema] permissions: list[PermissionSchema]

View file

@ -23,90 +23,90 @@ from src.service.schemas import HasServiceName
def valid_service_key( def valid_service_key(
db: DbSession, request: Request, request_model: HasServiceName db: DbSession, request: Request, request_model: HasServiceName
) -> bool: ) -> bool:
rn = request_model.rn rn = request_model.rn
api_key = request.headers.get("X-API-Key", None) api_key = request.headers.get("X-API-Key", None)
if not api_key: if not api_key:
raise UnauthorizedException("Missing API key") raise UnauthorizedException("Missing API key")
service = rn.service service = rn.service
result = ( result = (
db.query(Service) db.query(Service)
.filter(Service.name == service) .filter(Service.name == service)
.filter(Service.api_key == api_key) .filter(Service.api_key == api_key)
.first() .first()
) )
if result is None: if result is None:
raise UnauthorizedException("Invalid API key") raise UnauthorizedException("Invalid API key")
return True return True
service_key_dependency = Annotated[bool, Depends(valid_service_key)] service_key_dependency = Annotated[bool, Depends(valid_service_key)]
async def send_user_group_invitation( async def send_user_group_invitation(
user_email: str, org_name: str, org_id: int, group_id: int, group_name: str user_email: str, org_name: str, org_id: int, group_id: int, group_name: str
): ):
expiry_delta = timedelta(hours=24) expiry_delta = timedelta(hours=24)
expiry = datetime.now(timezone.utc) + expiry_delta expiry = datetime.now(timezone.utc) + expiry_delta
claims = { claims = {
"email": user_email, "email": user_email,
"org_id": org_id, "org_id": org_id,
"group_id": group_id, "group_id": group_id,
"group_name": group_name, "group_name": group_name,
"exp": expiry, "exp": expiry,
"type": "group_invite", "type": "group_invite",
} }
token = await generate_jwt(claims) token = await generate_jwt(claims)
subject = f"You have been invited to join a group of {org_name}" subject = f"You have been invited to join a group of {org_name}"
body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}" body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
await send_email( await send_email(
recipient=user_email, recipient=user_email,
subject=subject, subject=subject,
body=body, body=body,
) )
async def create_group_and_assign_perms( async def create_group_and_assign_perms(
db: Session, org_model: Org, group_name: str, perm_list: list[int] db: Session, org_model: Org, group_name: str, perm_list: list[int]
): ):
new_group = Group(name=group_name, org_id=org_model.id) new_group = Group(name=group_name, org_id=org_model.id)
db.add(new_group) db.add(new_group)
db.flush() db.flush()
for permission in perm_list: for permission in perm_list:
perm_model = db.get(Perm, permission) perm_model = db.get(Perm, permission)
if perm_model is None: if perm_model is None:
continue continue
new_group.permission_rel.append(perm_model) new_group.permission_rel.append(perm_model)
db.flush() db.flush()
return new_group return new_group
async def assign_default_group( async def assign_default_group(
db: DbSession, db: DbSession,
org_model: Org, org_model: Org,
user_model: User, user_model: User,
group_name: str, group_name: str,
perm_list: list[int], perm_list: list[int],
): ):
group_model = ( group_model = (
db.query(Group) db.query(Group)
.filter(Group.org_id == org_model.id) .filter(Group.org_id == org_model.id)
.filter(Group.name == group_name) .filter(Group.name == group_name)
.first() .first()
) )
if group_model is None: if group_model is None:
group_model = await create_group_and_assign_perms( group_model = await create_group_and_assign_perms(
db=db, group_name=group_name, org_model=org_model, perm_list=perm_list db=db, group_name=group_name, org_model=org_model, perm_list=perm_list
) )
user_model.group_rel.append(group_model) user_model.group_rel.append(group_model)
db.flush() db.flush()

View file

@ -2,6 +2,7 @@
Application root file: Inits the FastAPI application Application root file: Inits the FastAPI application
""" """
import os.path
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator from typing import AsyncGenerator
@ -19,43 +20,43 @@ from src.auth.service import get_current_user, get_dev_user
@asynccontextmanager @asynccontextmanager
async def lifespan(_application: FastAPI) -> AsyncGenerator: async def lifespan(_application: FastAPI) -> AsyncGenerator:
# Startup # Startup
yield yield
# Shutdown # Shutdown
if settings.ENVIRONMENT.is_deployed: if settings.ENVIRONMENT.is_deployed:
# Just a precaution, should be False anyway # Just a precaution, should be False anyway
settings.DISABLE_AUTH = False settings.DISABLE_AUTH = False
tags_metadata = [ tags_metadata = [
{ {
"name": "User", "name": "User",
"description": "User related operations, includes getting information about the current user", "description": "User related operations, includes getting information about the current user",
}, },
{ {
"name": "Organisation", "name": "Organisation",
"description": "Organisation related operations, includes getting lists of users etc associated with orgs", "description": "Organisation related operations, includes getting lists of users etc associated with orgs",
}, },
{ {
"name": "Service", "name": "Service",
"description": "Services related operations, includes registering services and reissuing API keys", "description": "Services related operations, includes registering services and reissuing API keys",
}, },
{ {
"name": "IAM", "name": "IAM",
"description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.", "description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.",
}, },
] ]
app = FastAPI( app = FastAPI(
swagger_ui_init_oauth={ swagger_ui_init_oauth={
"clientId": auth_settings.CLIENT_ID, "clientId": auth_settings.CLIENT_ID,
"usePkceWithAuthorizationCodeGrant": True, "usePkceWithAuthorizationCodeGrant": True,
"scopes": "openid profile email", "scopes": "openid profile email",
}, },
openapi_tags=tags_metadata, openapi_tags=tags_metadata,
) )
# Type inspection disabled for middleware injection. # Type inspection disabled for middleware injection.
@ -64,16 +65,19 @@ app = FastAPI(
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value()) app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value())
# noinspection PyTypeChecker # noinspection PyTypeChecker
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=settings.CORS_ORIGINS, allow_origins=settings.CORS_ORIGINS,
allow_origin_regex=settings.CORS_ORIGINS_REGEX, allow_origin_regex=settings.CORS_ORIGINS_REGEX,
allow_credentials=True, allow_credentials=True,
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"), allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
allow_headers=settings.CORS_HEADERS, allow_headers=settings.CORS_HEADERS,
) )
if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL): if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL):
app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_current_user] = get_dev_user
app.include_router(api_router) app.include_router(api_router)
if os.path.exists("/app/static"):
app.frontend("/ui", directory="/app/static", fallback="index.html")

View file

@ -5,12 +5,33 @@ Global database models
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from sqlalchemy import DateTime, JSON from sqlalchemy import DateTime, JSON, func
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class CustomBase(DeclarativeBase): class CustomBase(DeclarativeBase):
type_annotation_map = { type_annotation_map = {
datetime: DateTime(timezone=True), datetime: DateTime(timezone=True),
dict[str, Any]: JSON, dict[str, Any]: JSON,
} }
class ActivatedMixin:
active: Mapped[bool] = mapped_column(default=True)
class DeletedTimestampMixin:
deleted_at: Mapped[datetime | None] = mapped_column(nullable=True)
class DescriptionMixin:
description: Mapped[str]
class IdMixin:
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
class TimestampMixin:
created_at: Mapped[datetime] = mapped_column(default=func.now())
updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now())

View file

@ -10,48 +10,48 @@ from enum import StrEnum, auto
class Status(StrEnum): class Status(StrEnum):
""" """
Enumeration of organisation statuses. Enumeration of organisation statuses.
Attributes: Attributes:
PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted. PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted.
SUBMITTED (str): Questionnaire submitted but not approved. SUBMITTED (str): Questionnaire submitted but not approved.
REMEDIATION (str): Questionnaire submitted but requires revisions. REMEDIATION (str): Questionnaire submitted but requires revisions.
APPROVED (str): Questionnaire has been approved by an admin. APPROVED (str): Questionnaire has been approved by an admin.
REJECTED (str): Questionnaire has been rejected by an admin. REJECTED (str): Questionnaire has been rejected by an admin.
REMOVED (str): Organisation has been removed. REMOVED (str): Organisation has been removed.
""" """
PARTIAL = auto() PARTIAL = auto()
SUBMITTED = auto() SUBMITTED = auto()
REMEDIATION = auto() REMEDIATION = auto()
APPROVED = auto() APPROVED = auto()
REJECTED = auto() REJECTED = auto()
REMOVED = auto() REMOVED = auto()
@property @property
def is_pre_approval(self): def is_pre_approval(self):
return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION) return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION)
@property @property
def is_pre_submission(self): def is_pre_submission(self):
return self in (self.PARTIAL, self.REMEDIATION) return self in (self.PARTIAL, self.REMEDIATION)
@property @property
def is_blocked(self): def is_blocked(self):
return self in (self.REMOVED, self.REJECTED) return self in (self.REMOVED, self.REJECTED)
class ContactType(StrEnum): class ContactType(StrEnum):
""" """
Enumeration of organisation contact types. Enumeration of organisation contact types.
Attributes: Attributes:
BILLING(str): Billing contact. BILLING(str): Billing contact.
SECURITY (str): Security contact. SECURITY (str): Security contact.
OWNER (str): Owner contact. OWNER (str): Owner contact.
""" """
BILLING = auto() BILLING = auto()
SECURITY = auto() SECURITY = auto()
OWNER = auto() OWNER = auto()

View file

@ -18,25 +18,25 @@ from src.organisation.exceptions import OrgNotFoundException
def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org: def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org:
org_model = db.get(Org, org_id) org_model = db.get(Org, org_id)
if org_model is None: if org_model is None:
raise OrgNotFoundException(org_id) raise OrgNotFoundException(org_id)
return org_model return org_model
org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)] org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)]
def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org: def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org:
org_id: Optional[int] = getattr(request_model, "organisation_id", None) org_id: Optional[int] = getattr(request_model, "organisation_id", None)
if org_id is None: if org_id is None:
raise OrgNotFoundException() raise OrgNotFoundException()
org_model = db.get(Org, org_id) org_model = db.get(Org, org_id)
if org_model is None: if org_model is None:
raise OrgNotFoundException(org_id) raise OrgNotFoundException(org_id)
return org_model return org_model
org_model_body_dependency = Annotated[Org, Depends(get_org_model_body)] org_model_body_dependency = Annotated[Org, Depends(get_org_model_body)]

View file

@ -12,26 +12,26 @@ from fastapi import HTTPException, status
class OrgNotFoundException(HTTPException): class OrgNotFoundException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None: def __init__(self, org_id: Optional[int] = None) -> None:
detail = ( detail = (
"Organisation not found" "Organisation not found"
if org_id is None if org_id is None
else f"Organisation with ID '{org_id}' was not found." else f"Organisation with ID '{org_id}' was not found."
) )
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )
class AwaitingApprovalException(HTTPException): class AwaitingApprovalException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None: def __init__(self, org_id: Optional[int] = None) -> None:
detail = ( detail = (
"Organisation has not been approved." "Organisation has not been approved."
if org_id is None if org_id is None
else f"Organisation with ID '{org_id}' has not been approved." else f"Organisation with ID '{org_id}' has not been approved."
) )
super().__init__( super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail, detail=detail,
) )

View file

@ -14,61 +14,62 @@ Models:
- OrgUsers: org_id[FK][PK], user_id[FK][PK] - OrgUsers: org_id[FK][PK], user_id[FK][PK]
""" """
from src.models import IdMixin, DeletedTimestampMixin
from typing import Any from typing import Any
from sqlalchemy import ForeignKey from sqlalchemy import ForeignKey
from sqlalchemy.orm import relationship, Mapped, mapped_column from sqlalchemy.orm import relationship, Mapped, mapped_column
from src.models import CustomBase from src.models import CustomBase, TimestampMixin
class Organisation(CustomBase): class Organisation(CustomBase, IdMixin, TimestampMixin, DeletedTimestampMixin):
__tablename__ = "organisation" __tablename__ = "organisation"
id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str]
name: Mapped[str] status: Mapped[str] = mapped_column(default="partial")
status: Mapped[str] = mapped_column(default="partial") intake_questionnaire: Mapped[dict[str, Any] | None]
intake_questionnaire: Mapped[dict[str, Any] | None]
root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
security_contact_id: Mapped[int] = mapped_column( security_contact_id: Mapped[int] = mapped_column(
ForeignKey("contact.id"), nullable=True ForeignKey("contact.id"), nullable=True
) )
owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel") user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel")
group_rel = relationship( group_rel = relationship(
"Group", back_populates="org_rel", cascade="all, delete-orphan" "Group", back_populates="org_rel", cascade="all, delete-orphan"
) )
root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id") root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id")
billing_contact_rel = relationship( billing_contact_rel = relationship(
"Contact", foreign_keys="Organisation.billing_contact_id" "Contact", foreign_keys="Organisation.billing_contact_id"
) )
security_contact_rel = relationship( security_contact_rel = relationship(
"Contact", foreign_keys="Organisation.security_contact_id" "Contact", foreign_keys="Organisation.security_contact_id"
) )
owner_contact_rel = relationship( owner_contact_rel = relationship(
"Contact", foreign_keys="Organisation.owner_contact_id" "Contact", foreign_keys="Organisation.owner_contact_id"
) )
permission_rel = relationship( permission_rel = relationship(
"Permission", secondary="org_permissions", back_populates="org_rel" "Permission", secondary="org_permissions", back_populates="org_rel"
) )
@property @property
def root_user_email(self) -> str: def root_user_email(self) -> str:
return self.root_user_rel.email if self.root_user_rel else "" return self.root_user_rel.email if self.root_user_rel else ""
class OrgUsers(CustomBase): class OrgUsers(CustomBase):
__tablename__ = "orgusers" __tablename__ = "orgusers"
org_id: Mapped[int] = mapped_column( org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
) )
user_id: Mapped[int] = mapped_column( user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
) )

File diff suppressed because it is too large Load diff

View file

@ -12,139 +12,139 @@ from datetime import datetime
from pydantic import EmailStr, ConfigDict, Field from pydantic import EmailStr, ConfigDict, Field
from src.schemas import ( from src.schemas import (
CustomBaseModel, CustomBaseModel,
OrgIDMixin, OrgIDMixin,
UserIDMixin, UserIDMixin,
GroupSummary, GroupSummary,
OrgSummary, OrgSummary,
UserSummary, UserSummary,
) )
from src.contact.schemas import ContactModel from src.contact.schemas import ContactModel
from src.organisation.constants import Status, ContactType from src.organisation.constants import Status, ContactType
from src.organisation.schemas_questionnaires import ( from src.organisation.schemas_questionnaires import (
QuestionnaireQuestionsVersion0 as CurrentQuestions, QuestionnaireQuestionsVersion0 as CurrentQuestions,
questionnaire_union, questionnaire_union,
) )
class QuestionnaireMetadata(CustomBaseModel): class QuestionnaireMetadata(CustomBaseModel):
version: int version: int
submission_date: Optional[datetime] = None submission_date: Optional[datetime] = None
class Questionnaire(CustomBaseModel): class Questionnaire(CustomBaseModel):
metadata: QuestionnaireMetadata metadata: QuestionnaireMetadata
questions: questionnaire_union questions: questionnaire_union
class ContactSummary(CustomBaseModel): class ContactSummary(CustomBaseModel):
id: int id: int
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
class OrgSchema(OrgIDMixin): class OrgSchema(OrgIDMixin):
name: str name: str
status: Status status: Status
root_user_email: EmailStr root_user_email: EmailStr
intake_questionnaire: Optional[Questionnaire] = None intake_questionnaire: Optional[Questionnaire] = None
billing_contact: ContactSummary billing_contact: ContactSummary
owner_contact: ContactSummary owner_contact: ContactSummary
security_contact: ContactSummary security_contact: ContactSummary
class OrgPostOrgRequest(CustomBaseModel): class OrgPostOrgRequest(CustomBaseModel):
name: str = Field(min_length=3) name: str = Field(min_length=3)
intake_questionnaire: Optional[CurrentQuestions] = None intake_questionnaire: Optional[CurrentQuestions] = None
class OrgPostOrgResponse(CustomBaseModel): class OrgPostOrgResponse(CustomBaseModel):
id: int id: int
name: str name: str
status: Status status: Status
class OrgPatchQuestionnaireRequest(OrgIDMixin): class OrgPatchQuestionnaireRequest(OrgIDMixin):
intake_questionnaire: CurrentQuestions intake_questionnaire: CurrentQuestions
partial: bool partial: bool
class OrgPatchQuestionnaireResponse(CustomBaseModel): class OrgPatchQuestionnaireResponse(CustomBaseModel):
id: int id: int
name: str name: str
intake_questionnaire: Questionnaire intake_questionnaire: Questionnaire
status: Status status: Status
class OrgPatchStatusRequest(OrgIDMixin): class OrgPatchStatusRequest(OrgIDMixin):
status: Status status: Status
class OrgPatchStatusResponse(CustomBaseModel): class OrgPatchStatusResponse(CustomBaseModel):
id: int id: int
name: str name: str
status: Status status: Status
class OrgPatchContactRequest(OrgIDMixin): class OrgPatchContactRequest(OrgIDMixin):
contact_type: ContactType contact_type: ContactType
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
first_name: Optional[str] = None first_name: Optional[str] = None
last_name: Optional[str] = None last_name: Optional[str] = None
phonenumber: Optional[str] = None phonenumber: Optional[str] = None
vat_number: Optional[str] = None vat_number: Optional[str] = None
post_office_box_number: Optional[str] = None post_office_box_number: Optional[str] = None
street_address: Optional[str] = None street_address: Optional[str] = None
street_address_line_2: Optional[str] = None street_address_line_2: Optional[str] = None
locality: Optional[str] = None locality: Optional[str] = None
address_region: Optional[str] = None address_region: Optional[str] = None
country_code: Optional[str] = None country_code: Optional[str] = None
postal_code: Optional[str] = None postal_code: Optional[str] = None
class OrgPostUserRequest(OrgIDMixin, UserIDMixin): class OrgPostUserRequest(OrgIDMixin, UserIDMixin):
pass pass
class OrgPostUserResponse(CustomBaseModel): class OrgPostUserResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
users: list[UserSummary] users: list[UserSummary]
class OrgPatchRootRequest(OrgIDMixin, UserIDMixin): class OrgPatchRootRequest(OrgIDMixin, UserIDMixin):
pass pass
class OrgPatchRootResponse(CustomBaseModel): class OrgPatchRootResponse(CustomBaseModel):
name: str name: str
root_user_email: str root_user_email: str
class OrgGetUserResponse(CustomBaseModel): class OrgGetUserResponse(CustomBaseModel):
users: list[dict[str, str | int]] users: list[dict[str, str | int]]
organisation: OrgSummary organisation: OrgSummary
class OrgGetGroupResponse(CustomBaseModel): class OrgGetGroupResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
groups: list[GroupSummary] groups: list[GroupSummary]
class OrgGetContactResponse(CustomBaseModel): class OrgGetContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel contact: ContactModel
organisation: OrgSummary organisation: OrgSummary
class OrgPatchContactResponse(CustomBaseModel): class OrgPatchContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel contact: ContactModel
organisation: OrgSummary organisation: OrgSummary
class OrgGetOrgResponse(CustomBaseModel): class OrgGetOrgResponse(CustomBaseModel):
organisations: list[OrgSchema] organisations: list[OrgSchema]

View file

@ -4,13 +4,13 @@ from src.schemas import CustomBaseModel
class QuestionnaireQuestions(CustomBaseModel): class QuestionnaireQuestions(CustomBaseModel):
pass pass
class QuestionnaireQuestionsVersion0(QuestionnaireQuestions): class QuestionnaireQuestionsVersion0(QuestionnaireQuestions):
question_one: Optional[str] = None question_one: Optional[str] = None
question_two: Optional[str] = None question_two: Optional[str] = None
question_three: Optional[str] = None question_three: Optional[str] = None
questionnaire_union = QuestionnaireQuestionsVersion0 # | QuestionnaireQuestionsVersion1 questionnaire_union = QuestionnaireQuestionsVersion0 # | QuestionnaireQuestionsVersion1

View file

@ -11,57 +11,57 @@ from src.user.models import User
async def add_default_org_permissions( async def add_default_org_permissions(
db: Session, db: Session,
org_model: Org, org_model: Org,
perm_list: list[int], perm_list: list[int],
): ):
for permission in perm_list: for permission in perm_list:
perm_model = db.get(Perm, permission) perm_model = db.get(Perm, permission)
if perm_model is None: if perm_model is None:
continue continue
if perm_model in org_model.permission_rel: if perm_model in org_model.permission_rel:
continue continue
org_model.permission_rel.append(perm_model) org_model.permission_rel.append(perm_model)
db.flush() db.flush()
db.commit() db.commit()
async def assign_defaults( async def assign_defaults(
db: Session, db: Session,
org_id: int, org_id: int,
user_id: int, user_id: int,
): ):
default_org_permissions = [] default_org_permissions = []
default_user_permissions = [] default_user_permissions = []
org_model = db.get(Org, org_id) org_model = db.get(Org, org_id)
if org_model is None: if org_model is None:
print("Org not found while adding defaults") print("Org not found while adding defaults")
return return
user_model = db.get(User, user_id) user_model = db.get(User, user_id)
if user_model is None: if user_model is None:
print("User not found while adding defaults") print("User not found while adding defaults")
return return
await add_default_org_permissions(db, org_model, default_org_permissions) await add_default_org_permissions(db, org_model, default_org_permissions)
await assign_default_group( await assign_default_group(
db=db, db=db,
org_model=org_model, org_model=org_model,
user_model=user_model, user_model=user_model,
group_name="Default Users", group_name="Default Users",
perm_list=default_user_permissions, perm_list=default_user_permissions,
) )
await assign_default_group( await assign_default_group(
db=db, db=db,
org_model=org_model, org_model=org_model,
user_model=user_model, user_model=user_model,
group_name="Root User", group_name="Root User",
perm_list=default_org_permissions, perm_list=default_org_permissions,
) )
db.commit() db.commit()

View file

@ -11,54 +11,54 @@ from typing import Optional
class CustomBaseModel(BaseModel): class CustomBaseModel(BaseModel):
pass pass
### Mixins ### ### Mixins ###
class OrgIDMixin(CustomBaseModel): class OrgIDMixin(CustomBaseModel):
organisation_id: int = Field(gt=0) organisation_id: int = Field(gt=0)
class GroupIDMixin(CustomBaseModel): class GroupIDMixin(CustomBaseModel):
group_id: int = Field(gt=0) group_id: int = Field(gt=0)
class PermIDMixin(CustomBaseModel): class PermIDMixin(CustomBaseModel):
permission_id: int = Field(gt=0) permission_id: int = Field(gt=0)
class ServiceIDMixin(CustomBaseModel): class ServiceIDMixin(CustomBaseModel):
service_id: int = Field(gt=0) service_id: int = Field(gt=0)
class UserIDMixin(CustomBaseModel): class UserIDMixin(CustomBaseModel):
user_id: int = Field(gt=0) user_id: int = Field(gt=0)
class ServiceNameMixin(CustomBaseModel): class ServiceNameMixin(CustomBaseModel):
service: str service: str
class OrgSummary(CustomBaseModel): class OrgSummary(CustomBaseModel):
id: int id: int
name: str name: str
class GroupSummary(CustomBaseModel): class GroupSummary(CustomBaseModel):
id: int id: int
name: str name: str
class UserSummary(CustomBaseModel): class UserSummary(CustomBaseModel):
id: int id: int
email: str email: str
class ServiceSummary(CustomBaseModel): class ServiceSummary(CustomBaseModel):
id: int id: int
name: str name: str
class ResourceName(ServiceNameMixin, OrgIDMixin): class ResourceName(ServiceNameMixin, OrgIDMixin):
resource: str resource: str
instance: Optional[str] = None instance: Optional[str] = None

View file

@ -16,25 +16,23 @@ from src.service.models import Service
from src.service.schemas import ServiceIDMixin from src.service.schemas import ServiceIDMixin
async def get_service_model_query( async def get_service_model_query(db: DbSession, service_id: Annotated[int, Query(gt=0)]):
db: DbSession, service_id: Annotated[int, Query(gt=0)] service_model = db.get(Service, service_id)
): if service_model is None:
service_model = db.get(Service, service_id) raise ServiceNotFoundException(service_id=service_id)
if service_model is None:
raise ServiceNotFoundException(service_id=service_id)
return service_model return service_model
service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)] service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)]
async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin): async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin):
service_model = db.get(Service, request_model.service_id) service_model = db.get(Service, request_model.service_id)
if service_model is None: if service_model is None:
raise ServiceNotFoundException(service_id=request_model.service_id) raise ServiceNotFoundException(service_id=request_model.service_id)
return service_model return service_model
service_model_body_dependency = Annotated[Service, Depends(get_service_model_body)] service_model_body_dependency = Annotated[Service, Depends(get_service_model_body)]

View file

@ -11,13 +11,13 @@ from fastapi import HTTPException, status
class ServiceNotFoundException(HTTPException): class ServiceNotFoundException(HTTPException):
def __init__(self, service_id: Optional[int] = None) -> None: def __init__(self, service_id: Optional[int] = None) -> None:
detail = ( detail = (
"Service not found" "Service not found"
if service_id is None if service_id is None
else f"Service with ID '{service_id}' was not found." else f"Service with ID '{service_id}' was not found."
) )
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )

View file

@ -8,16 +8,15 @@ Models:
from sqlalchemy.orm import relationship, mapped_column, Mapped from sqlalchemy.orm import relationship, mapped_column, Mapped
from src.models import CustomBase from src.models import CustomBase, IdMixin
class Service(CustomBase): class Service(CustomBase, IdMixin):
__tablename__ = "service" __tablename__ = "service"
id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(unique=True)
name: Mapped[str] = mapped_column(unique=True) api_key: Mapped[str]
api_key: Mapped[str]
permission_rel = relationship( permission_rel = relationship(
"Permission", back_populates="service_rel", cascade="all, delete-orphan" "Permission", back_populates="service_rel", cascade="all, delete-orphan"
) )

View file

@ -15,8 +15,8 @@ from psycopg.errors import UniqueViolation
from src.exceptions import ConflictException from src.exceptions import ConflictException
from src.database import DbSession from src.database import DbSession
from src.auth.dependencies import ( from src.auth.dependencies import (
super_admin_dependency, super_admin_dependency,
org_model_root_claim_query_dependency, org_model_root_claim_query_dependency,
) )
from src.iam.service import service_key_dependency from src.iam.service import service_key_dependency
from src.iam.models import Permission as Perm from src.iam.models import Permission as Perm
@ -25,212 +25,210 @@ from src.service.exceptions import ServiceNotFoundException
from src.service.models import Service from src.service.models import Service
from src.service.utils import generate_api_key from src.service.utils import generate_api_key
from src.service.dependencies import ( from src.service.dependencies import (
service_model_body_dependency, service_model_body_dependency,
service_model_query_dependency, service_model_query_dependency,
) )
from src.service.schemas import ( from src.service.schemas import (
ServiceGetServiceResponse, ServiceGetServiceResponse,
ServicePostServiceRequest, ServicePostServiceRequest,
ServicePostServiceResponse, ServicePostServiceResponse,
ServiceWithKeySchema, ServiceWithKeySchema,
ServicePatchKeyResponse, ServicePatchKeyResponse,
ServicePatchKeyRequest, ServicePatchKeyRequest,
ServicePostPermissionsResponse, ServicePostPermissionsResponse,
ServicePostPermissionsRequest, ServicePostPermissionsRequest,
) )
router = APIRouter( router = APIRouter(
tags=["Service"], tags=["Service"],
prefix="/service", prefix="/service",
) )
@router.get( @router.get(
"", "",
summary="Get all services", summary="Get all services",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=ServiceGetServiceResponse, response_model=ServiceGetServiceResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
status.HTTP_401_UNAUTHORIZED: { status.HTTP_401_UNAUTHORIZED: {
"description": "Unauthorized", "description": "Unauthorized",
"content": { "content": {
"application/json": { "application/json": {
"examples": { "examples": {
"awaiting_approval": { "awaiting_approval": {
"summary": "Organisation has not yet been approved." "summary": "Organisation has not yet been approved."
}, },
} }
} }
}, },
}, },
status.HTTP_403_FORBIDDEN: { status.HTTP_403_FORBIDDEN: {
"description": "Forbidden", "description": "Forbidden",
"content": { "content": {
"application/json": { "application/json": {
"examples": { "examples": {
"not_root": {"summary": "Not authorised. Must be root user."}, "not_root": {"summary": "Not authorised. Must be root user."},
} }
} }
}, },
}, },
}, },
) )
async def get_all_services( async def get_all_services(db: DbSession, org_model: org_model_root_claim_query_dependency):
db: DbSession, org_model: org_model_root_claim_query_dependency """
): Returns the ID and name of all services registered to the hub.
""" """
Returns the ID and name of all services registered to the hub. permission_models = db.query(Service).all()
"""
permission_models = db.query(Service).all()
return {"services": permission_models} return {"services": permission_models}
@router.post( @router.post(
"", "",
summary="Register a new service.", summary="Register a new service.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=ServicePostServiceResponse, response_model=ServicePostServiceResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successfully registered a new service"}, status.HTTP_200_OK: {"description": "Successfully registered a new service"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"}, status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"},
}, },
) )
async def register_service( async def register_service(
db: DbSession, db: DbSession,
su: super_admin_dependency, su: super_admin_dependency,
request_model: ServicePostServiceRequest, request_model: ServicePostServiceRequest,
): ):
""" """
Registers a new service to the hub, generating and returning an API key for it. Registers a new service to the hub, generating and returning an API key for it.
""" """
key = generate_api_key() key = generate_api_key()
service_model = Service(name=request_model.name, api_key=key) service_model = Service(name=request_model.name, api_key=key)
db.add(service_model) db.add(service_model)
try: try:
db.flush() db.flush()
except IntegrityError as e: except IntegrityError as e:
if ( if (
isinstance(e.orig, UniqueViolation) # Postgres unique violation isinstance(e.orig, UniqueViolation) # Postgres unique violation
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
): ):
raise ConflictException(message="Service with this name already exists") raise ConflictException(message="Service with this name already exists")
raise raise
response = ServiceWithKeySchema(**service_model.__dict__) response = ServiceWithKeySchema(**service_model.__dict__)
db.commit() db.commit()
return {"service": response} return {"service": response}
@router.patch( @router.patch(
"/key", "/key",
summary="Regenerate service API key.", summary="Regenerate service API key.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=ServicePatchKeyResponse, response_model=ServicePatchKeyResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successful update of API key"}, status.HTTP_200_OK: {"description": "Successful update of API key"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
}, },
) )
async def regenerate_api_key( async def regenerate_api_key(
db: DbSession, db: DbSession,
su: super_admin_dependency, su: super_admin_dependency,
service_model: service_model_body_dependency, service_model: service_model_body_dependency,
request_model: ServicePatchKeyRequest, request_model: ServicePatchKeyRequest,
): ):
""" """
Generates and returns a new API key for the service to access the hub. Generates and returns a new API key for the service to access the hub.
""" """
key = generate_api_key() key = generate_api_key()
service_model.api_key = key service_model.api_key = key
db.flush() db.flush()
response = ServiceWithKeySchema(**service_model.__dict__) response = ServiceWithKeySchema(**service_model.__dict__)
db.commit() db.commit()
return {"service": response} return {"service": response}
@router.delete( @router.delete(
"", "",
summary="Remove a service.", summary="Remove a service.",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
responses={ responses={
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"}, status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
}, },
) )
async def remove_service( async def remove_service(
db: DbSession, db: DbSession,
service_model: service_model_query_dependency, service_model: service_model_query_dependency,
su: super_admin_dependency, su: super_admin_dependency,
): ):
""" """
Removes a service from the hub. Removes a service from the hub.
""" """
db.delete(service_model) db.delete(service_model)
db.commit() db.commit()
@router.post( @router.post(
path="/permissions", path="/permissions",
summary="Service endpoint for creating its own permissions.", summary="Service endpoint for creating its own permissions.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=ServicePostPermissionsResponse, response_model=ServicePostPermissionsResponse,
responses={ responses={
status.HTTP_401_UNAUTHORIZED: { status.HTTP_401_UNAUTHORIZED: {
"description": "API Key missing or invalid | Issue verifying user OIDC claims" "description": "API Key missing or invalid | Issue verifying user OIDC claims"
}, },
}, },
) )
async def service_create_new_permissions( async def service_create_new_permissions(
db: DbSession, db: DbSession,
request_model: ServicePostPermissionsRequest, request_model: ServicePostPermissionsRequest,
valid_key: service_key_dependency, valid_key: service_key_dependency,
): ):
""" """
Allows a service to register its own set of permissions. Allows a service to register its own set of permissions.
""" """
service_model = ( service_model = (
db.query(Service).filter(Service.name == request_model.rn.service).first() db.query(Service).filter(Service.name == request_model.rn.service).first()
) )
if service_model is None: if service_model is None:
raise ServiceNotFoundException() raise ServiceNotFoundException()
else: else:
service_id = service_model.id service_id = service_model.id
response_list = [] response_list = []
for new_permission in request_model.permissions: for new_permission in request_model.permissions:
perm_model = ( perm_model = (
db.query(Perm) db.query(Perm)
.filter(Perm.service_id == service_id) .filter(Perm.service_id == service_id)
.filter(Perm.resource == new_permission.resource) .filter(Perm.resource == new_permission.resource)
.filter(Perm.action == new_permission.action) .filter(Perm.action == new_permission.action)
.first() .first()
) )
if perm_model is not None: if perm_model is not None:
response_code = 409 response_code = 409
response = { response = {
"id": perm_model.id, "id": perm_model.id,
"service_name": perm_model.service_name, "service_name": perm_model.service_name,
"resource": perm_model.resource, "resource": perm_model.resource,
"action": perm_model.action, "action": perm_model.action,
} }
response_list.append((response, response_code)) response_list.append((response, response_code))
continue continue
new_perm_model = Perm(**new_permission.__dict__) new_perm_model = Perm(**new_permission.__dict__)
new_perm_model.service_id = service_id new_perm_model.service_id = service_id
db.add(new_perm_model) db.add(new_perm_model)
db.flush() db.flush()
response_code = 201 response_code = 201
response = { response = {
"id": new_perm_model.id, "id": new_perm_model.id,
"service_name": new_perm_model.service_name, "service_name": new_perm_model.service_name,
"resource": new_perm_model.resource, "resource": new_perm_model.resource,
"action": new_perm_model.action, "action": new_perm_model.action,
} }
response_list.append((response, response_code)) response_list.append((response, response_code))
db.commit() db.commit()
return {"permissions": response_list} return {"permissions": response_list}

View file

@ -10,10 +10,10 @@ from typing import Generic, TypeVar
from pydantic import Field, ConfigDict from pydantic import Field, ConfigDict
from src.schemas import ( from src.schemas import (
CustomBaseModel, CustomBaseModel,
ServiceIDMixin, ServiceIDMixin,
ServiceSummary, ServiceSummary,
ServiceNameMixin, ServiceNameMixin,
) )
@ -21,51 +21,51 @@ T = TypeVar("T", bound=ServiceNameMixin)
class HasServiceName(CustomBaseModel, Generic[T]): class HasServiceName(CustomBaseModel, Generic[T]):
rn: T rn: T
class PermissionResponseSchema(CustomBaseModel): class PermissionResponseSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int id: int
service_name: str service_name: str
resource: str resource: str
action: str action: str
class PermissionRequestSchema(CustomBaseModel): class PermissionRequestSchema(CustomBaseModel):
resource: str resource: str
action: str action: str
class ServiceWithKeySchema(ServiceSummary): class ServiceWithKeySchema(ServiceSummary):
api_key: str api_key: str
class ServiceGetServiceResponse(CustomBaseModel): class ServiceGetServiceResponse(CustomBaseModel):
services: list[ServiceSummary] services: list[ServiceSummary]
class ServicePostServiceRequest(CustomBaseModel): class ServicePostServiceRequest(CustomBaseModel):
name: str = Field(min_length=3) name: str = Field(min_length=3)
class ServicePostServiceResponse(CustomBaseModel): class ServicePostServiceResponse(CustomBaseModel):
service: ServiceWithKeySchema service: ServiceWithKeySchema
class ServicePatchKeyRequest(ServiceIDMixin): class ServicePatchKeyRequest(ServiceIDMixin):
pass pass
class ServicePatchKeyResponse(CustomBaseModel): class ServicePatchKeyResponse(CustomBaseModel):
service: ServiceWithKeySchema service: ServiceWithKeySchema
class ServicePostPermissionsRequest(CustomBaseModel): class ServicePostPermissionsRequest(CustomBaseModel):
rn: ServiceNameMixin rn: ServiceNameMixin
permissions: list[PermissionRequestSchema] permissions: list[PermissionRequestSchema]
class ServicePostPermissionsResponse(CustomBaseModel): class ServicePostPermissionsResponse(CustomBaseModel):
permissions: list[tuple[PermissionResponseSchema, int]] permissions: list[tuple[PermissionResponseSchema, int]]

View file

@ -9,4 +9,4 @@ import uuid
def generate_api_key() -> str: def generate_api_key() -> str:
return str(uuid.uuid4()) return str(uuid.uuid4())

View file

@ -10,46 +10,50 @@ Exports:
from typing import Annotated from typing import Annotated
from fastapi import Depends, Query from fastapi import Depends, Query
from src.user.exceptions import UserNotFoundException
from src.user.models import User
from src.auth.service import claims_dependency from src.auth.service import claims_dependency
from src.database import DbSession from src.database import DbSession
from src.schemas import UserIDMixin from src.schemas import UserIDMixin
from src.exceptions import ForbiddenException
from src.user.exceptions import UserNotFoundException
from src.user.models import User
async def get_user_model_claims(claims: claims_dependency, db: DbSession): async def get_user_model_claims(claims: claims_dependency, db: DbSession):
user_id = claims.get("db_id", None) user_id = claims.get("db_id", None)
if user_id is None: if user_id is None:
raise UserNotFoundException() raise UserNotFoundException()
user_model = db.get(User, user_id) user_model = db.get(User, user_id)
if user_model is None: if user_model is None:
raise UserNotFoundException(user_id=user_id) raise UserNotFoundException(user_id=user_id)
return user_model if not user_model.active:
raise ForbiddenException("User account is not active")
return user_model
user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)] user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)]
async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]): async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]):
user_model = db.get(User, user_id) user_model = db.get(User, user_id)
if user_model is None: if user_model is None:
raise UserNotFoundException(user_id=user_id) raise UserNotFoundException(user_id=user_id)
return user_model return user_model
user_model_query_dependency = Annotated[User, Depends(get_user_model_query)] user_model_query_dependency = Annotated[User, Depends(get_user_model_query)]
async def get_user_model_body(db: DbSession, request_model: UserIDMixin): async def get_user_model_body(db: DbSession, request_model: UserIDMixin):
user_model = db.get(User, request_model.user_id) user_model = db.get(User, request_model.user_id)
if user_model is None: if user_model is None:
raise UserNotFoundException(user_id=request_model.user_id) raise UserNotFoundException(user_id=request_model.user_id)
return user_model return user_model
user_model_body_dependency = Annotated[User, Depends(get_user_model_body)] user_model_body_dependency = Annotated[User, Depends(get_user_model_body)]

View file

@ -11,13 +11,13 @@ from fastapi import HTTPException, status
class UserNotFoundException(HTTPException): class UserNotFoundException(HTTPException):
def __init__(self, user_id: Optional[int] = None) -> None: def __init__(self, user_id: Optional[int] = None) -> None:
detail = ( detail = (
"User not found" "User not found"
if user_id is None if user_id is None
else f"User with ID '{user_id}' was not found." else f"User with ID '{user_id}' was not found."
) )
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )

View file

@ -10,6 +10,8 @@ Models:
- groups: Calc property dict of {group_rel.org_rel.name: group_rel.name} - groups: Calc property dict of {group_rel.org_rel.name: group_rel.name}
""" """
from src.models import IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin
from collections import defaultdict from collections import defaultdict
from sqlalchemy.orm import relationship, mapped_column, Mapped from sqlalchemy.orm import relationship, mapped_column, Mapped
@ -17,28 +19,27 @@ from sqlalchemy.orm import relationship, mapped_column, Mapped
from src.models import CustomBase from src.models import CustomBase
class User(CustomBase): class User(CustomBase, IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin):
__tablename__ = "user" __tablename__ = "user"
id: Mapped[int] = mapped_column(primary_key=True) email: Mapped[str]
email: Mapped[str] first_name: Mapped[str]
first_name: Mapped[str] last_name: Mapped[str]
last_name: Mapped[str] oidc_id: Mapped[str] = mapped_column(index=True, unique=True)
oidc_id: Mapped[str] = mapped_column(index=True, unique=True)
organisation_rel = relationship( organisation_rel = relationship(
"Organisation", secondary="orgusers", back_populates="user_rel" "Organisation", secondary="orgusers", back_populates="user_rel"
) )
group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel") group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel")
@property @property
def organisations(self): def organisations(self):
return [{"name": org.name, "id": org.id} for org in self.organisation_rel] return [{"name": org.name, "id": org.id} for org in self.organisation_rel]
@property @property
def groups(self): def groups(self):
result = defaultdict(list) result = defaultdict(list)
for group in self.group_rel: for group in self.group_rel:
result[group.org_rel.name].append({"name": group.name, "id": group.id}) result[group.org_rel.name].append({"name": group.name, "id": group.id})
return dict(result) return dict(result)

View file

@ -8,210 +8,213 @@ Endpoints:
- [DELETE](/user/): [super admin]: Removes a User(id) from the hub database. - [DELETE](/user/): [super admin]: Removes a User(id) from the hub database.
""" """
from datetime import datetime, timezone
from fastapi import APIRouter, status, BackgroundTasks from fastapi import APIRouter, status, BackgroundTasks
from src.iam.models import Group from src.iam.models import Group
from src.organisation.exceptions import OrgNotFoundException from src.organisation.exceptions import OrgNotFoundException
from src.user.schemas import ( from src.user.schemas import (
UserResponse, UserResponse,
OIDCClaims, OIDCClaims,
UserPostInvitationRequest, UserPostInvitationRequest,
UserPostInvitationAcceptRequest, UserPostInvitationAcceptRequest,
UserGetSelfOrgsResponse, UserGetSelfOrgsResponse,
UserPostInvitationResponse, UserPostInvitationResponse,
UserPostInvitationAcceptResponse, UserPostInvitationAcceptResponse,
) )
from src.user.dependencies import ( from src.user.dependencies import (
user_model_claims_dependency, user_model_claims_dependency,
user_model_query_dependency, user_model_query_dependency,
) )
from src.user.service import send_invitation from src.user.service import send_invitation
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
from src.auth.dependencies import ( from src.auth.dependencies import (
super_admin_dependency, super_admin_dependency,
org_model_root_claim_body_dependency, org_model_root_claim_body_dependency,
) )
from src.auth.service import claims_dependency from src.auth.service import claims_dependency
from src.database import DbSession from src.database import DbSession
from src.utils import verify_email_token from src.utils import verify_email_token
router = APIRouter( router = APIRouter(
prefix="/user", prefix="/user",
tags=["User"], tags=["User"],
) )
@router.get( @router.get(
"/self/claims", "/self/claims",
summary="Get current user OIDC claims.", summary="Get current user OIDC claims.",
response_model=OIDCClaims, response_model=OIDCClaims,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
responses={ responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
}, },
) )
async def current_user_claims(user: claims_dependency): async def current_user_claims(user: claims_dependency):
""" """
Returns the full OIDC claims associated with the currently logged-in user. Returns the full OIDC claims associated with the currently logged-in user.
""" """
user["allowed_origins"] = user.get("allowed-origins", []) user["allowed_origins"] = user.get("allowed-origins", [])
return user return user
@router.get( @router.get(
"/self/db", "/self/db",
summary="Get current user hub details.", summary="Get current user hub details.",
response_model=UserResponse, response_model=UserResponse,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
responses={ responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
}, },
) )
async def current_user(user_model: user_model_claims_dependency): async def current_user(user_model: user_model_claims_dependency):
""" """
Returns the database details associated with the currently logged-in user. Returns the database details associated with the currently logged-in user.
""" """
return user_model return user_model
@router.get( @router.get(
"", "",
summary="Get user hub details by ID.", summary="Get user hub details by ID.",
response_model=UserResponse, response_model=UserResponse,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
responses={ responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
}, },
) )
async def get_user_by_id( async def get_user_by_id(
user_model: user_model_query_dependency, su: super_admin_dependency user_model: user_model_query_dependency, su: super_admin_dependency
): ):
""" """
Returns the database details associated with the provided user ID. Returns the database details associated with the provided user ID.
""" """
return user_model return user_model
@router.delete( @router.delete(
"", "",
summary="Delete user from hub by ID.", summary="Delete user from hub by ID.",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
responses={ responses={
status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, status.HTTP_404_NOT_FOUND: {"description": "User not found"},
}, },
) )
async def delete_user_by_id( async def soft_delete_user_by_id(
db: DbSession, db: DbSession,
user_model: user_model_query_dependency, user_model: user_model_query_dependency,
su: super_admin_dependency, su: super_admin_dependency,
): ):
""" """
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login. Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
""" """
db.delete(user_model) user_model.active = False
db.commit() user_model.deleted_at = datetime.now(tz=timezone.utc)
db.commit()
@router.get( @router.get(
"/self/orgs", "/self/orgs",
summary="Get all orgs the current user is a member of", summary="Get all orgs the current user is a member of",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=UserGetSelfOrgsResponse, response_model=UserGetSelfOrgsResponse,
responses={}, responses={},
) )
async def get_user_orgs(user_model: user_model_claims_dependency): async def get_user_orgs(user_model: user_model_claims_dependency):
user_orgs = user_model.organisation_rel user_orgs = user_model.organisation_rel
response = [] response = []
for org in user_orgs: for org in user_orgs:
response.append( response.append(
{ {
"organisation_id": org.id, "organisation_id": org.id,
"name": org.name, "name": org.name,
"status": org.status, "status": org.status,
"intake_questionnaire": org.intake_questionnaire, "intake_questionnaire": org.intake_questionnaire,
"root_user_email": org.root_user_email, "root_user_email": org.root_user_email,
"billing_contact": { "billing_contact": {
"id": org.billing_contact_id, "id": org.billing_contact_id,
"email": org.billing_contact_rel.email, "email": org.billing_contact_rel.email,
}, },
"owner_contact": { "owner_contact": {
"id": org.owner_contact_id, "id": org.owner_contact_id,
"email": org.owner_contact_rel.email, "email": org.owner_contact_rel.email,
}, },
"security_contact": { "security_contact": {
"id": org.security_contact_id, "id": org.security_contact_id,
"email": org.security_contact_rel.email, "email": org.security_contact_rel.email,
}, },
} }
) )
return {"organisations": response} return {"organisations": response}
@router.post( @router.post(
"/invitation", "/invitation",
summary="Send an email invitation for a user to join an org", summary="Send an email invitation for a user to join an org",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=UserPostInvitationResponse, response_model=UserPostInvitationResponse,
) )
async def invitation( async def invitation(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
org_model: org_model_root_claim_body_dependency, org_model: org_model_root_claim_body_dependency,
request_model: UserPostInvitationRequest, request_model: UserPostInvitationRequest,
): ):
org_id = org_model.id org_id = org_model.id
org_name = org_model.name org_name = org_model.name
user_email = request_model.user_email user_email = request_model.user_email
background_tasks.add_task( background_tasks.add_task(
send_invitation, org_id=org_id, org_name=org_name, user_email=user_email send_invitation, org_id=org_id, org_name=org_name, user_email=user_email
) )
response = { response = {
"organisation": org_model, "organisation": org_model,
"invited_email": user_email, "invited_email": user_email,
} }
return response return response
@router.post( @router.post(
"/invitation/accept", "/invitation/accept",
summary="Accept email invitation to join an org", summary="Accept email invitation to join an org",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=UserPostInvitationAcceptResponse, response_model=UserPostInvitationAcceptResponse,
) )
async def accept_invitation( async def accept_invitation(
db: DbSession, db: DbSession,
user_model: user_model_claims_dependency, user_model: user_model_claims_dependency,
request_model: UserPostInvitationAcceptRequest, request_model: UserPostInvitationAcceptRequest,
): ):
email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model) email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model)
org_model = db.get(Org, email_claims["org_id"]) org_model = db.get(Org, email_claims["org_id"])
if org_model is None: if org_model is None:
raise OrgNotFoundException() raise OrgNotFoundException()
org_model.user_rel.append(user_model) org_model.user_rel.append(user_model)
db.flush() db.flush()
group_model = ( group_model = (
db.query(Group) db.query(Group)
.filter(Group.org_id == org_model.id) .filter(Group.org_id == org_model.id)
.filter(Group.name == "Default Users") .filter(Group.name == "Default Users")
.first() .first()
) )
if group_model is not None: if group_model is not None:
user_model.group_rel.append(group_model) user_model.group_rel.append(group_model)
response = { response = {
"organisation": org_model, "organisation": org_model,
"user": user_model, "user": user_model,
} }
db.commit() db.commit()
return response return response

View file

@ -10,63 +10,63 @@ from src.schemas import CustomBaseModel, OrgIDMixin, OrgSummary, UserSummary
class OIDCClaims(CustomBaseModel): class OIDCClaims(CustomBaseModel):
exp: int exp: int
iat: int iat: int
auth_time: int auth_time: int
jti: str jti: str
iss: str iss: str
aud: str aud: str
sub: str sub: str
typ: str typ: str
azp: str azp: str
sid: str sid: str
acr: str acr: str
allowed_origins: list[str] allowed_origins: list[str]
realm_access: dict[str, list[str]] realm_access: dict[str, list[str]]
resource_access: dict[str, dict[str, list[str]]] resource_access: dict[str, dict[str, list[str]]]
scope: str scope: str
email_verified: bool email_verified: bool
name: str name: str
preferred_username: str preferred_username: str
given_name: str given_name: str
family_name: str family_name: str
email: str email: str
db_id: int db_id: int
class OIDCUser(CustomBaseModel): class OIDCUser(CustomBaseModel):
first_name: str first_name: str
last_name: str last_name: str
email: str email: str
oidc_id: str oidc_id: str
class UserResponse(CustomBaseModel): class UserResponse(CustomBaseModel):
id: int id: int
first_name: str first_name: str
last_name: str last_name: str
email: str email: str
organisations: list[Optional[dict[str, str | int]]] organisations: list[Optional[dict[str, str | int]]]
groups: Optional[dict[str, list[dict[str, str | int]]]] = None groups: Optional[dict[str, list[dict[str, str | int]]]] = None
class UserPostInvitationRequest(OrgIDMixin): class UserPostInvitationRequest(OrgIDMixin):
user_email: EmailStr user_email: EmailStr
class UserPostInvitationAcceptRequest(CustomBaseModel): class UserPostInvitationAcceptRequest(CustomBaseModel):
jwt: str jwt: str
class UserGetSelfOrgsResponse(CustomBaseModel): class UserGetSelfOrgsResponse(CustomBaseModel):
organisations: list[OrgSchema] organisations: list[OrgSchema]
class UserPostInvitationResponse(CustomBaseModel): class UserPostInvitationResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
invited_email: EmailStr invited_email: EmailStr
class UserPostInvitationAcceptResponse(CustomBaseModel): class UserPostInvitationAcceptResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
user: UserSummary user: UserSummary

View file

@ -1,13 +1,11 @@
""" """
Module specific business logic for user module Module specific business logic for user module
Exports:
- add_user_to_db: Creates a User record from OIDC claims, or updates user details
""" """
from typing import Any from typing import Any
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
import logging
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from src.exceptions import UnprocessableContentException from src.exceptions import UnprocessableContentException
@ -17,57 +15,50 @@ from src.user.schemas import OIDCUser
from src.user.models import User from src.user.models import User
async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: async def add_user(db: Session, user_claims: dict[str, Any]) -> int:
try: try:
valid_user = OIDCUser( valid_user = OIDCUser(
first_name=user_claims["given_name"], first_name=user_claims["given_name"],
last_name=user_claims["family_name"], last_name=user_claims["family_name"],
email=user_claims["email"], email=user_claims["email"],
oidc_id=user_claims["sub"], oidc_id=user_claims["sub"],
) )
except Exception as e: except Exception as e:
print(e) logging.exception(e)
raise UnprocessableContentException("Invalid or missing OIDC data") raise UnprocessableContentException("Invalid or missing OIDC data")
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first() db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
if not db_user: if not db_user:
user_model = User(**valid_user.model_dump()) user_model = User(**valid_user.model_dump())
db.add(user_model) db.add(user_model)
user_id = user_model.id user_id = user_model.id
db.commit() db.commit()
return user_id return user_id
else:
user_id = db_user.id user_id = db_user.id
change = False db_user.first_name = valid_user.first_name
if db_user.first_name != valid_user.first_name: db_user.last_name = valid_user.last_name
db_user.first_name = valid_user.first_name db.commit()
change = True return user_id
if db_user.last_name != valid_user.last_name:
db_user.last_name = valid_user.last_name
change = True
if change:
db.add(db_user)
db.commit()
return user_id
async def send_invitation(user_email: str, org_name: str, org_id: int): async def send_invitation(user_email: str, org_name: str, org_id: int):
expiry_delta = timedelta(hours=24) expiry_delta = timedelta(hours=24)
expiry = datetime.now(timezone.utc) + expiry_delta expiry = datetime.now(timezone.utc) + expiry_delta
claims = { claims = {
"email": user_email, "email": user_email,
"org_id": org_id, "org_id": org_id,
"exp": expiry, "exp": expiry,
"type": "org_invite", "type": "org_invite",
} }
token = await generate_jwt(claims) token = await generate_jwt(claims)
subject = f"You have been invited to join {org_name}" subject = f"You have been invited to join {org_name}"
body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}" body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
await send_email( await send_email(
recipient=user_email, recipient=user_email,
subject=subject, subject=subject,
body=body, body=body,
) )

View file

@ -1,3 +1,5 @@
import logging
from lettermint import Lettermint, ValidationError from lettermint import Lettermint, ValidationError
from datetime import datetime, timezone from datetime import datetime, timezone
from joserfc import jwt, jwk, errors from joserfc import jwt, jwk, errors
@ -9,51 +11,56 @@ KEY = jwk.import_key(settings.SECRET_KEY.get_secret_value(), "oct")
async def generate_jwt(claims): async def generate_jwt(claims):
jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims) jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims)
return jwt_token return jwt_token
async def decode_jwt(encoded): async def decode_jwt(encoded):
try: try:
token = jwt.decode(encoded, key=KEY) token = jwt.decode(encoded, key=KEY)
return token.claims return token.claims
except errors.DecodeError: except errors.DecodeError:
raise UnauthorizedException("Invalid JWS") raise UnauthorizedException("Invalid JWS")
async def verify_email_token(user_model, token): async def verify_email_token(user_model, token):
email_claims = await decode_jwt(token) email_claims = await decode_jwt(token)
claimed_email = email_claims["email"] claimed_email = email_claims["email"]
expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc) expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc)
if expiry < datetime.now(timezone.utc): if expiry < datetime.now(timezone.utc):
raise UnauthorizedException("Invitation expired.") raise UnauthorizedException("Invitation expired.")
if user_model.email != claimed_email: if user_model.email != claimed_email:
raise ForbiddenException("The logged in user and email do not match.") raise ForbiddenException("The logged in user and email do not match.")
return email_claims return email_claims
async def send_email(recipient: str, subject: str, body: str): async def send_email(recipient: str, subject: str, body: str):
lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value()) if settings.ENVIRONMENT.is_testing:
return
if settings.ENVIRONMENT.is_testing or settings.ENVIRONMENT == "local": lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value())
recipient = "ok@testing.lettermint.co"
try: if settings.ENVIRONMENT == "local":
response = ( recipient = "ok@testing.lettermint.co"
lettermint.email.from_("noreply@sr2.uk")
.to(recipient)
.subject(subject)
.text(body)
.send()
)
print(response.status_code) try:
except ValidationError: response = (
# Error thrown if domain not approved for project lettermint.email.from_("noreply@sr2.uk")
print("Lettermint validation error") .to(recipient)
.subject(subject)
.text(body)
.send()
)
logging.info(
"Email sent to {} with subject {} (Status: {})".format(
recipient, subject, response.status_code
)
)
except ValidationError as e:
logging.exception(e)

View file

@ -1,8 +1,9 @@
from fastapi.dependencies.models import Dependant
import pytest import pytest
from typing import AsyncGenerator from typing import AsyncGenerator
from itertools import combinations from itertools import combinations
from fastapi.routing import APIRoute from fastapi.routing import APIRoute, iter_route_contexts
from httpx import AsyncClient, ASGITransport from httpx import AsyncClient, ASGITransport
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
@ -14,7 +15,7 @@ from src.iam.models import Group, Permission, OrgPermissions
from src.auth.service import get_current_user, get_dev_user from src.auth.service import get_current_user, get_dev_user
from src.auth.dependencies import empty_su_list, get_super_admin_list, testing_su_list from src.auth.dependencies import empty_su_list, get_super_admin_list, testing_su_list
from src.main import app # inited FastAPI app from src.main import app # inited FastAPI app
from src.database import engine, get_db from src.database import engine, get_db_session
from src.models import CustomBase from src.models import CustomBase
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@ -22,269 +23,295 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture() @pytest.fixture()
def db_session(): def db_session():
CustomBase.metadata.drop_all(bind=engine) CustomBase.metadata.drop_all(bind=engine)
CustomBase.metadata.create_all(bind=engine) CustomBase.metadata.create_all(bind=engine)
db = SessionLocal() db = SessionLocal()
try: try:
_seed(db) # extracted seeding logic into a plain function _seed(db) # extracted seeding logic into a plain function
yield db yield db
finally: finally:
db.rollback() db.rollback()
db.close() db.close()
@pytest.fixture @pytest.fixture
async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]: async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override(): def get_db_override():
return db_session return db_session
app.dependency_overrides[get_db] = get_db_override app.dependency_overrides[get_db_session] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_super_admin_list] = testing_su_list app.dependency_overrides[get_super_admin_list] = testing_su_list
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient( async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1" transport=transport, base_url="http://localhost:8000/api/v1"
) as ac: ) as ac:
yield ac yield ac
app.dependency_overrides.clear() app.dependency_overrides.clear()
@pytest.fixture @pytest.fixture
async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]: async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override(): def get_db_override():
return db_session return db_session
app.dependency_overrides[get_db] = get_db_override app.dependency_overrides[get_db_session] = get_db_override
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient( async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1" transport=transport, base_url="http://localhost:8000/api/v1"
) as ac: ) as ac:
yield ac yield ac
app.dependency_overrides.clear() app.dependency_overrides.clear()
@pytest.fixture @pytest.fixture
async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]: async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override(): def get_db_override():
return db_session return db_session
app.dependency_overrides[get_db] = get_db_override app.dependency_overrides[get_db_session] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_super_admin_list] = empty_su_list app.dependency_overrides[get_super_admin_list] = empty_su_list
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient( async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1" transport=transport, base_url="http://localhost:8000/api/v1"
) as ac: ) as ac:
yield ac yield ac
app.dependency_overrides.clear() app.dependency_overrides.clear()
def _seed(db): def _seed(db):
db.add( db.add(
User( User(
email="admin@test.com", email="admin@test.com",
first_name="Admin", first_name="Admin",
last_name="Test", last_name="Test",
oidc_id="abcd-efgh-ijkl-mnop", oidc_id="abcd-efgh-ijkl-mnop",
) )
) )
db.add( db.add(
User( User(
email="user@orgone.com", email="user@orgone.com",
first_name="User", first_name="User",
last_name="Test", last_name="Test",
oidc_id="abcd-efgh-ijkl-qwer", oidc_id="abcd-efgh-ijkl-qwer",
) )
) )
db.add( db.add(
User( User(
email="root@orgtwo.com", email="root@orgtwo.com",
first_name="Root", first_name="Root",
last_name="Test", last_name="Test",
oidc_id="abcd-efgh-ijkl-hjkl", oidc_id="abcd-efgh-ijkl-hjkl",
) )
) )
db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927")) db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927")) db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927")) db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927")) db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927")) db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927")) db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927")) db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927")) db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927")) db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927"))
db.flush() db.flush()
db.add( db.add(
Org( Org(
name="Org One", name="Org One",
root_user_id=1, root_user_id=1,
billing_contact_id=1, billing_contact_id=1,
owner_contact_id=2, owner_contact_id=2,
security_contact_id=3, security_contact_id=3,
status="approved", status="approved",
intake_questionnaire={ intake_questionnaire={
"metadata": {"version": 0, "submission_date": None}, "metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"}, "questions": {"question_two": "answer two"},
}, },
) )
) )
db.add( db.add(
Org( Org(
name="Org Two", name="Org Two",
root_user_id=3, root_user_id=3,
billing_contact_id=4, billing_contact_id=4,
owner_contact_id=5, owner_contact_id=5,
security_contact_id=6, security_contact_id=6,
status="approved", status="approved",
intake_questionnaire={ intake_questionnaire={
"metadata": {"version": 0, "submission_date": None}, "metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"}, "questions": {"question_two": "answer two"},
}, },
) )
) )
db.add( db.add(
Org( Org(
name="Org Three", name="Org Three",
root_user_id=1, root_user_id=1,
billing_contact_id=7, billing_contact_id=7,
owner_contact_id=8, owner_contact_id=8,
security_contact_id=9, security_contact_id=9,
status="partial", status="partial",
intake_questionnaire={ intake_questionnaire={
"metadata": {"version": 0, "submission_date": None}, "metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"}, "questions": {"question_two": "answer two"},
}, },
) )
) )
db.add(OrgUsers(org_id=1, user_id=2)) db.add(OrgUsers(org_id=1, user_id=2))
db.add(Service(name="Test Service", api_key="123456789")) db.add(Service(name="Test Service", api_key="123456789"))
db.add(Permission(service_id=1, resource="test_resource", action="read")) db.add(Permission(service_id=1, resource="test_resource", action="read"))
db.add(Permission(service_id=1, resource="test_resource", action="move")) db.add(Permission(service_id=1, resource="test_resource", action="move"))
db.add(Permission(service_id=1, resource="test_resource", action="delete")) db.add(Permission(service_id=1, resource="test_resource", action="delete"))
db.add(OrgPermissions(org_id=1, permission_id=1)) db.add(OrgPermissions(org_id=1, permission_id=1))
db.add(OrgPermissions(org_id=1, permission_id=2)) db.add(OrgPermissions(org_id=1, permission_id=2))
db.add(Group(name="Org One Group", org_id=1)) db.add(Group(name="Org One Group", org_id=1))
db.add(Group(name="Org Two Group", org_id=2)) db.add(Group(name="Org Two Group", org_id=2))
db.add(Group(name="Org One Group Two", org_id=1)) db.add(Group(name="Org One Group Two", org_id=1))
db.flush() db.flush()
group_model = db.get(Group, 1) group_model = db.get(Group, 1)
perm_model = db.get(Permission, 1) perm_model = db.get(Permission, 1)
group_model.permission_rel.append(perm_model) group_model.permission_rel.append(perm_model)
user_model = db.get(User, 1) user_model = db.get(User, 1)
org_model = db.get(Org, 1) org_model = db.get(Org, 1)
org_model.user_rel.append(user_model) org_model.user_rel.append(user_model)
org_model.group_rel.append(group_model) org_model.group_rel.append(group_model)
db.flush() db.flush()
group_model.user_rel.append(user_model) group_model.user_rel.append(user_model)
db.commit() db.commit()
def generate_query_and_status(params) -> list[tuple[str, int]]: def generate_query_and_status(params) -> list[tuple[str, int]]:
possible_values = [0, -1, 42, "banana", ""] possible_values = [0, -1, 42, "banana", ""]
defaults = [f"{param}=1" for param in params] defaults = [f"{param}=1" for param in params]
# Missing params # Missing params
query_list = [ query_list = [
"&".join(combo) "&".join(combo)
for r in range(len(defaults) + 1) for r in range(len(defaults) + 1)
for combo in combinations(defaults, r) for combo in combinations(defaults, r)
] ]
# Complete query as default for invalid checks # Complete query as default for invalid checks
default_query = query_list.pop(-1) default_query = query_list.pop(-1)
# Checks for each param being invalid # Checks for each param being invalid
for param in params: for param in params:
for value in possible_values: for value in possible_values:
new_value = f"&{param}={value}" new_value = f"&{param}={value}"
query_list.append(default_query.replace(f"{param}=1", new_value)) query_list.append(default_query.replace(f"{param}=1", new_value))
query_and_status = [] query_and_status = []
# Assign expected status # Assign expected status
for query in query_list: for query in query_list:
# ID 42 is used to represent a non-existent entry. So it should 404. # ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if "42" in query else 422 status = 404 if "42" in query else 422
# Remove leading "&" if present # Remove leading "&" if present
query = query if len(query) > 1 and query[0] != "&" else query[1:] query = query if len(query) > 1 and query[0] != "&" else query[1:]
query_and_status.append((query, status)) query_and_status.append((query, status))
return query_and_status return query_and_status
def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]: def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]:
possible_values_int = [0, -1, 42, "banana", ""] possible_values_int = [0, -1, 42, "banana", ""]
possible_values_str = [0, "", "a"] possible_values_str = [0, "", "a"]
defaults = [{param: 1 for param in params.keys()}] defaults = [{param: 1 for param in params.keys()}]
# Missing params # Missing params
body_list = [ body_list = [
{key: ("valid string" if params[key] == "str" else 1) for key in combo} {key: ("valid string" if params[key] == "str" else 1) for key in combo}
for r in range(len(defaults[0].keys()) + 1) for r in range(len(defaults[0].keys()) + 1)
for combo in combinations(defaults[0].keys(), r) for combo in combinations(defaults[0].keys(), r)
] ]
# Complete body as default for generating invalid checks # Complete body as default for generating invalid checks
default_body = body_list.pop(-1) default_body = body_list.pop(-1)
# Generates checks for each param being invalid # Generates checks for each param being invalid
for param, typ in params.items(): for param, typ in params.items():
if typ == "int": if typ == "int":
possible_values = possible_values_int possible_values = possible_values_int
elif typ == "str": elif typ == "str":
possible_values = possible_values_str possible_values = possible_values_str
else: else:
raise TypeError(f"Unknown type {typ}") raise TypeError(f"Unknown type {typ}")
for value in possible_values: for value in possible_values:
new_record = default_body.copy() new_record = default_body.copy()
new_record[param] = value new_record[param] = value
body_list.append(new_record) body_list.append(new_record)
body_and_status = [] body_and_status = []
# Assign expected status # Assign expected status
for body in body_list: for body in body_list:
# ID 42 is used to represent a non-existent entry. So it should 404. # ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if 42 in body.values() else 422 status = 404 if 42 in body.values() else 422
body_and_status.append((body, status)) body_and_status.append((body, status))
return body_and_status return body_and_status
def get_testable_routes(): def get_testable_routes():
routes = [] routes = []
for route in app.routes: contexts = list(iter_route_contexts(app.routes))
if not isinstance(route, APIRoute):
continue
for method in route.methods: for route in contexts:
if method in {"HEAD", "OPTIONS"}: if not route.methods:
continue continue
if not isinstance(route.route, APIRoute):
continue
routes.append( dep_func_names = set()
(
method,
route.path,
route.status_code,
route.response_model,
route.summary,
)
)
return routes unchecked = []
unchecked.append(route.route.dependant)
while unchecked:
dependant = unchecked.pop(0)
ck = dependant.cache_key[0]
if hasattr(ck, "__name__"):
dep_func_names.add(ck.__name__)
unchecked += [
dep for dep in dependant.dependencies if isinstance(dep, Dependant)
]
auth_level = None
if "get_current_user" in dep_func_names:
auth_level = "User"
if (
"org_body_root_claims" in dep_func_names
or "org_query_root_claims" in dep_func_names
):
auth_level = "Root User"
if "user_model_super_admin" in dep_func_names:
auth_level = "Super Admin"
if "valid_service_key" in dep_func_names:
auth_level = "API Key"
for method in route.methods:
if method in {"HEAD", "OPTIONS"}:
continue
routes.append(
(
method,
route.route.path,
route.route.status_code,
route.route.response_model,
route.route.summary,
auth_level,
)
)
return routes
# with open("endpoints.txt", "w") as f:
# for ep in get_testable_routes():
# f.write(f"[{ep[0]}]({ep[1]}) -> {ep[2]}: {ep[3]}\n")
#
#
### Docstring formatted output ### ### Docstring formatted output ###
# with open("endpoints.txt", "w") as f: with open("endpoints.txt", "w") as f:
# for ep in get_testable_routes(): for ep in get_testable_routes():
# f.write(f"- [{ep[0]}]({ep[1]}): []: {ep[4]}\n") f.write(f"- [{ep[0]}]({ep[1]}): [{ep[5]}]: {ep[4]}\n")

View file

@ -8,181 +8,181 @@ import pytest
from httpx import AsyncClient from httpx import AsyncClient
pytestmark = [ pytestmark = [
pytest.mark.auth, pytest.mark.auth,
pytest.mark.preapproval, pytest.mark.preapproval,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_auth_approval(no_su_client: AsyncClient): async def test_get_org_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org?org_id=3") resp = await no_su_client.get("/org?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient): async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch(
"/org/questionnaire", "/org/questionnaire",
json={ json={
"organisation_id": 3, "organisation_id": 3,
"intake_questionnaire": { "intake_questionnaire": {
"question_one": "new answer one", "question_one": "new answer one",
"question_two": None, "question_two": None,
"question_three": None, "question_three": None,
}, },
"partial": True, "partial": True,
}, },
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_users_auth_approval(no_su_client: AsyncClient): async def test_get_org_users_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/users?org_id=3") resp = await no_su_client.get("/org/users?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_groups_auth_approval(no_su_client: AsyncClient): async def test_get_org_groups_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/groups?org_id=3") resp = await no_su_client.get("/org/groups?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_contact_auth_approval(no_su_client: AsyncClient): async def test_get_org_contact_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing") resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient): async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch(
"/org/contact", "/org/contact",
json={ json={
"organisation_id": 3, "organisation_id": 3,
"contact_type": "billing", "contact_type": "billing",
"email": "user@example.com", "email": "user@example.com",
}, },
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_service_auth_approval(no_su_client: AsyncClient): async def test_get_service_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/service?org_id=3") resp = await no_su_client.get("/service?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_group_permissions_auth_approval(no_su_client: AsyncClient): async def test_get_iam_group_permissions_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1") resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_group_users_auth_approval(no_su_client: AsyncClient): async def test_get_iam_group_users_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1") resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_group_auth_approval(no_su_client: AsyncClient): async def test_post_iam_group_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.post( resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 3} "/iam/group", json={"name": "New Group", "organisation_id": 3}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient): async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.put( resp = await no_su_client.put(
"/iam/group/permission", "/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 3}, json={"permission_id": 1, "group_id": 2, "organisation_id": 3},
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient): async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.put( resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3} "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_permissions_auth_approval(no_su_client: AsyncClient): async def test_get_iam_permissions_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/permissions?org_id=3") resp = await no_su_client.get("/iam/permissions?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_permissions_search_auth_approval(no_su_client: AsyncClient): async def test_post_iam_permissions_search_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.post( resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 3, "action": "read"} "/iam/permissions/search", json={"organisation_id": 3, "action": "read"}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_org_user_auth_approval(no_su_client: AsyncClient): async def test_delete_org_user_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.delete("/org/user?org_id=3&user_id=1") resp = await no_su_client.delete("/org/user?org_id=3&user_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_preapproval_auth_approval(no_su_client: AsyncClient): async def test_delete_preapproval_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.delete("/org/self?org_id=3") resp = await no_su_client.delete("/org/self?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 204 assert resp.status_code == 204
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_user_invitation_auth_approval(no_su_client: AsyncClient): async def test_post_user_invitation_auth_approval(no_su_client: AsyncClient):
body = {"user_email": "admin@test.com", "organisation_id": 3} body = {"user_email": "admin@test.com", "organisation_id": 3}
resp = await no_su_client.post("/user/invitation", json=body) resp = await no_su_client.post("/user/invitation", json=body)
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_group_permissions_auth_approval(no_su_client: AsyncClient): async def test_delete_group_permissions_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1") resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_group_users_success(no_su_client: AsyncClient): async def test_delete_group_users_success(no_su_client: AsyncClient):
resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1") resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_group_user_invitation_success(no_su_client: AsyncClient): async def test_put_group_user_invitation_success(no_su_client: AsyncClient):
body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1} body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1}
resp = await no_su_client.put("/iam/group/user/invitation", json=body) resp = await no_su_client.put("/iam/group/user/invitation", json=body)
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]

View file

@ -5,14 +5,14 @@ from httpx import AsyncClient
pytestmark = [ pytestmark = [
pytest.mark.auth, pytest.mark.auth,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_auth_root_su(default_client: AsyncClient): async def test_get_org_auth_root_su(default_client: AsyncClient):
# If a super admin can access a resource when not the root user # If a super admin can access a resource when not the root user
resp = await default_client.get("/org?org_id=2") resp = await default_client.get("/org?org_id=2")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["organisations"][0]["name"] == "Org Two" assert resp.json()["organisations"][0]["name"] == "Org Two"

View file

@ -7,147 +7,147 @@ import pytest
from httpx import AsyncClient from httpx import AsyncClient
pytestmark = [ pytestmark = [
pytest.mark.auth, pytest.mark.auth,
pytest.mark.root_user, pytest.mark.root_user,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_auth_root(no_su_client: AsyncClient): async def test_get_org_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org?org_id=2") resp = await no_su_client.get("/org?org_id=2")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient): async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch(
"/org/questionnaire", "/org/questionnaire",
json={ json={
"organisation_id": 2, "organisation_id": 2,
"intake_questionnaire": { "intake_questionnaire": {
"question_one": "new answer one", "question_one": "new answer one",
"question_two": None, "question_two": None,
"question_three": None, "question_three": None,
}, },
"partial": True, "partial": True,
}, },
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_users_auth_root(no_su_client: AsyncClient): async def test_get_org_users_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/users?org_id=2") resp = await no_su_client.get("/org/users?org_id=2")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_groups_auth_root(no_su_client: AsyncClient): async def test_get_org_groups_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/groups?org_id=2") resp = await no_su_client.get("/org/groups?org_id=2")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_contact_auth_root(no_su_client: AsyncClient): async def test_get_org_contact_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing") resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_contact_auth_root(no_su_client: AsyncClient): async def test_patch_org_contact_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch(
"/org/contact", "/org/contact",
json={ json={
"organisation_id": 2, "organisation_id": 2,
"contact_type": "billing", "contact_type": "billing",
"email": "user@example.com", "email": "user@example.com",
}, },
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_service_auth_root(no_su_client: AsyncClient): async def test_get_service_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/service?org_id=2") resp = await no_su_client.get("/service?org_id=2")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_group_permissions_auth_root(no_su_client: AsyncClient): async def test_get_iam_group_permissions_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1") resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient): async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1") resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_group_auth_root(no_su_client: AsyncClient): async def test_post_iam_group_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.post( resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 2} "/iam/group", json={"name": "New Group", "organisation_id": 2}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient): async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.put( resp = await no_su_client.put(
"/iam/group/permission", "/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 2}, json={"permission_id": 1, "group_id": 2, "organisation_id": 2},
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_user_auth_root( async def test_put_iam_group_user_auth_root(
no_su_client: AsyncClient, no_su_client: AsyncClient,
): ):
resp = await no_su_client.put( resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2} "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient): async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/permissions?org_id=2") resp = await no_su_client.get("/iam/permissions?org_id=2")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient): async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.post( resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 2, "action": "read"} "/iam/permissions/search", json={"organisation_id": 2, "action": "read"}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]

View file

@ -7,69 +7,69 @@ import pytest
from httpx import AsyncClient from httpx import AsyncClient
pytestmark = [ pytestmark = [
pytest.mark.auth, pytest.mark.auth,
pytest.mark.super_admin, pytest.mark.super_admin,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_user_auth_su(no_su_client: AsyncClient): async def test_get_user_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.get("/user?user_id=1") resp = await no_su_client.get("/user?user_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_status_auth_su(no_su_client: AsyncClient): async def test_patch_org_status_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch(
"/org/status", json={"organisation_id": 1, "status": "submitted"} "/org/status", json={"organisation_id": 1, "status": "submitted"}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient): async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2} "/org/root_user", json={"organisation_id": 1, "user_id": 2}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_service_key_auth_su(no_su_client: AsyncClient): async def test_patch_service_key_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch("/service/key", json={"service_id": 1}) resp = await no_su_client.patch("/service/key", json={"service_id": 1})
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_service_auth_su(no_su_client: AsyncClient): async def test_post_service_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.post("/service", json={"name": "New Test Service"}) resp = await no_su_client.post("/service", json={"name": "New Test Service"})
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_perm_auth_su(no_su_client: AsyncClient): async def test_post_perm_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.post( resp = await no_su_client.post(
"/iam/permission", "/iam/permission",
json={"service_id": 1, "resource": "test_resource", "action": "create"}, json={"service_id": 1, "resource": "test_resource", "action": "create"},
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_user_auth_su(no_su_client: AsyncClient): async def test_post_org_user_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2}) resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2})
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be super admin" in resp.json()["detail"] assert "Must be super admin" in resp.json()["detail"]

View file

@ -7,22 +7,22 @@ from httpx import AsyncClient
pytestmark = [ pytestmark = [
pytest.mark.auth, pytest.mark.auth,
pytest.mark.user, pytest.mark.user,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_self_db_auth_user(no_user_client: AsyncClient): async def test_get_self_db_auth_user(no_user_client: AsyncClient):
resp = await no_user_client.get("/user/self/db") resp = await no_user_client.get("/user/self/db")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated" assert resp.json()["detail"] == "Not authenticated"
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_success_auth_user(no_user_client: AsyncClient): async def test_post_org_success_auth_user(no_user_client: AsyncClient):
resp = await no_user_client.post("/org", json={"name": "New Test Org"}) resp = await no_user_client.post("/org", json={"name": "New Test Org"})
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated" assert resp.json()["detail"] == "Not authenticated"

View file

@ -4,7 +4,7 @@ from httpx import AsyncClient
@pytest.mark.anyio @pytest.mark.anyio
async def test_healthcheck(default_client: AsyncClient): async def test_healthcheck(default_client: AsyncClient):
resp = await default_client.get("/healthcheck") resp = await default_client.get("/healthcheck")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json() == {"status": "ok"} assert resp.json() == {"status": "ok"}

File diff suppressed because it is too large Load diff

View file

@ -9,506 +9,506 @@ from .conftest import generate_query_and_status
pytestmark = [ pytestmark = [
pytest.mark.org_module, pytest.mark.org_module,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_success(default_client: AsyncClient): async def test_get_org_success(default_client: AsyncClient):
resp = await default_client.get("/org?org_id=1") resp = await default_client.get("/org?org_id=1")
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
org = data["organisations"][0] org = data["organisations"][0]
assert isinstance(org, dict) assert isinstance(org, dict)
assert org["organisation_id"] == 1 assert org["organisation_id"] == 1
assert org["name"] == "Org One" assert org["name"] == "Org One"
assert org["status"] == "approved" assert org["status"] == "approved"
assert org["root_user_email"] == "admin@test.com" assert org["root_user_email"] == "admin@test.com"
assert "intake_questionnaire" in org assert "intake_questionnaire" in org
assert isinstance(org["intake_questionnaire"], dict) assert isinstance(org["intake_questionnaire"], dict)
assert org["billing_contact"]["email"] == "billing@orgone.com" assert org["billing_contact"]["email"] == "billing@orgone.com"
assert org["owner_contact"]["email"] == "owner@orgone.com" assert org["owner_contact"]["email"] == "owner@orgone.com"
assert org["security_contact"]["email"] == "security@orgone.com" assert org["security_contact"]["email"] == "security@orgone.com"
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_status_checks( async def test_get_org_status_checks(
default_client: AsyncClient, query: str, expected_status: int default_client: AsyncClient, query: str, expected_status: int
): ):
resp = await default_client.get(f"/org?{query}") resp = await default_client.get(f"/org?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_success(default_client: AsyncClient): async def test_post_org_success(default_client: AsyncClient):
resp = await default_client.post("/org", json={"name": "New Test Org"}) resp = await default_client.post("/org", json={"name": "New Test Org"})
data = resp.json() data = resp.json()
assert resp.status_code == 201 assert resp.status_code == 201
assert data["name"] == "New Test Org" assert data["name"] == "New Test Org"
assert data["status"] == "partial" assert data["status"] == "partial"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"name": 42}, 422), ({"name": 42}, 422),
({}, 422), ({}, 422),
({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422), ({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_status_checks( async def test_post_org_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.post("/org", json=body) resp = await default_client.post("/org", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient): async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient):
resp = await default_client.patch( resp = await default_client.patch(
"/org/questionnaire", "/org/questionnaire",
json={ json={
"organisation_id": 3, "organisation_id": 3,
"intake_questionnaire": { "intake_questionnaire": {
"question_one": "new answer one", "question_one": "new answer one",
"question_two": None, "question_two": None,
"question_three": None, "question_three": None,
}, },
"partial": True, "partial": True,
}, },
) )
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert data["name"] == "Org Three" assert data["name"] == "Org Three"
assert data["status"] == "partial" assert data["status"] == "partial"
assert "intake_questionnaire" in data assert "intake_questionnaire" in data
assert isinstance(data["intake_questionnaire"], dict) assert isinstance(data["intake_questionnaire"], dict)
metadata = data["intake_questionnaire"]["metadata"] metadata = data["intake_questionnaire"]["metadata"]
assert metadata["version"] == 0 assert metadata["version"] == 0
assert metadata["submission_date"] is None assert metadata["submission_date"] is None
questions = data["intake_questionnaire"]["questions"] questions = data["intake_questionnaire"]["questions"]
assert questions["question_one"] == "new answer one" assert questions["question_one"] == "new answer one"
assert questions["question_two"] == "answer two" assert questions["question_two"] == "answer two"
assert questions["question_three"] is None assert questions["question_three"] is None
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 42}, 404), ({"organisation_id": 42}, 404),
({"organisation_id": "Org One"}, 422), ({"organisation_id": "Org One"}, 422),
({"organisation_id": ""}, 422), ({"organisation_id": ""}, 422),
({}, 422), ({}, 422),
( (
{ {
"organisation_id": "1", "organisation_id": "1",
"intake_questionnaire": {"question_one": 42}, "intake_questionnaire": {"question_one": 42},
"partial": True, "partial": True,
}, },
422, 422,
), ),
( (
{"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}}, {"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}},
422, 422,
), ),
( (
{ {
"organisation_id": "1", "organisation_id": "1",
"intake_questionnaire": {"question_one": "valid"}, "intake_questionnaire": {"question_one": "valid"},
"partial": 42, "partial": 42,
}, },
422, 422,
), ),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_questionnaire_partial_status_checks( async def test_patch_questionnaire_partial_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.patch("/org/questionnaire", json=body) resp = await default_client.patch("/org/questionnaire", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient): async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient):
resp = await default_client.patch( resp = await default_client.patch(
"/org/questionnaire", "/org/questionnaire",
json={ json={
"organisation_id": 3, "organisation_id": 3,
"intake_questionnaire": { "intake_questionnaire": {
"question_one": "new answer one", "question_one": "new answer one",
"question_two": None, "question_two": None,
"question_three": None, "question_three": None,
}, },
"partial": False, "partial": False,
}, },
) )
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert data["name"] == "Org Three" assert data["name"] == "Org Three"
assert data["status"] == "submitted" assert data["status"] == "submitted"
assert "intake_questionnaire" in data assert "intake_questionnaire" in data
assert isinstance(data["intake_questionnaire"], dict) assert isinstance(data["intake_questionnaire"], dict)
metadata = data["intake_questionnaire"]["metadata"] metadata = data["intake_questionnaire"]["metadata"]
assert metadata["version"] == 0 assert metadata["version"] == 0
assert metadata["submission_date"] is not None assert metadata["submission_date"] is not None
questions = data["intake_questionnaire"]["questions"] questions = data["intake_questionnaire"]["questions"]
assert questions["question_one"] == "new answer one" assert questions["question_one"] == "new answer one"
assert questions["question_two"] == "answer two" assert questions["question_two"] == "answer two"
assert questions["question_three"] is None assert questions["question_three"] is None
@pytest.mark.parametrize( @pytest.mark.parametrize(
"status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"] "status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"]
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_status_success(default_client: AsyncClient, status: str): async def test_patch_org_status_success(default_client: AsyncClient, status: str):
resp = await default_client.patch( resp = await default_client.patch(
"/org/status", json={"organisation_id": 1, "status": status} "/org/status", json={"organisation_id": 1, "status": status}
) )
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert data["name"] == "Org One" assert data["name"] == "Org One"
assert data["status"] == status assert data["status"] == status
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 42}, 404), ({"organisation_id": 42}, 404),
({"organisation_id": "Org One"}, 422), ({"organisation_id": "Org One"}, 422),
({"organisation_id": ""}, 422), ({"organisation_id": ""}, 422),
({}, 422), ({}, 422),
({"organisation_id": "1", "status": True}, 422), ({"organisation_id": "1", "status": True}, 422),
({"organisation_id": "1", "status": 42}, 422), ({"organisation_id": "1", "status": 42}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_status_status_checks( async def test_patch_org_status_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.patch("/org/status", json=body) resp = await default_client.patch("/org/status", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_users_success(default_client: AsyncClient): async def test_get_org_users_success(default_client: AsyncClient):
resp = await default_client.get("/org/users?org_id=1") resp = await default_client.get("/org/users?org_id=1")
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert "users" in data assert "users" in data
assert isinstance(data["users"], list) assert isinstance(data["users"], list)
assert len(data["users"]) == 2 assert len(data["users"]) == 2
user = data["users"][0] user = data["users"][0]
assert isinstance(user, dict) assert isinstance(user, dict)
assert user["email"] == "admin@test.com" assert user["email"] == "admin@test.com"
assert user["id"] == 1 assert user["id"] == 1
assert "organisation" in data assert "organisation" in data
assert data["organisation"]["name"] == "Org One" assert data["organisation"]["name"] == "Org One"
assert data["organisation"]["id"] == 1 assert data["organisation"]["id"] == 1
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_users_status_checks( async def test_get_org_users_status_checks(
default_client: AsyncClient, query: str, expected_status: int default_client: AsyncClient, query: str, expected_status: int
): ):
resp = await default_client.get(f"/org/users?{query}") resp = await default_client.get(f"/org/users?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_user_success(default_client: AsyncClient): async def test_post_org_user_success(default_client: AsyncClient):
resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3}) resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3})
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert "organisation" in data assert "organisation" in data
assert isinstance(data["organisation"], dict) assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1 assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One" assert data["organisation"]["name"] == "Org One"
assert "users" in data assert "users" in data
assert isinstance(data["users"], list) assert isinstance(data["users"], list)
assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1 assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 42}, 404), ({"organisation_id": 42}, 404),
({}, 422), ({}, 422),
({"organisation_id": 1, "user_id": "id"}, 422), ({"organisation_id": 1, "user_id": "id"}, 422),
({"user_id": 2}, 422), ({"user_id": 2}, 422),
({"organisation_id": 1, "user_id": 42}, 404), ({"organisation_id": 1, "user_id": 42}, 404),
({"organisation_id": 1, "user_id": 1}, 409), ({"organisation_id": 1, "user_id": 1}, 409),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_user_status_checks( async def test_post_org_user_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.post("/org/user", json=body) resp = await default_client.post("/org/user", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_root_user_success(default_client: AsyncClient): async def test_patch_org_root_user_success(default_client: AsyncClient):
resp = await default_client.patch( resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2} "/org/root_user", json={"organisation_id": 1, "user_id": 2}
) )
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert data["name"] == "Org One" assert data["name"] == "Org One"
assert data["root_user_email"] == "user@orgone.com" assert data["root_user_email"] == "user@orgone.com"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 42, "user_id": 2}, 404), ({"organisation_id": 42, "user_id": 2}, 404),
({"organisation_id": "Org One", "user_id": 2}, 422), ({"organisation_id": "Org One", "user_id": 2}, 422),
({"organisation_id": "", "user_id": 2}, 422), ({"organisation_id": "", "user_id": 2}, 422),
({}, 422), ({}, 422),
({"user_id": 2}, 422), ({"user_id": 2}, 422),
({"user_id": 42}, 404), ({"user_id": 42}, 404),
({"organisation_id": 1, "user_id": "Test User"}, 422), ({"organisation_id": 1, "user_id": "Test User"}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_root_user_status_checks( async def test_patch_root_user_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.patch("/org/root_user", json=body) resp = await default_client.patch("/org/root_user", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_root_user_non_member(default_client: AsyncClient): async def test_patch_org_root_user_non_member(default_client: AsyncClient):
resp = await default_client.patch( resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 3} "/org/root_user", json={"organisation_id": 1, "user_id": 3}
) )
data = resp.json() data = resp.json()
assert resp.status_code == 422 assert resp.status_code == 422
assert data["detail"] == "This user does not belong to your organisation." assert data["detail"] == "This user does not belong to your organisation."
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_groups_success(default_client: AsyncClient): async def test_get_org_groups_success(default_client: AsyncClient):
resp = await default_client.get("/org/groups?org_id=1") resp = await default_client.get("/org/groups?org_id=1")
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert "organisation" in data assert "organisation" in data
assert isinstance(data["organisation"], dict) assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1 assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One" assert data["organisation"]["name"] == "Org One"
assert "groups" in data assert "groups" in data
assert isinstance(data["groups"], list) assert isinstance(data["groups"], list)
group = data["groups"][0] group = data["groups"][0]
assert isinstance(group, dict) assert isinstance(group, dict)
assert group["id"] == 1 assert group["id"] == 1
assert group["name"] == "Org One Group" assert group["name"] == "Org One Group"
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_groups_status_checks( async def test_get_org_groups_status_checks(
default_client: AsyncClient, query: str, expected_status: int default_client: AsyncClient, query: str, expected_status: int
): ):
resp = await default_client.get(f"/org/groups?{query}") resp = await default_client.get(f"/org/groups?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.parametrize("contact_type", ["billing", "security", "owner"]) @pytest.mark.parametrize("contact_type", ["billing", "security", "owner"])
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str): async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str):
resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}") resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}")
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert "organisation" in data assert "organisation" in data
assert data["organisation"]["id"] == 1 assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One" assert data["organisation"]["name"] == "Org One"
attributes = [ attributes = [
"email", "email",
"first_name", "first_name",
"last_name", "last_name",
"phonenumber", "phonenumber",
"vat_number", "vat_number",
"address", "address",
] ]
for attribute in attributes: for attribute in attributes:
assert attribute in data["contact"] assert attribute in data["contact"]
address_attributes = [ address_attributes = [
"post_office_box_number", "post_office_box_number",
"street_address", "street_address",
"street_address_line_2", "street_address_line_2",
"locality", "locality",
"address_region", "address_region",
"country_code", "country_code",
"postal_code", "postal_code",
] ]
for attribute in address_attributes: for attribute in address_attributes:
assert attribute in data["contact"]["address"] assert attribute in data["contact"]["address"]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", "query, expected_status",
[ [
("org_id=42&contact_type=billing", 404), ("org_id=42&contact_type=billing", 404),
("org_id=banana&contact_type=billing", 422), ("org_id=banana&contact_type=billing", 422),
("", 422), ("", 422),
("org_id=1&contact_type=contact", 422), ("org_id=1&contact_type=contact", 422),
("contact_type=billing", 422), ("contact_type=billing", 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_contact_status_checks( async def test_get_org_contact_status_checks(
default_client: AsyncClient, query: str, expected_status: int default_client: AsyncClient, query: str, expected_status: int
): ):
resp = await default_client.get(f"/org/contact?{query}") resp = await default_client.get(f"/org/contact?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.parametrize( @pytest.mark.parametrize(
"key, value", "key, value",
[ [
("email", "user@example.com"), ("email", "user@example.com"),
("first_name", "John"), ("first_name", "John"),
("last_name", "Doe"), ("last_name", "Doe"),
("phonenumber", "+441234567890"), ("phonenumber", "+441234567890"),
("vat_number", "GB123456789"), ("vat_number", "GB123456789"),
("post_office_box_number", "PO Box 123"), ("post_office_box_number", "PO Box 123"),
("street_address", "123 Example Street"), ("street_address", "123 Example Street"),
("street_address_line_2", "Suite 4B"), ("street_address_line_2", "Suite 4B"),
("locality", "Glasgow"), ("locality", "Glasgow"),
("address_region", "Glasgow City"), ("address_region", "Glasgow City"),
("country_code", "GB"), ("country_code", "GB"),
("postal_code", "G1 1AA"), ("postal_code", "G1 1AA"),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str): async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str):
resp = await default_client.patch( resp = await default_client.patch(
"/org/contact", "/org/contact",
json={"organisation_id": 1, "contact_type": "billing", key: value}, json={"organisation_id": 1, "contact_type": "billing", key: value},
) )
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert "organisation" in data assert "organisation" in data
assert data["organisation"]["id"] == 1 assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One" assert data["organisation"]["name"] == "Org One"
attributes = [ attributes = [
"email", "email",
"first_name", "first_name",
"last_name", "last_name",
"phonenumber", "phonenumber",
"vat_number", "vat_number",
"address", "address",
] ]
for attribute in attributes: for attribute in attributes:
assert attribute in data["contact"] assert attribute in data["contact"]
address_attributes = [ address_attributes = [
"post_office_box_number", "post_office_box_number",
"street_address", "street_address",
"street_address_line_2", "street_address_line_2",
"locality", "locality",
"address_region", "address_region",
"country_code", "country_code",
"postal_code", "postal_code",
] ]
for attribute in address_attributes: for attribute in address_attributes:
assert attribute in data["contact"]["address"] assert attribute in data["contact"]["address"]
if key in data["contact"]: if key in data["contact"]:
assert data["contact"][key] == value assert data["contact"][key] == value
elif key in data["contact"]["address"]: elif key in data["contact"]["address"]:
assert data["contact"]["address"][key] == value assert data["contact"]["address"][key] == value
else: else:
pytest.fail(f"Invalid contact key: {key}") pytest.fail(f"Invalid contact key: {key}")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 42, "contact_type": "billing"}, 404), ({"organisation_id": 42, "contact_type": "billing"}, 404),
({"organisation_id": 1, "contact_type": "security"}, 200), ({"organisation_id": 1, "contact_type": "security"}, 200),
({"organisation_id": 1, "contact_type": "owner"}, 200), ({"organisation_id": 1, "contact_type": "owner"}, 200),
({"organisation_id": "Org One", "contact_type": "billing"}, 422), ({"organisation_id": "Org One", "contact_type": "billing"}, 422),
({"organisation_id": "", "contact_type": "billing"}, 422), ({"organisation_id": "", "contact_type": "billing"}, 422),
({}, 422), ({}, 422),
({"organisation_id": 1, "contact_type": "not_real"}, 422), ({"organisation_id": 1, "contact_type": "not_real"}, 422),
({"organisation_id": 1, "contact_type": 42}, 422), ({"organisation_id": 1, "contact_type": 42}, 422),
({"organisation_id": 1, "contact_type": ""}, 422), ({"organisation_id": 1, "contact_type": ""}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_contact_status_checks( async def test_patch_org_contact_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.patch("/org/contact", json=body) resp = await default_client.patch("/org/contact", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_org_success(default_client: AsyncClient): async def test_delete_org_success(default_client: AsyncClient):
resp = await default_client.delete("/org?org_id=1") resp = await default_client.delete("/org?org_id=1")
assert resp.status_code == 204 assert resp.status_code == 204
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_org_users_success(default_client: AsyncClient): async def test_delete_org_users_success(default_client: AsyncClient):
resp = await default_client.delete("/org/user?org_id=1&user_id=2") resp = await default_client.delete("/org/user?org_id=1&user_id=2")
assert resp.status_code == 204 assert resp.status_code == 204
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_preapproval_org_success(default_client: AsyncClient): async def test_delete_preapproval_org_success(default_client: AsyncClient):
resp = await default_client.delete("/org/self?org_id=3") resp = await default_client.delete("/org/self?org_id=3")
assert resp.status_code == 204 assert resp.status_code == 204

View file

@ -8,90 +8,90 @@ from httpx import AsyncClient
from .conftest import generate_query_and_status, generate_body_and_status from .conftest import generate_query_and_status, generate_body_and_status
pytestmark = [ pytestmark = [
pytest.mark.service_module, pytest.mark.service_module,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_services_success(default_client: AsyncClient): async def test_get_services_success(default_client: AsyncClient):
resp = await default_client.get("/service?org_id=1") resp = await default_client.get("/service?org_id=1")
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert "services" in data assert "services" in data
assert isinstance(data["services"], list) assert isinstance(data["services"], list)
assert data["services"][0]["id"] == 1 assert data["services"][0]["id"] == 1
assert data["services"][0]["name"] == "Test Service" assert data["services"][0]["name"] == "Test Service"
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_services_status_checks( async def test_get_services_status_checks(
default_client: AsyncClient, query: str, expected_status: int default_client: AsyncClient, query: str, expected_status: int
): ):
resp = await default_client.get(f"/service?{query}") resp = await default_client.get(f"/service?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_service_success(default_client: AsyncClient): async def test_post_service_success(default_client: AsyncClient):
resp = await default_client.post("/service", json={"name": "New Test Service"}) resp = await default_client.post("/service", json={"name": "New Test Service"})
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert "service" in data assert "service" in data
assert isinstance(data["service"], dict) assert isinstance(data["service"], dict)
assert data["service"]["name"] == "New Test Service" assert data["service"]["name"] == "New Test Service"
assert data["service"]["id"] == 2 assert data["service"]["id"] == 2
assert isinstance(data["service"]["api_key"], str) assert isinstance(data["service"]["api_key"], str)
@pytest.mark.parametrize("body, expected_status", generate_body_and_status({"name": "str"})) @pytest.mark.parametrize("body, expected_status", generate_body_and_status({"name": "str"}))
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_service_status_checks( async def test_post_service_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.post("/service", json=body) resp = await default_client.post("/service", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_service_conflict(default_client: AsyncClient): async def test_post_service_conflict(default_client: AsyncClient):
resp = await default_client.post("/service", json={"name": "Test Service"}) resp = await default_client.post("/service", json={"name": "Test Service"})
assert resp.status_code == 409 assert resp.status_code == 409
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_service_success(default_client: AsyncClient): async def test_patch_service_success(default_client: AsyncClient):
resp = await default_client.patch("/service/key", json={"service_id": 1}) resp = await default_client.patch("/service/key", json={"service_id": 1})
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert "service" in data assert "service" in data
assert isinstance(data["service"], dict) assert isinstance(data["service"], dict)
assert data["service"]["name"] == "Test Service" assert data["service"]["name"] == "Test Service"
assert data["service"]["id"] == 1 assert data["service"]["id"] == 1
assert isinstance(data["service"]["api_key"], str) assert isinstance(data["service"]["api_key"], str)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
generate_body_and_status({"service_id": "int"}), generate_body_and_status({"service_id": "int"}),
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_services_status_checks( async def test_patch_services_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.patch("/service/key", json=body) resp = await default_client.patch("/service/key", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_service_success(default_client: AsyncClient): async def test_delete_service_success(default_client: AsyncClient):
resp = await default_client.delete("/service?service_id=1") resp = await default_client.delete("/service?service_id=1")
assert resp.status_code == 204 assert resp.status_code == 204

View file

@ -5,197 +5,202 @@
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
from fastapi.routing import APIRoute from fastapi.routing import APIRoute, iter_route_contexts
from .conftest import generate_query_and_status from .conftest import generate_query_and_status
pytestmark = [ pytestmark = [
pytest.mark.user_module, pytest.mark.user_module,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_self_db_success(default_client: AsyncClient): async def test_get_self_db_success(default_client: AsyncClient):
resp = await default_client.get("/user/self/db") resp = await default_client.get("/user/self/db")
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert data["first_name"] == "Admin" assert data["first_name"] == "Admin"
assert data["last_name"] == "Test" assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com" assert data["email"] == "admin@test.com"
assert "organisations" in data assert "organisations" in data
assert isinstance(data["organisations"], list) assert isinstance(data["organisations"], list)
assert "groups" in data assert "groups" in data
assert isinstance(data["groups"], dict) assert isinstance(data["groups"], dict)
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_user_success(default_client: AsyncClient): async def test_get_user_success(default_client: AsyncClient):
resp = await default_client.get("/user?user_id=1") resp = await default_client.get("/user?user_id=1")
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert data["first_name"] == "Admin" assert data["first_name"] == "Admin"
assert data["last_name"] == "Test" assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com" assert data["email"] == "admin@test.com"
assert "organisations" in data assert "organisations" in data
assert isinstance(data["organisations"], list) assert isinstance(data["organisations"], list)
assert "groups" in data assert "groups" in data
assert isinstance(data["groups"], dict) assert isinstance(data["groups"], dict)
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["user_id"])) @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["user_id"]))
async def test_get_user_status_checks( async def test_get_user_status_checks(
default_client: AsyncClient, query: str, expected_status: int default_client: AsyncClient, query: str, expected_status: int
): ):
resp = await default_client.get(f"/user?{query}") resp = await default_client.get(f"/user?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_user_success(default_client: AsyncClient): async def test_delete_user_success(default_client: AsyncClient):
resp = await default_client.delete("/user?user_id=1") resp = await default_client.delete("/user?user_id=1")
assert resp.status_code == 204 assert resp.status_code == 204
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_user_invitation_success(default_client: AsyncClient): async def test_post_user_invitation_success(default_client: AsyncClient):
body = {"user_email": "admin@test.com", "organisation_id": 1} body = {"user_email": "admin@test.com", "organisation_id": 1}
resp = await default_client.post("/user/invitation", json=body) resp = await default_client.post("/user/invitation", json=body)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert "organisation" in data assert "organisation" in data
assert isinstance(data["organisation"], dict) assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1 assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One" assert data["organisation"]["name"] == "Org One"
assert "invited_email" in data assert "invited_email" in data
assert isinstance(data["invited_email"], str) assert isinstance(data["invited_email"], str)
assert data["invited_email"] == "admin@test.com" assert data["invited_email"] == "admin@test.com"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 42, "user_email": "admin@test.com"}, 404), ({"organisation_id": 42, "user_email": "admin@test.com"}, 404),
({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422), ({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422),
({"organisation_id": "", "user_email": "admin@test.com"}, 422), ({"organisation_id": "", "user_email": "admin@test.com"}, 422),
({}, 422), ({}, 422),
({"user_email": 42}, 422), ({"user_email": 42}, 422),
({"organisation_id": 1, "user_email": "Test User"}, 422), ({"organisation_id": 1, "user_email": "Test User"}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_user_invitation_status_checks( async def test_post_user_invitation_status_checks(
default_client: AsyncClient, body, expected_status default_client: AsyncClient, body, expected_status
): ):
resp = await default_client.post("/user/invitation", json=body) resp = await default_client.post("/user/invitation", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"jwt": "invalid"}, 401), ({"jwt": "invalid"}, 401),
({"jwt": ""}, 401), ({"jwt": ""}, 401),
({"jwt": None}, 422), ({"jwt": None}, 422),
({"jwt": 42}, 422), ({"jwt": 42}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_user_invitation_accept_status_checks( async def test_post_user_invitation_accept_status_checks(
default_client: AsyncClient, body, expected_status default_client: AsyncClient, body, expected_status
): ):
resp = await default_client.post("/user/invitation/accept", json=body) resp = await default_client.post("/user/invitation/accept", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
if resp.status_code == 401: if resp.status_code == 401:
assert resp.json()["detail"] == "Invalid JWS" assert resp.json()["detail"] == "Invalid JWS"
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_self_orgs_success(default_client: AsyncClient): async def test_get_self_orgs_success(default_client: AsyncClient):
resp = await default_client.get("/user/self/orgs") resp = await default_client.get("/user/self/orgs")
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert "organisations" in data assert "organisations" in data
assert isinstance(data["organisations"], list) assert isinstance(data["organisations"], list)
assert len(data["organisations"]) > 0 assert len(data["organisations"]) > 0
org = data["organisations"][0] org = data["organisations"][0]
assert org["organisation_id"] == 1 assert org["organisation_id"] == 1
assert org["name"] == "Org One" assert org["name"] == "Org One"
assert org["status"] == "approved" assert org["status"] == "approved"
assert org["root_user_email"] == "admin@test.com" assert org["root_user_email"] == "admin@test.com"
assert "intake_questionnaire" in org assert "intake_questionnaire" in org
assert isinstance(org["intake_questionnaire"], dict) assert isinstance(org["intake_questionnaire"], dict)
assert isinstance(org["billing_contact"], dict) assert isinstance(org["billing_contact"], dict)
assert org["billing_contact"]["email"] == "billing@orgone.com" assert org["billing_contact"]["email"] == "billing@orgone.com"
assert org["billing_contact"]["id"] == 1 assert org["billing_contact"]["id"] == 1
assert isinstance(org["owner_contact"], dict) assert isinstance(org["owner_contact"], dict)
assert org["owner_contact"]["email"] == "owner@orgone.com" assert org["owner_contact"]["email"] == "owner@orgone.com"
assert org["owner_contact"]["id"] == 2 assert org["owner_contact"]["id"] == 2
assert isinstance(org["security_contact"], dict) assert isinstance(org["security_contact"], dict)
assert org["security_contact"]["email"] == "security@orgone.com" assert org["security_contact"]["email"] == "security@orgone.com"
assert org["security_contact"]["id"] == 3 assert org["security_contact"]["id"] == 3
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_self_orgs_dynamic(default_client: AsyncClient): async def test_get_self_orgs_dynamic(default_client: AsyncClient):
method = "GET" method = "GET"
path = "/user/self/orgs" path = "/user/self/orgs"
expected_data = { expected_data = {
"organisations": [ "organisations": [
{ {
"organisation_id": 1, "organisation_id": 1,
"name": "Org One", "name": "Org One",
"status": "approved", "status": "approved",
"root_user_email": "admin@test.com", "root_user_email": "admin@test.com",
"owner_contact": {"email": "owner@orgone.com", "id": 2}, "owner_contact": {"email": "owner@orgone.com", "id": 2},
"security_contact": {"email": "security@orgone.com", "id": 3}, "security_contact": {"email": "security@orgone.com", "id": 3},
"billing_contact": {"email": "billing@orgone.com", "id": 1}, "billing_contact": {"email": "billing@orgone.com", "id": 1},
"intake_questionnaire": { "intake_questionnaire": {
"questions": { "questions": {
"question_one": None, "question_one": None,
"question_three": None, "question_three": None,
"question_two": "answer two", "question_two": "answer two",
}, },
"metadata": {"version": 0, "submission_date": None}, "metadata": {"version": 0, "submission_date": None},
}, },
} }
] ]
} }
resp = await default_client.get(path) resp = await default_client.get(path)
route = next( contexts = list(iter_route_contexts(default_client._transport.app.routes)) # ty:ignore[unresolved-attribute]
route
for route in default_client._transport.app.routes # ty:ignore[unresolved-attribute]
if isinstance(route, APIRoute) and path in route.path and method in route.methods
)
assert resp.status_code == route.status_code route = next(
if route.status_code == 204: route.route
return for route in contexts
if isinstance(route.route, APIRoute)
and path in route.route.path
and isinstance(route.methods, set)
and method in route.methods
)
expected_response_schema = route.response_model assert resp.status_code == route.status_code
data = resp.json() if route.status_code == 204:
return
response_model = expected_response_schema(**data) expected_response_schema = route.response_model
assert isinstance(response_model, expected_response_schema) data = resp.json()
expected_response_model = expected_response_schema(**expected_data) response_model = expected_response_schema(**data)
assert isinstance(response_model, expected_response_schema)
assert response_model == expected_response_model expected_response_model = expected_response_schema(**expected_data)
assert response_model == expected_response_model

11
uv.lock generated
View file

@ -6,6 +6,9 @@ requires-python = ">=3.12"
exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values.
exclude-newer-span = "P2W" exclude-newer-span = "P2W"
[options.exclude-newer-package]
fastapi = "2026-06-22T00:00:00Z"
[[package]] [[package]]
name = "alembic" name = "alembic"
version = "1.18.4" version = "1.18.4"
@ -238,7 +241,7 @@ dev = [
requires-dist = [ requires-dist = [
{ name = "alembic", specifier = ">=1.18.4" }, { name = "alembic", specifier = ">=1.18.4" },
{ name = "email-validator", specifier = ">=2.3.0" }, { name = "email-validator", specifier = ">=2.3.0" },
{ name = "fastapi", specifier = ">=0.136.3" }, { name = "fastapi", specifier = ">=0.138.0" },
{ name = "httptools", specifier = ">=0.7.1" }, { name = "httptools", specifier = ">=0.7.1" },
{ name = "httpx", specifier = ">=0.28.1" }, { name = "httpx", specifier = ">=0.28.1" },
{ name = "itsdangerous", specifier = ">=2.2.0" }, { name = "itsdangerous", specifier = ">=2.2.0" },
@ -349,7 +352,7 @@ wheels = [
[[package]] [[package]]
name = "fastapi" name = "fastapi"
version = "0.136.3" version = "0.138.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "annotated-doc" }, { name = "annotated-doc" },
@ -358,9 +361,9 @@ dependencies = [
{ name = "typing-extensions" }, { name = "typing-extensions" },
{ name = "typing-inspection" }, { name = "typing-inspection" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/81/2d/ff8d91d7b564d464629a0fd50a4489c97fcb836ac230bf3a7269232a9b1f/fastapi-0.136.3.tar.gz", hash = "sha256:e487fae93ad408e6f47641ee4dfe389864fd7bec92e547ea8498fc13f43e83ab", size = 396410, upload-time = "2026-05-23T18:53:15.192Z" } sdist = { url = "https://files.pythonhosted.org/packages/5b/58/ff455d9fe47c60abadb34b9e05a304b1f05f5ab8000ac01565156b6f5e43/fastapi-0.138.0.tar.gz", hash = "sha256:d445a4877636ad191e7053e08c9bf98cb921a6756776848400bb773d1740c061", size = 419240, upload-time = "2026-06-20T01:18:05.259Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/e0/82/45359b62a067409bd929ae8a56b8ed13e5a8c8a61194b3c236920999ab83/fastapi-0.136.3-py3-none-any.whl", hash = "sha256:3d2a69bdf04b7e9f3afa292c3bc7a98816bbfafa10bc9b45f3f3700d2f761620", size = 117481, upload-time = "2026-05-23T18:53:16.924Z" }, { url = "https://files.pythonhosted.org/packages/6c/ff/8496d9847a5fedae775eb49460722d3efaa80487854273e9647ae876218c/fastapi-0.138.0-py3-none-any.whl", hash = "sha256:b6f54fd1bd72c80b0f899f172c61a600f6f7af9b43d4d772a018f35624048cb0", size = 126779, upload-time = "2026-06-20T01:18:03.483Z" },
] ]
[[package]] [[package]]