1
0
Fork 0
forked from sr2/cloud-api

Compare commits

...
Sign in to create a new pull request.

34 commits

Author SHA1 Message Date
7dad2e920e tests: get_testable_routes finds auth level
Checks all dependencies used on each endpoint and determines the highest level of auth applied to each endpoint.

API Key>SU>Root>User>None
2026-06-22 16:45:50 +01:00
bee0dcd4fe feat: soft deleted users access blocked 2026-06-22 16:12:03 +01:00
a9e059bf0a feat: user soft delete 2026-06-22 15:43:38 +01:00
irl
5b98be9787 ci: define context for docker 2026-06-22 15:30:29 +01:00
be46e43042 fix(db): user active default true 2026-06-22 15:28:46 +01:00
irl
8ab0390977 ci: fix branch name tag again 2026-06-22 15:26:39 +01:00
irl
cc4ae42646 ci: adds frontend ref 2026-06-22 15:24:33 +01:00
irl
44e1d4986f ci: relative repo path 2026-06-22 15:22:42 +01:00
irl
20615f438a ci: fix branch name tag 2026-06-22 15:21:31 +01:00
irl
a481be8352 ci: check out the frontend repo 2026-06-22 15:20:14 +01:00
irl
e7bd455b2d ci: run the build step somewhere 2026-06-22 15:18:28 +01:00
4b3ab92d2a fix: fastapi 0.137 router.route changes 2026-06-22 15:15:42 +01:00
irl
ee47186c5a fix(db): generator types 2026-06-22 15:12:34 +01:00
fab228bf8f minor: ruff format
Tabs -> spaces
2026-06-22 15:04:11 +01:00
b2921b73b8 fix: conftest match db changes 2026-06-22 15:02:39 +01:00
1a851859d0 fix: logging import for email 2026-06-22 15:02:04 +01:00
a343b76f63 fix: invalid toml syntax 2026-06-22 15:01:36 +01:00
irl
84ba3b6bee feat(db): db tuning options and consistency 2026-06-22 14:50:05 +01:00
40918fd8b8 feat: delete org soft deletes 2026-06-22 14:50:05 +01:00
irl
d395b01997 fix: only serve frontend if present in prod 2026-06-22 14:42:13 +01:00
irl
1384ee7bd6 feat: adds empty static directory for frontend 2026-06-22 14:40:02 +01:00
irl
df8ab32cb1 ci: build and publish OCI image 2026-06-22 14:38:23 +01:00
f41f76bcf8 Merge pull request 'feat(utils): use logging around email send' (#31) from irl/cloud-api:maillog into main
Reviewed-on: sr2/cloud-api#31
2026-06-22 13:37:15 +00:00
d07230b3b0 Merge pull request 'fix(user): simplify add_user' (#28) from irl/cloud-api:add_user into main
Reviewed-on: sr2/cloud-api#28
2026-06-22 13:34:36 +00:00
irl
9e1d6026b5 feat: adds Containerfile with frontend serving 2026-06-22 14:24:56 +01:00
c28b4dc37b feat: applied model mixins
IdMixin used on every table with an ID index (no changes needed to db)

Timestamp and Deleted mixins applied to org and user tables.

ActivatedMixin added to users.
2026-06-22 13:46:11 +01:00
7e1ab6c6ee feat: db model mixins 2026-06-22 13:46:11 +01:00
irl
0baa50d10f misc: add frontend dir to .gitignore 2026-06-22 13:30:53 +01:00
irl
53b42b24dd feat(utils): use logging around email send 2026-06-22 13:26:47 +01:00
irl
fe8f627fa5 ci: reduce min age for renovate to 7 days 2026-06-22 12:02:29 +00:00
c2777db2e3 Add renovate.json 2026-06-22 12:02:02 +00:00
irl
a9e539ef74 fix(user): simplify add_user 2026-06-22 12:23:38 +01:00
02ddf9a3ed fix: skip sending email process while running tests
Removes the need for lettermint api key in CI.
2026-06-22 12:06:43 +01:00
63e7d48c07 ci: remove non-ty checks from ty job 2026-06-22 12:04:39 +01:00
65 changed files with 4004 additions and 3761 deletions

View file

@ -34,8 +34,6 @@ jobs:
- run: uv python install # Gets Python version from pyproject.toml
- run: uv sync --dev
- run: uv run ty check
- run: uv run ruff format
- run: uv run pytest test
env:
ENVIRONMENT: testing
@ -54,3 +52,35 @@ jobs:
- run: uv run pytest test
env:
ENVIRONMENT: testing
build:
needs: [ ruff, ty, tests ]
if: ${{ always() && needs.ruff.result == 'success' && needs.ty.result == 'success' && needs.tests.result == 'success' }}
runs-on: docker
container:
image: ghcr.io/catthehacker/ubuntu:act-latest
options: -v /dind/docker.sock:/var/run/docker.sock
steps:
- name: Checkout the repo
uses: actions/checkout@v4
- name: Checkout the frontend
uses: actions/checkout@v4
with:
repository: sr2/cloud-portal.git
path: frontend
ref: main
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to the registry
uses: docker/login-action@v3
with:
registry: guardianproject.dev
username: irl
password: ${{ secrets.PACKAGE_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v6
with:
file: /workspace/sr2/cloud-api/Containerfile
context: /workspace/sr2/cloud-api/
push: true
tags: guardianproject.dev/${{ github.repository }}:${{ github.ref_name }}

4
.gitignore vendored
View file

@ -206,5 +206,7 @@ marimo/_static/
marimo/_lsp/
__marimo__/
endpoints.txt
endpoints.txt
# React Frontend
/frontend/

View file

@ -1 +1 @@
3.14
3.12

42
Containerfile Normal file
View file

@ -0,0 +1,42 @@
FROM node:22-slim AS react-builder
WORKDIR /app
COPY frontend/ /app/
RUN --mount=type=cache,target=/root/.npm npm ci
RUN npm run build # Outputs to /app/dist
FROM ghcr.io/astral-sh/uv:python3.12-trixie-slim AS python-builder
ENV UV_PYTHON_DOWNLOADS=0
WORKDIR /app
# Install dependencies first (layer caching)
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
uv sync --locked --no-install-project --no-editable
# Copy project source and install the project itself
COPY ./ /app/
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --locked --no-editable
FROM python:3.12-slim-trixie
WORKDIR /app
COPY alembic /app/alembic
COPY alembic.ini /app
COPY src /app/src
COPY --from=python-builder /app/.venv /app/.venv
COPY --from=react-builder /app/dist /app/static
# Ensure venv is on PATH
ENV PATH="/app/.venv/bin:$PATH" \
UV_PYTHON_DOWNLOADS=0
EXPOSE 8000
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]

View file

@ -0,0 +1,32 @@
"""fix user activated default
Revision ID: ae433e1c3b20
Revises: 661202797ecd
Create Date: 2026-06-22 15:26:57.805129
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'ae433e1c3b20'
down_revision: Union[str, Sequence[str], None] = '661202797ecd'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('user', 'active', server_default=sa.true())
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('user', 'active', server_default=sa.false())
# ### end Alembic commands ###

View file

@ -0,0 +1,44 @@
"""model mixins
Revision ID: 661202797ecd
Revises: 869d48618a1c
Create Date: 2026-06-22 13:29:39.689067
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '661202797ecd'
down_revision: Union[str, Sequence[str], None] = '869d48618a1c'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('organisation', sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()))
op.add_column('organisation', sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()))
op.add_column('organisation', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
op.add_column('user', sa.Column('active', sa.Boolean(), nullable=False, server_default=sa.false()))
op.add_column('user', sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()))
op.add_column('user', sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()))
op.add_column('user', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('user', 'deleted_at')
op.drop_column('user', 'updated_at')
op.drop_column('user', 'created_at')
op.drop_column('user', 'active')
op.drop_column('organisation', 'deleted_at')
op.drop_column('organisation', 'updated_at')
op.drop_column('organisation', 'created_at')
# ### end Alembic commands ###

View file

@ -8,7 +8,7 @@ requires-python = ">=3.12"
dependencies = [
"alembic>=1.18.4",
"email-validator>=2.3.0",
"fastapi>=0.136.3",
"fastapi>=0.138.0",
"httptools>=0.7.1",
"httpx>=0.28.1",
"itsdangerous>=2.2.0",
@ -34,11 +34,11 @@ line-length = 92
[tool.ruff.format]
quote-style = "double"
indent-style = "tab"
[tool.uv]
add-bounds = "major"
exclude-newer = "P2W"
exclude-newer-package = { "fastapi" = "2026-06-22T00:00:00Z" }
[dependency-groups]
dev = [

8
renovate.json Normal file
View file

@ -0,0 +1,8 @@
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": [
"config:recommended"
],
"minimumReleaseAge": "7 days",
"gitAuthor": "Renovate<noreply@sr2.uk>"
}

View file

@ -22,5 +22,5 @@ from fastapi import APIRouter
router = APIRouter(
tags=[""],
tags=[""],
)

View file

@ -8,6 +8,6 @@ Exports:
from fastapi import APIRouter
router = APIRouter(
tags=["admin"],
prefix="/admin",
tags=["admin"],
prefix="/admin",
)

View file

@ -26,15 +26,15 @@ api_router.include_router(iam_router)
class HealthCheckResponse(CustomBaseModel):
status: str
status: str
@api_router.get(
path="/healthcheck",
status_code=status.HTTP_200_OK,
response_model=HealthCheckResponse,
include_in_schema=False,
path="/healthcheck",
status_code=status.HTTP_200_OK,
response_model=HealthCheckResponse,
include_in_schema=False,
)
def healthcheck():
"""Simple health check endpoint."""
return {"status": "ok"}
"""Simple health check endpoint."""
return {"status": "ok"}

View file

@ -9,9 +9,9 @@ from src.config import CustomBaseSettings
class AuthConfig(CustomBaseSettings):
OIDC_CONFIG: str = ""
OIDC_ISSUER: str = ""
CLIENT_ID: str = ""
OIDC_CONFIG: str = ""
OIDC_ISSUER: str = ""
CLIENT_ID: str = ""
auth_settings = AuthConfig()

View file

@ -16,92 +16,92 @@ from src.exceptions import ForbiddenException
from src.user.dependencies import user_model_claims_dependency
from src.user.models import User
from src.organisation.dependencies import (
org_model_query_dependency,
org_model_body_dependency,
org_model_query_dependency,
org_model_body_dependency,
)
from src.organisation.models import Organisation as Org
async def org_query_user_claims(
org_model: org_model_query_dependency, user_model: user_model_claims_dependency
org_model: org_model_query_dependency, user_model: user_model_claims_dependency
):
if user_model in org_model.user_rel:
return True
if user_model in org_model.user_rel:
return True
raise ForbiddenException()
raise ForbiddenException()
org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)]
def get_super_admin_list():
return []
return []
def empty_su_list():
return []
return []
def testing_su_list():
return ["admin@test.com"]
return ["admin@test.com"]
su_list_dependency = Annotated[list[str | None], Depends(get_super_admin_list)]
async def user_model_super_admin(
user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency
user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency
):
if user_model.email in super_admin_emails:
return user_model
if user_model.email in super_admin_emails:
return user_model
raise ForbiddenException(message="Must be super admin")
raise ForbiddenException(message="Must be super admin")
super_admin_dependency = Annotated[User, Depends(user_model_super_admin)]
async def org_query_root_claims(
user_model: user_model_claims_dependency,
org_model: org_model_query_dependency,
su_emails: su_list_dependency,
request: Request,
user_model: user_model_claims_dependency,
org_model: org_model_query_dependency,
su_emails: su_list_dependency,
request: Request,
):
try:
if await user_model_super_admin(user_model, su_emails):
return org_model
except ForbiddenException:
pass
try:
if await user_model_super_admin(user_model, su_emails):
return org_model
except ForbiddenException:
pass
await org_status_check(org_model, request)
await org_status_check(org_model, request)
if org_model.root_user_id == user_model.id:
return org_model
if org_model.root_user_id == user_model.id:
return org_model
raise ForbiddenException(message="Must be the org's root user")
raise ForbiddenException(message="Must be the org's root user")
org_model_root_claim_query_dependency = Annotated[Org, Depends(org_query_root_claims)]
async def org_body_root_claims(
user_model: user_model_claims_dependency,
org_model: org_model_body_dependency,
su_emails: su_list_dependency,
request: Request,
user_model: user_model_claims_dependency,
org_model: org_model_body_dependency,
su_emails: su_list_dependency,
request: Request,
):
try:
if await user_model_super_admin(user_model, su_emails):
return org_model
except ForbiddenException:
pass
try:
if await user_model_super_admin(user_model, su_emails):
return org_model
except ForbiddenException:
pass
await org_status_check(org_model, request)
await org_status_check(org_model, request)
if org_model.root_user_id == user_model.id:
return org_model
if org_model.root_user_id == user_model.id:
return org_model
raise ForbiddenException(message="Must be the org's root user")
raise ForbiddenException(message="Must be the org's root user")
org_model_root_claim_body_dependency = Annotated[Org, Depends(org_body_root_claims)]

View file

@ -8,5 +8,5 @@ Exports:
from fastapi import APIRouter
router = APIRouter(
tags=["auth"],
tags=["auth"],
)

View file

@ -22,8 +22,8 @@ from src.organisation.exceptions import AwaitingApprovalException
from src.organisation.models import Organisation as Org
from src.exceptions import UnauthorizedException, ForbiddenException
from src.auth.config import auth_settings
from src.user.service import add_user_to_db
from src.database import db_dependency
from src.user.service import add_user
from src.database import DbSession
oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG)
@ -31,56 +31,56 @@ oidc_dependency = Annotated[str, Depends(oidc)]
async def get_dev_user():
return {"db_id": 1, "email": "chris@sr2.uk"}
return {"db_id": 1, "email": "chris@sr2.uk"}
async def get_current_user(
oidc_auth_string: oidc_dependency, db: db_dependency
oidc_auth_string: oidc_dependency, db: DbSession
) -> dict[str, Any]:
config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"]
key_response = requests.get(jwks_uri)
jwk_keys = KeySet.import_key_set(key_response.json())
config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"]
key_response = requests.get(jwks_uri)
jwk_keys = KeySet.import_key_set(key_response.json())
token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys)
token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys)
claims_requests = jwt.JWTClaimsRegistry(
exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER}
)
claims_requests = jwt.JWTClaimsRegistry(
exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER}
)
try:
claims_requests.validate(token.claims)
except ExpiredTokenError:
raise UnauthorizedException(message="Token is expired")
db_id = await add_user_to_db(db, token.claims)
try:
claims_requests.validate(token.claims)
except ExpiredTokenError:
raise UnauthorizedException(message="Token is expired")
db_id = await add_user(db, token.claims)
token.claims["db_id"] = db_id
token.claims["db_id"] = db_id
return token.claims
return token.claims
claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)]
async def org_status_check(org_model: Org, request: Request):
org_status = OrgStatus(org_model.status)
if org_status.is_blocked:
raise ForbiddenException("This organisation cannot perform this action.")
org_status = OrgStatus(org_model.status)
if org_status.is_blocked:
raise ForbiddenException("This organisation cannot perform this action.")
root = "/api/v1"
root = "/api/v1"
pre_approval_endpoints = [
f"PATCH{root}/org/status",
f"PATCH{root}/org/questionnaire",
f"GET{root}/org",
f"GET{root}/org/contact",
f"PATCH{root}/org/contact",
f"DELETE{root}/org/self",
]
current_request = f"{request.method}{request.url.path}"
if (
current_request not in pre_approval_endpoints
and org_model.status != OrgStatus.APPROVED
):
raise AwaitingApprovalException(org_model.id)
pre_approval_endpoints = [
f"PATCH{root}/org/status",
f"PATCH{root}/org/questionnaire",
f"GET{root}/org",
f"GET{root}/org/contact",
f"PATCH{root}/org/contact",
f"DELETE{root}/org/self",
]
current_request = f"{request.method}{request.url.path}"
if (
current_request not in pre_approval_endpoints
and org_model.status != OrgStatus.APPROVED
):
raise AwaitingApprovalException(org_model.id)

View file

@ -16,27 +16,31 @@ from src.constants import Environment
class CustomBaseSettings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
class Config(CustomBaseSettings):
APP_VERSION: str = "0.1"
ENVIRONMENT: Environment = Environment.PRODUCTION
SECRET_KEY: SecretStr = SecretStr("")
DISABLE_AUTH: bool = False
APP_VERSION: str = "0.1"
ENVIRONMENT: Environment = Environment.PRODUCTION
SECRET_KEY: SecretStr = SecretStr("")
DISABLE_AUTH: bool = False
CORS_ORIGINS: list[str] = ["*"]
CORS_ORIGINS_REGEX: str | None = None
CORS_HEADERS: list[str] = ["*"]
CORS_ORIGINS: list[str] = ["*"]
CORS_ORIGINS_REGEX: str | None = None
CORS_HEADERS: list[str] = ["*"]
DATABASE_NAME: str = "fastapi-exp"
DATABASE_PORT: str = "5432"
DATABASE_HOSTNAME: str = "localhost"
DATABASE_CREDENTIALS: SecretStr = SecretStr(":")
DATABASE_NAME: str = "fastapi-exp"
DATABASE_PORT: str = "5432"
DATABASE_HOSTNAME: str = "localhost"
DATABASE_CREDENTIALS: SecretStr = SecretStr(":")
LETTERMINT_API_TOKEN: SecretStr = SecretStr("")
DATABASE_POOL_SIZE: int = 16
DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes
DATABASE_POOL_PRE_PING: bool = True
LETTERMINT_API_TOKEN: SecretStr = SecretStr("")
settings = Config()
@ -47,20 +51,20 @@ DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME
DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value()
# this will support special chars for credentials
_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split(
":"
":"
)
_QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD))
SQLALCHEMY_DATABASE_URI = SecretStr(
f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
)
if settings.ENVIRONMENT == Environment.TESTING:
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")
app_configs: dict[str, Any] = {"title": "App API"}
if settings.ENVIRONMENT.is_deployed:
app_configs["root_path"] = f"/v{settings.APP_VERSION}"
app_configs["root_path"] = f"/v{settings.APP_VERSION}"
if not settings.ENVIRONMENT.is_debug:
app_configs["openapi_url"] = None # hide docs
app_configs["openapi_url"] = None # hide docs

View file

@ -9,29 +9,29 @@ from enum import StrEnum, auto
class Environment(StrEnum):
"""
Enumeration of environments.
"""
Enumeration of environments.
Attributes:
LOCAL (str): Application is running locally
TESTING (str): Application is running in testing mode
STAGING (str): Application is running in staging mode (ie not testing)
PRODUCTION (str): Application is running in production mode
"""
Attributes:
LOCAL (str): Application is running locally
TESTING (str): Application is running in testing mode
STAGING (str): Application is running in staging mode (ie not testing)
PRODUCTION (str): Application is running in production mode
"""
LOCAL = auto()
TESTING = auto()
STAGING = auto()
PRODUCTION = auto()
LOCAL = auto()
TESTING = auto()
STAGING = auto()
PRODUCTION = auto()
@property
def is_debug(self):
return self in (self.LOCAL, self.STAGING, self.TESTING)
@property
def is_debug(self):
return self in (self.LOCAL, self.STAGING, self.TESTING)
@property
def is_testing(self):
return self == self.TESTING
@property
def is_testing(self):
return self == self.TESTING
@property
def is_deployed(self) -> bool:
return self in (self.STAGING, self.PRODUCTION)
@property
def is_deployed(self) -> bool:
return self in (self.STAGING, self.PRODUCTION)

View file

@ -11,13 +11,13 @@ from fastapi import HTTPException, status
class ContactNotFoundException(HTTPException):
def __init__(self, contact_id: Optional[int] = None) -> None:
detail = (
"Contact not found"
if contact_id is None
else f"Contact with ID '{contact_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def __init__(self, contact_id: Optional[int] = None) -> None:
detail = (
"Contact not found"
if contact_id is None
else f"Contact with ID '{contact_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)

View file

@ -6,30 +6,31 @@ Models:
street_address, street_address_line_2, post_office_box_number, address_locality, country_code, address_region, postal_code
"""
from src.models import IdMixin
from sqlalchemy import ForeignKey
from sqlalchemy.orm import mapped_column, Mapped
from src.models import CustomBase
class Contact(CustomBase):
__tablename__ = "contact"
class Contact(CustomBase, IdMixin):
__tablename__ = "contact"
id: Mapped[int] = mapped_column(primary_key=True)
email: Mapped[str] = mapped_column(default=None, nullable=True)
first_name: Mapped[str] = mapped_column(default=None, nullable=True)
last_name: Mapped[str] = mapped_column(default=None, nullable=True)
phonenumber: Mapped[str] = mapped_column(default=None, nullable=True)
vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
email: Mapped[str] = mapped_column(default=None, nullable=True)
first_name: Mapped[str] = mapped_column(default=None, nullable=True)
last_name: Mapped[str] = mapped_column(default=None, nullable=True)
phonenumber: Mapped[str] = mapped_column(default=None, nullable=True)
vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
street_address: Mapped[str] = mapped_column(default=None, nullable=True)
street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True)
post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City
country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB
address_region: Mapped[str | None] = mapped_column(default=None, nullable=True)
postal_code: Mapped[str] = mapped_column(default=None, nullable=True)
street_address: Mapped[str] = mapped_column(default=None, nullable=True)
street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True)
post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City
country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB
address_region: Mapped[str | None] = mapped_column(default=None, nullable=True)
postal_code: Mapped[str] = mapped_column(default=None, nullable=True)
org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False
)
org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False
)

View file

@ -6,6 +6,6 @@ from fastapi import APIRouter
router = APIRouter(
prefix="/contact",
tags=["contact"],
prefix="/contact",
tags=["contact"],
)

View file

@ -14,22 +14,22 @@ from src.schemas import CustomBaseModel
class ContactAddress(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
model_config = ConfigDict(from_attributes=True, extra="ignore")
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None
street_address_line_2: Optional[str] = None
locality: Optional[str] = None
address_region: Optional[str] = None
country_code: Optional[str] = None
postal_code: Optional[str] = None
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None
street_address_line_2: Optional[str] = None
locality: Optional[str] = None
address_region: Optional[str] = None
country_code: Optional[str] = None
postal_code: Optional[str] = None
class ContactModel(CustomBaseModel):
email: Optional[EmailStr] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phonenumber: Optional[str] = None
vat_number: Optional[str] = None
email: Optional[EmailStr] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phonenumber: Optional[str] = None
vat_number: Optional[str] = None
address: ContactAddress
address: ContactAddress

View file

@ -1,13 +1,10 @@
"""
Database connections and init
Exports:
- db_dependency
- Base (sqlalchemy base model)
Database connection and session utilities
"""
from typing import Annotated
from sqlalchemy import create_engine, StaticPool
from contextlib import contextmanager
from typing import Annotated, Generator
from sqlalchemy import create_engine, StaticPool, Connection
from sqlalchemy.orm import sessionmaker, Session
from fastapi import Depends
@ -16,28 +13,57 @@ from src.constants import Environment
from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings
if global_settings.ENVIRONMENT == Environment.TESTING:
connect_args = {"check_same_thread": False}
engine = create_engine(
SQLALCHEMY_DATABASE_URI.get_secret_value(),
connect_args=connect_args,
poolclass=StaticPool,
)
connect_args = {"check_same_thread": False}
engine = create_engine(
SQLALCHEMY_DATABASE_URI.get_secret_value(),
connect_args=connect_args,
poolclass=StaticPool,
)
else:
engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value())
engine = create_engine(
SQLALCHEMY_DATABASE_URI.get_secret_value(),
pool_size=global_settings.DATABASE_POOL_SIZE,
pool_recycle=global_settings.DATABASE_POOL_TTL,
pool_pre_ping=global_settings.DATABASE_POOL_PRE_PING,
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine)
def get_db():
db = SessionLocal()
try:
yield db
except:
db.rollback()
raise
finally:
db.close()
@contextmanager
def get_db_connection() -> Generator[Connection, None, None]:
with engine.connect() as connection:
try:
yield connection
except Exception:
connection.rollback()
raise
db_dependency = Annotated[Session, Depends(get_db)]
def _get_db_connection() -> Generator[Connection, None, None]:
with get_db_connection() as connection:
yield connection
DbConnection = Annotated[Connection, Depends(_get_db_connection)]
@contextmanager
def get_db_session() -> Generator[Session, None, None]:
session = sm()
try:
yield session
except Exception:
session.rollback()
raise
finally:
session.close()
def _get_db_session() -> Generator[Session, None, None]:
with get_db_session() as session:
yield session
DbSession = Annotated[Session, Depends(_get_db_session)]

View file

@ -12,36 +12,36 @@ from fastapi import HTTPException, status
class UnprocessableContentException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Unprocessable content" if not message else message
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=detail,
)
def __init__(self, message: Optional[str] = None) -> None:
detail = "Unprocessable content" if not message else message
super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=detail,
)
class ConflictException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Conflict" if not message else message
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=detail,
)
def __init__(self, message: Optional[str] = None) -> None:
detail = "Conflict" if not message else message
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=detail,
)
class ForbiddenException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Forbidden" if not message else message
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail,
)
def __init__(self, message: Optional[str] = None) -> None:
detail = "Forbidden" if not message else message
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=detail,
)
class UnauthorizedException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Not authorized" if not message else message
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)
def __init__(self, message: Optional[str] = None) -> None:
detail = "Not authorized" if not message else message
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)

View file

@ -11,66 +11,62 @@ from typing import Annotated, Optional
from fastapi import Depends, Query
from src.database import db_dependency
from src.database import DbSession
from src.iam.models import Group, Permission
from src.iam.exceptions import GroupNotFoundException, PermNotFoundException
from src.iam.schemas import GroupIDMixin, PermIDMixin
def get_group_model_query(
db: db_dependency, group_id: Annotated[int, Query(gt=0)]
) -> Group:
group_model = db.get(Group, group_id)
if group_model is None:
raise GroupNotFoundException(group_id)
def get_group_model_query(db: DbSession, group_id: Annotated[int, Query(gt=0)]) -> Group:
group_model = db.get(Group, group_id)
if group_model is None:
raise GroupNotFoundException(group_id)
return group_model
return group_model
group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)]
def get_group_model_body(
db: db_dependency, request_model: Optional[GroupIDMixin] = None
db: DbSession, request_model: Optional[GroupIDMixin] = None
) -> Group:
group_id = getattr(request_model, "group_id", None)
if group_id is None:
raise GroupNotFoundException()
group_model = db.get(Group, group_id)
if group_model is None:
raise GroupNotFoundException(group_id)
group_id = getattr(request_model, "group_id", None)
if group_id is None:
raise GroupNotFoundException()
group_model = db.get(Group, group_id)
if group_model is None:
raise GroupNotFoundException(group_id)
return group_model
return group_model
group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)]
def get_perm_model_body(
db: db_dependency, request_model: Optional[PermIDMixin] = None
db: DbSession, request_model: Optional[PermIDMixin] = None
) -> Permission:
perm_id = getattr(request_model, "permission_id", None)
if perm_id is None:
raise PermNotFoundException
perm_model = db.get(Permission, perm_id)
if perm_model is None:
raise PermNotFoundException(perm_id)
perm_id = getattr(request_model, "permission_id", None)
if perm_id is None:
raise PermNotFoundException
perm_model = db.get(Permission, perm_id)
if perm_model is None:
raise PermNotFoundException(perm_id)
return perm_model
return perm_model
perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)]
def get_perm_model_query(
db: db_dependency, perm_id: Annotated[int, Query(gt=0)]
) -> Permission:
perm_model = db.get(Permission, perm_id)
if perm_model is None:
raise PermNotFoundException(perm_id)
def get_perm_model_query(db: DbSession, perm_id: Annotated[int, Query(gt=0)]) -> Permission:
perm_model = db.get(Permission, perm_id)
if perm_model is None:
raise PermNotFoundException(perm_id)
return perm_model
return perm_model
perm_model_query_dependency = Annotated[Permission, Depends(get_perm_model_query)]

View file

@ -12,26 +12,26 @@ from fastapi import HTTPException, status
class GroupNotFoundException(HTTPException):
def __init__(self, group_id: Optional[int] = None) -> None:
detail = (
"Group not found"
if group_id is None
else f"User with ID '{group_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def __init__(self, group_id: Optional[int] = None) -> None:
detail = (
"Group not found"
if group_id is None
else f"User with ID '{group_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
class PermNotFoundException(HTTPException):
def __init__(self, perm_id: Optional[int] = None) -> None:
detail = (
"Permission not found"
if perm_id is None
else f"User with ID '{perm_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def __init__(self, perm_id: Optional[int] = None) -> None:
detail = (
"Permission not found"
if perm_id is None
else f"User with ID '{perm_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)

View file

@ -21,95 +21,94 @@ Models:
from sqlalchemy import ForeignKey, UniqueConstraint
from sqlalchemy.orm import relationship, mapped_column, Mapped
from src.models import CustomBase
from src.models import CustomBase, IdMixin
class Permission(CustomBase):
__tablename__ = "permission"
class Permission(CustomBase, IdMixin):
__tablename__ = "permission"
id: Mapped[int] = mapped_column(primary_key=True)
resource: Mapped[str]
action: Mapped[str]
resource: Mapped[str]
action: Mapped[str]
service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE"))
service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE"))
__table_args__ = (
UniqueConstraint(
"service_id",
"resource",
"action",
name="uniq_permission_resource_and_action",
),
)
__table_args__ = (
UniqueConstraint(
"service_id",
"resource",
"action",
name="uniq_permission_resource_and_action",
),
)
service_rel = relationship(
"Service",
back_populates="permission_rel",
foreign_keys="Permission.service_id",
)
service_rel = relationship(
"Service",
back_populates="permission_rel",
foreign_keys="Permission.service_id",
)
group_rel = relationship(
"Group", secondary="group_permissions", back_populates="permission_rel"
)
group_rel = relationship(
"Group", secondary="group_permissions", back_populates="permission_rel"
)
org_rel = relationship(
"Organisation", secondary="org_permissions", back_populates="permission_rel"
)
org_rel = relationship(
"Organisation", secondary="org_permissions", back_populates="permission_rel"
)
@property
def service_name(self):
return self.service_rel.name
@property
def service_name(self):
return self.service_rel.name
class Group(CustomBase):
__tablename__ = "group"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]
class Group(CustomBase, IdMixin):
__tablename__ = "group"
org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE"))
name: Mapped[str]
__table_args__ = (
UniqueConstraint(
"name",
"org_id",
name="uniq_group_name_org_id",
),
)
org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE"))
user_rel = relationship("User", secondary="user_groups", back_populates="group_rel")
__table_args__ = (
UniqueConstraint(
"name",
"org_id",
name="uniq_group_name_org_id",
),
)
org_rel = relationship("Organisation", back_populates="group_rel")
user_rel = relationship("User", secondary="user_groups", back_populates="group_rel")
permission_rel = relationship(
"Permission", secondary="group_permissions", back_populates="group_rel"
)
org_rel = relationship("Organisation", back_populates="group_rel")
permission_rel = relationship(
"Permission", secondary="group_permissions", back_populates="group_rel"
)
class GroupPermissions(CustomBase):
__tablename__ = "group_permissions"
group_id: Mapped[int] = mapped_column(
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
)
permission_id: Mapped[int] = mapped_column(
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
)
__tablename__ = "group_permissions"
group_id: Mapped[int] = mapped_column(
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
)
permission_id: Mapped[int] = mapped_column(
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
)
class UserGroups(CustomBase):
__tablename__ = "user_groups"
user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
)
group_id: Mapped[int] = mapped_column(
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
)
__tablename__ = "user_groups"
user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
)
group_id: Mapped[int] = mapped_column(
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
)
class OrgPermissions(CustomBase):
__tablename__ = "org_permissions"
org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
)
permission_id: Mapped[int] = mapped_column(
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
)
__tablename__ = "org_permissions"
org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
)
permission_id: Mapped[int] = mapped_column(
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
)

File diff suppressed because it is too large Load diff

View file

@ -11,151 +11,151 @@ from typing import Optional, Annotated
from pydantic import EmailStr, ConfigDict, Field
from src.schemas import (
CustomBaseModel,
ResourceName,
ServiceIDMixin,
OrgIDMixin,
UserIDMixin,
PermIDMixin,
GroupIDMixin,
GroupSummary,
OrgSummary,
UserSummary,
CustomBaseModel,
ResourceName,
ServiceIDMixin,
OrgIDMixin,
UserIDMixin,
PermIDMixin,
GroupIDMixin,
GroupSummary,
OrgSummary,
UserSummary,
)
class UserSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int
first_name: str
last_name: str
email: EmailStr
id: int
first_name: str
last_name: str
email: EmailStr
class PermissionSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int
service_name: str
resource: str
action: str
id: int
service_name: str
resource: str
action: str
class GroupDetails(CustomBaseModel):
details: GroupSummary
permissions: list[PermissionSchema]
details: GroupSummary
permissions: list[PermissionSchema]
class IAMCAoRRequest(CustomBaseModel):
action: str
rn: ResourceName
action: str
rn: ResourceName
class IAMCAoRResponse(CustomBaseModel):
allowed: bool
user: UserSummary
action: str
rn: ResourceName
allowed: bool
user: UserSummary
action: str
rn: ResourceName
class IAMGetGroupPermissionsResponse(CustomBaseModel):
organisation: OrgSummary
group: GroupSummary
permissions: list[PermissionSchema]
organisation: OrgSummary
group: GroupSummary
permissions: list[PermissionSchema]
class IAMGetGroupUsersResponse(CustomBaseModel):
organisation: OrgSummary
group: GroupSummary
users: list[UserSummary]
organisation: OrgSummary
group: GroupSummary
users: list[UserSummary]
class IAMPostGroupRequest(OrgIDMixin):
name: str = Field(min_length=3)
name: str = Field(min_length=3)
class IAMPostGroupResponse(CustomBaseModel):
organisation: OrgSummary
group: GroupSummary
organisation: OrgSummary
group: GroupSummary
class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin):
pass
pass
class IAMPutGroupPermissionResponse(CustomBaseModel):
organisation: OrgSummary
group: GroupSummary
permissions: list[PermissionSchema]
organisation: OrgSummary
group: GroupSummary
permissions: list[PermissionSchema]
class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin):
pass
pass
class IAMPutGroupUserResponse(CustomBaseModel):
group: GroupSummary
users: list[UserSchema]
group: GroupSummary
users: list[UserSchema]
class IAMDeleteGroupPermissionResponse(CustomBaseModel):
group: GroupSummary
permissions: list[PermissionSchema]
group: GroupSummary
permissions: list[PermissionSchema]
class IAMDeleteGroupUserResponse(CustomBaseModel):
group: GroupSummary
users: list[UserSchema]
group: GroupSummary
users: list[UserSchema]
class IAMGetPermissionsResponse(CustomBaseModel):
permissions: list[PermissionSchema]
permissions: list[PermissionSchema]
class IAMPostPermissionRequest(ServiceIDMixin):
resource: str
action: str
resource: str
action: str
class IAMPostPermissionResponse(CustomBaseModel):
permission: PermissionSchema
permission: PermissionSchema
class IAMGetPermissionsSearchRequest(OrgIDMixin):
service_id: Annotated[int | None, Field(gt=0)] = None
resource: Optional[str] = None
action: Optional[str] = None
service_id: Annotated[int | None, Field(gt=0)] = None
resource: Optional[str] = None
action: Optional[str] = None
class IAMGetPermissionsSearchResponse(CustomBaseModel):
permissions: list[PermissionSchema]
permissions: list[PermissionSchema]
class IAMPutGroupInvitationRequest(OrgIDMixin, GroupIDMixin):
user_email: EmailStr
user_email: EmailStr
class IAMPutGroupInvitationResponse(CustomBaseModel):
organisation: OrgSummary
group: GroupSummary
invited_email: EmailStr
organisation: OrgSummary
group: GroupSummary
invited_email: EmailStr
class IAMPutGroupInvitationAcceptRequest(CustomBaseModel):
jwt: str
jwt: str
class IAMPutGroupInvitationAcceptResponse(CustomBaseModel):
organisation: OrgSummary
user: UserSummary
group: GroupDetails
organisation: OrgSummary
user: UserSummary
group: GroupDetails
class IAMPutOrgPermissionsRequest(OrgIDMixin):
permissions: list[int]
permissions: list[int]
class IAMPutOrgPermissionsResponse(CustomBaseModel):
organisation: OrgSummary
permissions: list[PermissionSchema]
organisation: OrgSummary
permissions: list[PermissionSchema]

View file

@ -10,7 +10,7 @@ from datetime import datetime, timedelta, timezone
from fastapi import Request, Depends
from sqlalchemy.orm import Session
from src.database import db_dependency
from src.database import DbSession
from src.exceptions import UnauthorizedException
from src.utils import send_email, generate_jwt
from src.iam.models import Group
@ -23,90 +23,90 @@ from src.service.schemas import HasServiceName
def valid_service_key(
db: db_dependency, request: Request, request_model: HasServiceName
db: DbSession, request: Request, request_model: HasServiceName
) -> bool:
rn = request_model.rn
api_key = request.headers.get("X-API-Key", None)
if not api_key:
raise UnauthorizedException("Missing API key")
service = rn.service
result = (
db.query(Service)
.filter(Service.name == service)
.filter(Service.api_key == api_key)
.first()
)
if result is None:
raise UnauthorizedException("Invalid API key")
rn = request_model.rn
api_key = request.headers.get("X-API-Key", None)
if not api_key:
raise UnauthorizedException("Missing API key")
service = rn.service
result = (
db.query(Service)
.filter(Service.name == service)
.filter(Service.api_key == api_key)
.first()
)
if result is None:
raise UnauthorizedException("Invalid API key")
return True
return True
service_key_dependency = Annotated[bool, Depends(valid_service_key)]
async def send_user_group_invitation(
user_email: str, org_name: str, org_id: int, group_id: int, group_name: str
user_email: str, org_name: str, org_id: int, group_id: int, group_name: str
):
expiry_delta = timedelta(hours=24)
expiry = datetime.now(timezone.utc) + expiry_delta
claims = {
"email": user_email,
"org_id": org_id,
"group_id": group_id,
"group_name": group_name,
"exp": expiry,
"type": "group_invite",
}
expiry_delta = timedelta(hours=24)
expiry = datetime.now(timezone.utc) + expiry_delta
claims = {
"email": user_email,
"org_id": org_id,
"group_id": group_id,
"group_name": group_name,
"exp": expiry,
"type": "group_invite",
}
token = await generate_jwt(claims)
subject = f"You have been invited to join a group of {org_name}"
body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
token = await generate_jwt(claims)
subject = f"You have been invited to join a group of {org_name}"
body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
await send_email(
recipient=user_email,
subject=subject,
body=body,
)
await send_email(
recipient=user_email,
subject=subject,
body=body,
)
async def create_group_and_assign_perms(
db: Session, org_model: Org, group_name: str, perm_list: list[int]
db: Session, org_model: Org, group_name: str, perm_list: list[int]
):
new_group = Group(name=group_name, org_id=org_model.id)
db.add(new_group)
db.flush()
new_group = Group(name=group_name, org_id=org_model.id)
db.add(new_group)
db.flush()
for permission in perm_list:
perm_model = db.get(Perm, permission)
for permission in perm_list:
perm_model = db.get(Perm, permission)
if perm_model is None:
continue
if perm_model is None:
continue
new_group.permission_rel.append(perm_model)
db.flush()
new_group.permission_rel.append(perm_model)
db.flush()
return new_group
return new_group
async def assign_default_group(
db: db_dependency,
org_model: Org,
user_model: User,
group_name: str,
perm_list: list[int],
db: DbSession,
org_model: Org,
user_model: User,
group_name: str,
perm_list: list[int],
):
group_model = (
db.query(Group)
.filter(Group.org_id == org_model.id)
.filter(Group.name == group_name)
.first()
)
group_model = (
db.query(Group)
.filter(Group.org_id == org_model.id)
.filter(Group.name == group_name)
.first()
)
if group_model is None:
group_model = await create_group_and_assign_perms(
db=db, group_name=group_name, org_model=org_model, perm_list=perm_list
)
if group_model is None:
group_model = await create_group_and_assign_perms(
db=db, group_name=group_name, org_model=org_model, perm_list=perm_list
)
user_model.group_rel.append(group_model)
db.flush()
user_model.group_rel.append(group_model)
db.flush()

View file

@ -2,6 +2,7 @@
Application root file: Inits the FastAPI application
"""
import os.path
from contextlib import asynccontextmanager
from typing import AsyncGenerator
@ -19,43 +20,43 @@ from src.auth.service import get_current_user, get_dev_user
@asynccontextmanager
async def lifespan(_application: FastAPI) -> AsyncGenerator:
# Startup
yield
# Shutdown
# Startup
yield
# Shutdown
if settings.ENVIRONMENT.is_deployed:
# Just a precaution, should be False anyway
settings.DISABLE_AUTH = False
# Just a precaution, should be False anyway
settings.DISABLE_AUTH = False
tags_metadata = [
{
"name": "User",
"description": "User related operations, includes getting information about the current user",
},
{
"name": "Organisation",
"description": "Organisation related operations, includes getting lists of users etc associated with orgs",
},
{
"name": "Service",
"description": "Services related operations, includes registering services and reissuing API keys",
},
{
"name": "IAM",
"description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.",
},
{
"name": "User",
"description": "User related operations, includes getting information about the current user",
},
{
"name": "Organisation",
"description": "Organisation related operations, includes getting lists of users etc associated with orgs",
},
{
"name": "Service",
"description": "Services related operations, includes registering services and reissuing API keys",
},
{
"name": "IAM",
"description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.",
},
]
app = FastAPI(
swagger_ui_init_oauth={
"clientId": auth_settings.CLIENT_ID,
"usePkceWithAuthorizationCodeGrant": True,
"scopes": "openid profile email",
},
openapi_tags=tags_metadata,
swagger_ui_init_oauth={
"clientId": auth_settings.CLIENT_ID,
"usePkceWithAuthorizationCodeGrant": True,
"scopes": "openid profile email",
},
openapi_tags=tags_metadata,
)
# Type inspection disabled for middleware injection.
@ -64,16 +65,19 @@ app = FastAPI(
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value())
# noinspection PyTypeChecker
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_origin_regex=settings.CORS_ORIGINS_REGEX,
allow_credentials=True,
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
allow_headers=settings.CORS_HEADERS,
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_origin_regex=settings.CORS_ORIGINS_REGEX,
allow_credentials=True,
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
allow_headers=settings.CORS_HEADERS,
)
if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL):
app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_current_user] = get_dev_user
app.include_router(api_router)
if os.path.exists("/app/static"):
app.frontend("/ui", directory="/app/static", fallback="index.html")

View file

@ -5,12 +5,33 @@ Global database models
from datetime import datetime
from typing import Any
from sqlalchemy import DateTime, JSON
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy import DateTime, JSON, func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class CustomBase(DeclarativeBase):
type_annotation_map = {
datetime: DateTime(timezone=True),
dict[str, Any]: JSON,
}
type_annotation_map = {
datetime: DateTime(timezone=True),
dict[str, Any]: JSON,
}
class ActivatedMixin:
active: Mapped[bool] = mapped_column(default=True)
class DeletedTimestampMixin:
deleted_at: Mapped[datetime | None] = mapped_column(nullable=True)
class DescriptionMixin:
description: Mapped[str]
class IdMixin:
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
class TimestampMixin:
created_at: Mapped[datetime] = mapped_column(default=func.now())
updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now())

View file

@ -10,48 +10,48 @@ from enum import StrEnum, auto
class Status(StrEnum):
"""
Enumeration of organisation statuses.
"""
Enumeration of organisation statuses.
Attributes:
PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted.
SUBMITTED (str): Questionnaire submitted but not approved.
REMEDIATION (str): Questionnaire submitted but requires revisions.
APPROVED (str): Questionnaire has been approved by an admin.
REJECTED (str): Questionnaire has been rejected by an admin.
REMOVED (str): Organisation has been removed.
"""
Attributes:
PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted.
SUBMITTED (str): Questionnaire submitted but not approved.
REMEDIATION (str): Questionnaire submitted but requires revisions.
APPROVED (str): Questionnaire has been approved by an admin.
REJECTED (str): Questionnaire has been rejected by an admin.
REMOVED (str): Organisation has been removed.
"""
PARTIAL = auto()
SUBMITTED = auto()
REMEDIATION = auto()
APPROVED = auto()
REJECTED = auto()
REMOVED = auto()
PARTIAL = auto()
SUBMITTED = auto()
REMEDIATION = auto()
APPROVED = auto()
REJECTED = auto()
REMOVED = auto()
@property
def is_pre_approval(self):
return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION)
@property
def is_pre_approval(self):
return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION)
@property
def is_pre_submission(self):
return self in (self.PARTIAL, self.REMEDIATION)
@property
def is_pre_submission(self):
return self in (self.PARTIAL, self.REMEDIATION)
@property
def is_blocked(self):
return self in (self.REMOVED, self.REJECTED)
@property
def is_blocked(self):
return self in (self.REMOVED, self.REJECTED)
class ContactType(StrEnum):
"""
Enumeration of organisation contact types.
"""
Enumeration of organisation contact types.
Attributes:
BILLING(str): Billing contact.
SECURITY (str): Security contact.
OWNER (str): Owner contact.
"""
Attributes:
BILLING(str): Billing contact.
SECURITY (str): Security contact.
OWNER (str): Owner contact.
"""
BILLING = auto()
SECURITY = auto()
OWNER = auto()
BILLING = auto()
SECURITY = auto()
OWNER = auto()

View file

@ -10,33 +10,33 @@ from typing import Annotated, Optional
from fastapi import Depends, Query
from src.database import db_dependency
from src.database import DbSession
from src.organisation.schemas import OrgIDMixin
from src.organisation.models import Organisation as Org
from src.organisation.exceptions import OrgNotFoundException
def get_org_model_query(db: db_dependency, org_id: Annotated[int, Query(gt=0)]) -> Org:
org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
return org_model
def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org:
org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
return org_model
org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)]
def get_org_model_body(db: db_dependency, request_model: OrgIDMixin) -> Org:
org_id: Optional[int] = getattr(request_model, "organisation_id", None)
if org_id is None:
raise OrgNotFoundException()
def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org:
org_id: Optional[int] = getattr(request_model, "organisation_id", None)
if org_id is None:
raise OrgNotFoundException()
org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
return org_model
return org_model
org_model_body_dependency = Annotated[Org, Depends(get_org_model_body)]

View file

@ -12,26 +12,26 @@ from fastapi import HTTPException, status
class OrgNotFoundException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None:
detail = (
"Organisation not found"
if org_id is None
else f"Organisation with ID '{org_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def __init__(self, org_id: Optional[int] = None) -> None:
detail = (
"Organisation not found"
if org_id is None
else f"Organisation with ID '{org_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
class AwaitingApprovalException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None:
detail = (
"Organisation has not been approved."
if org_id is None
else f"Organisation with ID '{org_id}' has not been approved."
)
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)
def __init__(self, org_id: Optional[int] = None) -> None:
detail = (
"Organisation has not been approved."
if org_id is None
else f"Organisation with ID '{org_id}' has not been approved."
)
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)

View file

@ -14,61 +14,62 @@ Models:
- OrgUsers: org_id[FK][PK], user_id[FK][PK]
"""
from src.models import IdMixin, DeletedTimestampMixin
from typing import Any
from sqlalchemy import ForeignKey
from sqlalchemy.orm import relationship, Mapped, mapped_column
from src.models import CustomBase
from src.models import CustomBase, TimestampMixin
class Organisation(CustomBase):
__tablename__ = "organisation"
class Organisation(CustomBase, IdMixin, TimestampMixin, DeletedTimestampMixin):
__tablename__ = "organisation"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str]
status: Mapped[str] = mapped_column(default="partial")
intake_questionnaire: Mapped[dict[str, Any] | None]
name: Mapped[str]
status: Mapped[str] = mapped_column(default="partial")
intake_questionnaire: Mapped[dict[str, Any] | None]
root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
security_contact_id: Mapped[int] = mapped_column(
ForeignKey("contact.id"), nullable=True
)
owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
security_contact_id: Mapped[int] = mapped_column(
ForeignKey("contact.id"), nullable=True
)
owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel")
user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel")
group_rel = relationship(
"Group", back_populates="org_rel", cascade="all, delete-orphan"
)
root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id")
group_rel = relationship(
"Group", back_populates="org_rel", cascade="all, delete-orphan"
)
root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id")
billing_contact_rel = relationship(
"Contact", foreign_keys="Organisation.billing_contact_id"
)
security_contact_rel = relationship(
"Contact", foreign_keys="Organisation.security_contact_id"
)
owner_contact_rel = relationship(
"Contact", foreign_keys="Organisation.owner_contact_id"
)
billing_contact_rel = relationship(
"Contact", foreign_keys="Organisation.billing_contact_id"
)
security_contact_rel = relationship(
"Contact", foreign_keys="Organisation.security_contact_id"
)
owner_contact_rel = relationship(
"Contact", foreign_keys="Organisation.owner_contact_id"
)
permission_rel = relationship(
"Permission", secondary="org_permissions", back_populates="org_rel"
)
permission_rel = relationship(
"Permission", secondary="org_permissions", back_populates="org_rel"
)
@property
def root_user_email(self) -> str:
return self.root_user_rel.email if self.root_user_rel else ""
@property
def root_user_email(self) -> str:
return self.root_user_rel.email if self.root_user_rel else ""
class OrgUsers(CustomBase):
__tablename__ = "orgusers"
__tablename__ = "orgusers"
org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
)
user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
)
org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
)
user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
)

File diff suppressed because it is too large Load diff

View file

@ -12,139 +12,139 @@ from datetime import datetime
from pydantic import EmailStr, ConfigDict, Field
from src.schemas import (
CustomBaseModel,
OrgIDMixin,
UserIDMixin,
GroupSummary,
OrgSummary,
UserSummary,
CustomBaseModel,
OrgIDMixin,
UserIDMixin,
GroupSummary,
OrgSummary,
UserSummary,
)
from src.contact.schemas import ContactModel
from src.organisation.constants import Status, ContactType
from src.organisation.schemas_questionnaires import (
QuestionnaireQuestionsVersion0 as CurrentQuestions,
questionnaire_union,
QuestionnaireQuestionsVersion0 as CurrentQuestions,
questionnaire_union,
)
class QuestionnaireMetadata(CustomBaseModel):
version: int
submission_date: Optional[datetime] = None
version: int
submission_date: Optional[datetime] = None
class Questionnaire(CustomBaseModel):
metadata: QuestionnaireMetadata
questions: questionnaire_union
metadata: QuestionnaireMetadata
questions: questionnaire_union
class ContactSummary(CustomBaseModel):
id: int
email: Optional[EmailStr] = None
id: int
email: Optional[EmailStr] = None
class OrgSchema(OrgIDMixin):
name: str
status: Status
root_user_email: EmailStr
intake_questionnaire: Optional[Questionnaire] = None
name: str
status: Status
root_user_email: EmailStr
intake_questionnaire: Optional[Questionnaire] = None
billing_contact: ContactSummary
owner_contact: ContactSummary
security_contact: ContactSummary
billing_contact: ContactSummary
owner_contact: ContactSummary
security_contact: ContactSummary
class OrgPostOrgRequest(CustomBaseModel):
name: str = Field(min_length=3)
intake_questionnaire: Optional[CurrentQuestions] = None
name: str = Field(min_length=3)
intake_questionnaire: Optional[CurrentQuestions] = None
class OrgPostOrgResponse(CustomBaseModel):
id: int
name: str
status: Status
id: int
name: str
status: Status
class OrgPatchQuestionnaireRequest(OrgIDMixin):
intake_questionnaire: CurrentQuestions
partial: bool
intake_questionnaire: CurrentQuestions
partial: bool
class OrgPatchQuestionnaireResponse(CustomBaseModel):
id: int
name: str
intake_questionnaire: Questionnaire
status: Status
id: int
name: str
intake_questionnaire: Questionnaire
status: Status
class OrgPatchStatusRequest(OrgIDMixin):
status: Status
status: Status
class OrgPatchStatusResponse(CustomBaseModel):
id: int
name: str
status: Status
id: int
name: str
status: Status
class OrgPatchContactRequest(OrgIDMixin):
contact_type: ContactType
contact_type: ContactType
email: Optional[EmailStr] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phonenumber: Optional[str] = None
vat_number: Optional[str] = None
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None
street_address_line_2: Optional[str] = None
locality: Optional[str] = None
address_region: Optional[str] = None
country_code: Optional[str] = None
postal_code: Optional[str] = None
email: Optional[EmailStr] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phonenumber: Optional[str] = None
vat_number: Optional[str] = None
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None
street_address_line_2: Optional[str] = None
locality: Optional[str] = None
address_region: Optional[str] = None
country_code: Optional[str] = None
postal_code: Optional[str] = None
class OrgPostUserRequest(OrgIDMixin, UserIDMixin):
pass
pass
class OrgPostUserResponse(CustomBaseModel):
organisation: OrgSummary
users: list[UserSummary]
organisation: OrgSummary
users: list[UserSummary]
class OrgPatchRootRequest(OrgIDMixin, UserIDMixin):
pass
pass
class OrgPatchRootResponse(CustomBaseModel):
name: str
root_user_email: str
name: str
root_user_email: str
class OrgGetUserResponse(CustomBaseModel):
users: list[dict[str, str | int]]
organisation: OrgSummary
users: list[dict[str, str | int]]
organisation: OrgSummary
class OrgGetGroupResponse(CustomBaseModel):
organisation: OrgSummary
groups: list[GroupSummary]
organisation: OrgSummary
groups: list[GroupSummary]
class OrgGetContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel
organisation: OrgSummary
contact: ContactModel
organisation: OrgSummary
class OrgPatchContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel
organisation: OrgSummary
contact: ContactModel
organisation: OrgSummary
class OrgGetOrgResponse(CustomBaseModel):
organisations: list[OrgSchema]
organisations: list[OrgSchema]

View file

@ -4,13 +4,13 @@ from src.schemas import CustomBaseModel
class QuestionnaireQuestions(CustomBaseModel):
pass
pass
class QuestionnaireQuestionsVersion0(QuestionnaireQuestions):
question_one: Optional[str] = None
question_two: Optional[str] = None
question_three: Optional[str] = None
question_one: Optional[str] = None
question_two: Optional[str] = None
question_three: Optional[str] = None
questionnaire_union = QuestionnaireQuestionsVersion0 # | QuestionnaireQuestionsVersion1

View file

@ -11,57 +11,57 @@ from src.user.models import User
async def add_default_org_permissions(
db: Session,
org_model: Org,
perm_list: list[int],
db: Session,
org_model: Org,
perm_list: list[int],
):
for permission in perm_list:
perm_model = db.get(Perm, permission)
for permission in perm_list:
perm_model = db.get(Perm, permission)
if perm_model is None:
continue
if perm_model is None:
continue
if perm_model in org_model.permission_rel:
continue
if perm_model in org_model.permission_rel:
continue
org_model.permission_rel.append(perm_model)
db.flush()
org_model.permission_rel.append(perm_model)
db.flush()
db.commit()
db.commit()
async def assign_defaults(
db: Session,
org_id: int,
user_id: int,
db: Session,
org_id: int,
user_id: int,
):
default_org_permissions = []
default_org_permissions = []
default_user_permissions = []
default_user_permissions = []
org_model = db.get(Org, org_id)
if org_model is None:
print("Org not found while adding defaults")
return
org_model = db.get(Org, org_id)
if org_model is None:
print("Org not found while adding defaults")
return
user_model = db.get(User, user_id)
if user_model is None:
print("User not found while adding defaults")
return
user_model = db.get(User, user_id)
if user_model is None:
print("User not found while adding defaults")
return
await add_default_org_permissions(db, org_model, default_org_permissions)
await assign_default_group(
db=db,
org_model=org_model,
user_model=user_model,
group_name="Default Users",
perm_list=default_user_permissions,
)
await assign_default_group(
db=db,
org_model=org_model,
user_model=user_model,
group_name="Root User",
perm_list=default_org_permissions,
)
db.commit()
await add_default_org_permissions(db, org_model, default_org_permissions)
await assign_default_group(
db=db,
org_model=org_model,
user_model=user_model,
group_name="Default Users",
perm_list=default_user_permissions,
)
await assign_default_group(
db=db,
org_model=org_model,
user_model=user_model,
group_name="Root User",
perm_list=default_org_permissions,
)
db.commit()

View file

@ -11,54 +11,54 @@ from typing import Optional
class CustomBaseModel(BaseModel):
pass
pass
### Mixins ###
class OrgIDMixin(CustomBaseModel):
organisation_id: int = Field(gt=0)
organisation_id: int = Field(gt=0)
class GroupIDMixin(CustomBaseModel):
group_id: int = Field(gt=0)
group_id: int = Field(gt=0)
class PermIDMixin(CustomBaseModel):
permission_id: int = Field(gt=0)
permission_id: int = Field(gt=0)
class ServiceIDMixin(CustomBaseModel):
service_id: int = Field(gt=0)
service_id: int = Field(gt=0)
class UserIDMixin(CustomBaseModel):
user_id: int = Field(gt=0)
user_id: int = Field(gt=0)
class ServiceNameMixin(CustomBaseModel):
service: str
service: str
class OrgSummary(CustomBaseModel):
id: int
name: str
id: int
name: str
class GroupSummary(CustomBaseModel):
id: int
name: str
id: int
name: str
class UserSummary(CustomBaseModel):
id: int
email: str
id: int
email: str
class ServiceSummary(CustomBaseModel):
id: int
name: str
id: int
name: str
class ResourceName(ServiceNameMixin, OrgIDMixin):
resource: str
instance: Optional[str] = None
resource: str
instance: Optional[str] = None

View file

@ -9,32 +9,30 @@ Exports:
from typing import Annotated
from fastapi import Depends, Query
from src.database import db_dependency
from src.database import DbSession
from src.service.exceptions import ServiceNotFoundException
from src.service.models import Service
from src.service.schemas import ServiceIDMixin
async def get_service_model_query(
db: db_dependency, service_id: Annotated[int, Query(gt=0)]
):
service_model = db.get(Service, service_id)
if service_model is None:
raise ServiceNotFoundException(service_id=service_id)
async def get_service_model_query(db: DbSession, service_id: Annotated[int, Query(gt=0)]):
service_model = db.get(Service, service_id)
if service_model is None:
raise ServiceNotFoundException(service_id=service_id)
return service_model
return service_model
service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)]
async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixin):
service_model = db.get(Service, request_model.service_id)
if service_model is None:
raise ServiceNotFoundException(service_id=request_model.service_id)
async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin):
service_model = db.get(Service, request_model.service_id)
if service_model is None:
raise ServiceNotFoundException(service_id=request_model.service_id)
return service_model
return service_model
service_model_body_dependency = Annotated[Service, Depends(get_service_model_body)]

View file

@ -11,13 +11,13 @@ from fastapi import HTTPException, status
class ServiceNotFoundException(HTTPException):
def __init__(self, service_id: Optional[int] = None) -> None:
detail = (
"Service not found"
if service_id is None
else f"Service with ID '{service_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def __init__(self, service_id: Optional[int] = None) -> None:
detail = (
"Service not found"
if service_id is None
else f"Service with ID '{service_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)

View file

@ -8,16 +8,15 @@ Models:
from sqlalchemy.orm import relationship, mapped_column, Mapped
from src.models import CustomBase
from src.models import CustomBase, IdMixin
class Service(CustomBase):
__tablename__ = "service"
class Service(CustomBase, IdMixin):
__tablename__ = "service"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(unique=True)
api_key: Mapped[str]
name: Mapped[str] = mapped_column(unique=True)
api_key: Mapped[str]
permission_rel = relationship(
"Permission", back_populates="service_rel", cascade="all, delete-orphan"
)
permission_rel = relationship(
"Permission", back_populates="service_rel", cascade="all, delete-orphan"
)

View file

@ -13,10 +13,10 @@ from sqlalchemy.exc import IntegrityError
from psycopg.errors import UniqueViolation
from src.exceptions import ConflictException
from src.database import db_dependency
from src.database import DbSession
from src.auth.dependencies import (
super_admin_dependency,
org_model_root_claim_query_dependency,
super_admin_dependency,
org_model_root_claim_query_dependency,
)
from src.iam.service import service_key_dependency
from src.iam.models import Permission as Perm
@ -25,212 +25,210 @@ from src.service.exceptions import ServiceNotFoundException
from src.service.models import Service
from src.service.utils import generate_api_key
from src.service.dependencies import (
service_model_body_dependency,
service_model_query_dependency,
service_model_body_dependency,
service_model_query_dependency,
)
from src.service.schemas import (
ServiceGetServiceResponse,
ServicePostServiceRequest,
ServicePostServiceResponse,
ServiceWithKeySchema,
ServicePatchKeyResponse,
ServicePatchKeyRequest,
ServicePostPermissionsResponse,
ServicePostPermissionsRequest,
ServiceGetServiceResponse,
ServicePostServiceRequest,
ServicePostServiceResponse,
ServiceWithKeySchema,
ServicePatchKeyResponse,
ServicePatchKeyRequest,
ServicePostPermissionsResponse,
ServicePostPermissionsRequest,
)
router = APIRouter(
tags=["Service"],
prefix="/service",
tags=["Service"],
prefix="/service",
)
@router.get(
"",
summary="Get all services",
status_code=status.HTTP_200_OK,
response_model=ServiceGetServiceResponse,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
status.HTTP_401_UNAUTHORIZED: {
"description": "Unauthorized",
"content": {
"application/json": {
"examples": {
"awaiting_approval": {
"summary": "Organisation has not yet been approved."
},
}
}
},
},
status.HTTP_403_FORBIDDEN: {
"description": "Forbidden",
"content": {
"application/json": {
"examples": {
"not_root": {"summary": "Not authorised. Must be root user."},
}
}
},
},
},
"",
summary="Get all services",
status_code=status.HTTP_200_OK,
response_model=ServiceGetServiceResponse,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
status.HTTP_401_UNAUTHORIZED: {
"description": "Unauthorized",
"content": {
"application/json": {
"examples": {
"awaiting_approval": {
"summary": "Organisation has not yet been approved."
},
}
}
},
},
status.HTTP_403_FORBIDDEN: {
"description": "Forbidden",
"content": {
"application/json": {
"examples": {
"not_root": {"summary": "Not authorised. Must be root user."},
}
}
},
},
},
)
async def get_all_services(
db: db_dependency, org_model: org_model_root_claim_query_dependency
):
"""
Returns the ID and name of all services registered to the hub.
"""
permission_models = db.query(Service).all()
async def get_all_services(db: DbSession, org_model: org_model_root_claim_query_dependency):
"""
Returns the ID and name of all services registered to the hub.
"""
permission_models = db.query(Service).all()
return {"services": permission_models}
return {"services": permission_models}
@router.post(
"",
summary="Register a new service.",
status_code=status.HTTP_200_OK,
response_model=ServicePostServiceResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully registered a new service"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"},
},
"",
summary="Register a new service.",
status_code=status.HTTP_200_OK,
response_model=ServicePostServiceResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully registered a new service"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"},
},
)
async def register_service(
db: db_dependency,
su: super_admin_dependency,
request_model: ServicePostServiceRequest,
db: DbSession,
su: super_admin_dependency,
request_model: ServicePostServiceRequest,
):
"""
Registers a new service to the hub, generating and returning an API key for it.
"""
key = generate_api_key()
service_model = Service(name=request_model.name, api_key=key)
"""
Registers a new service to the hub, generating and returning an API key for it.
"""
key = generate_api_key()
service_model = Service(name=request_model.name, api_key=key)
db.add(service_model)
try:
db.flush()
except IntegrityError as e:
if (
isinstance(e.orig, UniqueViolation) # Postgres unique violation
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
):
raise ConflictException(message="Service with this name already exists")
raise
response = ServiceWithKeySchema(**service_model.__dict__)
db.commit()
return {"service": response}
db.add(service_model)
try:
db.flush()
except IntegrityError as e:
if (
isinstance(e.orig, UniqueViolation) # Postgres unique violation
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
):
raise ConflictException(message="Service with this name already exists")
raise
response = ServiceWithKeySchema(**service_model.__dict__)
db.commit()
return {"service": response}
@router.patch(
"/key",
summary="Regenerate service API key.",
status_code=status.HTTP_200_OK,
response_model=ServicePatchKeyResponse,
responses={
status.HTTP_200_OK: {"description": "Successful update of API key"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
"/key",
summary="Regenerate service API key.",
status_code=status.HTTP_200_OK,
response_model=ServicePatchKeyResponse,
responses={
status.HTTP_200_OK: {"description": "Successful update of API key"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
)
async def regenerate_api_key(
db: db_dependency,
su: super_admin_dependency,
service_model: service_model_body_dependency,
request_model: ServicePatchKeyRequest,
db: DbSession,
su: super_admin_dependency,
service_model: service_model_body_dependency,
request_model: ServicePatchKeyRequest,
):
"""
Generates and returns a new API key for the service to access the hub.
"""
key = generate_api_key()
service_model.api_key = key
"""
Generates and returns a new API key for the service to access the hub.
"""
key = generate_api_key()
service_model.api_key = key
db.flush()
response = ServiceWithKeySchema(**service_model.__dict__)
db.commit()
return {"service": response}
db.flush()
response = ServiceWithKeySchema(**service_model.__dict__)
db.commit()
return {"service": response}
@router.delete(
"",
summary="Remove a service.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
"",
summary="Remove a service.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
)
async def remove_service(
db: db_dependency,
service_model: service_model_query_dependency,
su: super_admin_dependency,
db: DbSession,
service_model: service_model_query_dependency,
su: super_admin_dependency,
):
"""
Removes a service from the hub.
"""
db.delete(service_model)
db.commit()
"""
Removes a service from the hub.
"""
db.delete(service_model)
db.commit()
@router.post(
path="/permissions",
summary="Service endpoint for creating its own permissions.",
status_code=status.HTTP_200_OK,
response_model=ServicePostPermissionsResponse,
responses={
status.HTTP_401_UNAUTHORIZED: {
"description": "API Key missing or invalid | Issue verifying user OIDC claims"
},
},
path="/permissions",
summary="Service endpoint for creating its own permissions.",
status_code=status.HTTP_200_OK,
response_model=ServicePostPermissionsResponse,
responses={
status.HTTP_401_UNAUTHORIZED: {
"description": "API Key missing or invalid | Issue verifying user OIDC claims"
},
},
)
async def service_create_new_permissions(
db: db_dependency,
request_model: ServicePostPermissionsRequest,
valid_key: service_key_dependency,
db: DbSession,
request_model: ServicePostPermissionsRequest,
valid_key: service_key_dependency,
):
"""
Allows a service to register its own set of permissions.
"""
service_model = (
db.query(Service).filter(Service.name == request_model.rn.service).first()
)
if service_model is None:
raise ServiceNotFoundException()
else:
service_id = service_model.id
response_list = []
for new_permission in request_model.permissions:
perm_model = (
db.query(Perm)
.filter(Perm.service_id == service_id)
.filter(Perm.resource == new_permission.resource)
.filter(Perm.action == new_permission.action)
.first()
)
if perm_model is not None:
response_code = 409
response = {
"id": perm_model.id,
"service_name": perm_model.service_name,
"resource": perm_model.resource,
"action": perm_model.action,
}
response_list.append((response, response_code))
continue
"""
Allows a service to register its own set of permissions.
"""
service_model = (
db.query(Service).filter(Service.name == request_model.rn.service).first()
)
if service_model is None:
raise ServiceNotFoundException()
else:
service_id = service_model.id
response_list = []
for new_permission in request_model.permissions:
perm_model = (
db.query(Perm)
.filter(Perm.service_id == service_id)
.filter(Perm.resource == new_permission.resource)
.filter(Perm.action == new_permission.action)
.first()
)
if perm_model is not None:
response_code = 409
response = {
"id": perm_model.id,
"service_name": perm_model.service_name,
"resource": perm_model.resource,
"action": perm_model.action,
}
response_list.append((response, response_code))
continue
new_perm_model = Perm(**new_permission.__dict__)
new_perm_model.service_id = service_id
db.add(new_perm_model)
db.flush()
response_code = 201
response = {
"id": new_perm_model.id,
"service_name": new_perm_model.service_name,
"resource": new_perm_model.resource,
"action": new_perm_model.action,
}
response_list.append((response, response_code))
new_perm_model = Perm(**new_permission.__dict__)
new_perm_model.service_id = service_id
db.add(new_perm_model)
db.flush()
response_code = 201
response = {
"id": new_perm_model.id,
"service_name": new_perm_model.service_name,
"resource": new_perm_model.resource,
"action": new_perm_model.action,
}
response_list.append((response, response_code))
db.commit()
return {"permissions": response_list}
db.commit()
return {"permissions": response_list}

View file

@ -10,10 +10,10 @@ from typing import Generic, TypeVar
from pydantic import Field, ConfigDict
from src.schemas import (
CustomBaseModel,
ServiceIDMixin,
ServiceSummary,
ServiceNameMixin,
CustomBaseModel,
ServiceIDMixin,
ServiceSummary,
ServiceNameMixin,
)
@ -21,51 +21,51 @@ T = TypeVar("T", bound=ServiceNameMixin)
class HasServiceName(CustomBaseModel, Generic[T]):
rn: T
rn: T
class PermissionResponseSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int
service_name: str
resource: str
action: str
id: int
service_name: str
resource: str
action: str
class PermissionRequestSchema(CustomBaseModel):
resource: str
action: str
resource: str
action: str
class ServiceWithKeySchema(ServiceSummary):
api_key: str
api_key: str
class ServiceGetServiceResponse(CustomBaseModel):
services: list[ServiceSummary]
services: list[ServiceSummary]
class ServicePostServiceRequest(CustomBaseModel):
name: str = Field(min_length=3)
name: str = Field(min_length=3)
class ServicePostServiceResponse(CustomBaseModel):
service: ServiceWithKeySchema
service: ServiceWithKeySchema
class ServicePatchKeyRequest(ServiceIDMixin):
pass
pass
class ServicePatchKeyResponse(CustomBaseModel):
service: ServiceWithKeySchema
service: ServiceWithKeySchema
class ServicePostPermissionsRequest(CustomBaseModel):
rn: ServiceNameMixin
permissions: list[PermissionRequestSchema]
rn: ServiceNameMixin
permissions: list[PermissionRequestSchema]
class ServicePostPermissionsResponse(CustomBaseModel):
permissions: list[tuple[PermissionResponseSchema, int]]
permissions: list[tuple[PermissionResponseSchema, int]]

View file

@ -9,4 +9,4 @@ import uuid
def generate_api_key() -> str:
return str(uuid.uuid4())
return str(uuid.uuid4())

View file

@ -10,46 +10,50 @@ Exports:
from typing import Annotated
from fastapi import Depends, Query
from src.auth.service import claims_dependency
from src.database import DbSession
from src.schemas import UserIDMixin
from src.exceptions import ForbiddenException
from src.user.exceptions import UserNotFoundException
from src.user.models import User
from src.auth.service import claims_dependency
from src.database import db_dependency
from src.schemas import UserIDMixin
async def get_user_model_claims(claims: claims_dependency, db: DbSession):
user_id = claims.get("db_id", None)
if user_id is None:
raise UserNotFoundException()
async def get_user_model_claims(claims: claims_dependency, db: db_dependency):
user_id = claims.get("db_id", None)
if user_id is None:
raise UserNotFoundException()
user_model = db.get(User, user_id)
if user_model is None:
raise UserNotFoundException(user_id=user_id)
user_model = db.get(User, user_id)
if user_model is None:
raise UserNotFoundException(user_id=user_id)
if not user_model.active:
raise ForbiddenException("User account is not active")
return user_model
return user_model
user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)]
async def get_user_model_query(db: db_dependency, user_id: Annotated[int, Query(gt=0)]):
user_model = db.get(User, user_id)
if user_model is None:
raise UserNotFoundException(user_id=user_id)
async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]):
user_model = db.get(User, user_id)
if user_model is None:
raise UserNotFoundException(user_id=user_id)
return user_model
return user_model
user_model_query_dependency = Annotated[User, Depends(get_user_model_query)]
async def get_user_model_body(db: db_dependency, request_model: UserIDMixin):
user_model = db.get(User, request_model.user_id)
if user_model is None:
raise UserNotFoundException(user_id=request_model.user_id)
async def get_user_model_body(db: DbSession, request_model: UserIDMixin):
user_model = db.get(User, request_model.user_id)
if user_model is None:
raise UserNotFoundException(user_id=request_model.user_id)
return user_model
return user_model
user_model_body_dependency = Annotated[User, Depends(get_user_model_body)]

View file

@ -11,13 +11,13 @@ from fastapi import HTTPException, status
class UserNotFoundException(HTTPException):
def __init__(self, user_id: Optional[int] = None) -> None:
detail = (
"User not found"
if user_id is None
else f"User with ID '{user_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
def __init__(self, user_id: Optional[int] = None) -> None:
detail = (
"User not found"
if user_id is None
else f"User with ID '{user_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)

View file

@ -10,6 +10,8 @@ Models:
- groups: Calc property dict of {group_rel.org_rel.name: group_rel.name}
"""
from src.models import IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin
from collections import defaultdict
from sqlalchemy.orm import relationship, mapped_column, Mapped
@ -17,28 +19,27 @@ from sqlalchemy.orm import relationship, mapped_column, Mapped
from src.models import CustomBase
class User(CustomBase):
__tablename__ = "user"
class User(CustomBase, IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin):
__tablename__ = "user"
id: Mapped[int] = mapped_column(primary_key=True)
email: Mapped[str]
first_name: Mapped[str]
last_name: Mapped[str]
oidc_id: Mapped[str] = mapped_column(index=True, unique=True)
email: Mapped[str]
first_name: Mapped[str]
last_name: Mapped[str]
oidc_id: Mapped[str] = mapped_column(index=True, unique=True)
organisation_rel = relationship(
"Organisation", secondary="orgusers", back_populates="user_rel"
)
organisation_rel = relationship(
"Organisation", secondary="orgusers", back_populates="user_rel"
)
group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel")
group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel")
@property
def organisations(self):
return [{"name": org.name, "id": org.id} for org in self.organisation_rel]
@property
def organisations(self):
return [{"name": org.name, "id": org.id} for org in self.organisation_rel]
@property
def groups(self):
result = defaultdict(list)
for group in self.group_rel:
result[group.org_rel.name].append({"name": group.name, "id": group.id})
return dict(result)
@property
def groups(self):
result = defaultdict(list)
for group in self.group_rel:
result[group.org_rel.name].append({"name": group.name, "id": group.id})
return dict(result)

View file

@ -8,210 +8,213 @@ Endpoints:
- [DELETE](/user/): [super admin]: Removes a User(id) from the hub database.
"""
from datetime import datetime, timezone
from fastapi import APIRouter, status, BackgroundTasks
from src.iam.models import Group
from src.organisation.exceptions import OrgNotFoundException
from src.user.schemas import (
UserResponse,
OIDCClaims,
UserPostInvitationRequest,
UserPostInvitationAcceptRequest,
UserGetSelfOrgsResponse,
UserPostInvitationResponse,
UserPostInvitationAcceptResponse,
UserResponse,
OIDCClaims,
UserPostInvitationRequest,
UserPostInvitationAcceptRequest,
UserGetSelfOrgsResponse,
UserPostInvitationResponse,
UserPostInvitationAcceptResponse,
)
from src.user.dependencies import (
user_model_claims_dependency,
user_model_query_dependency,
user_model_claims_dependency,
user_model_query_dependency,
)
from src.user.service import send_invitation
from src.organisation.models import Organisation as Org
from src.auth.dependencies import (
super_admin_dependency,
org_model_root_claim_body_dependency,
super_admin_dependency,
org_model_root_claim_body_dependency,
)
from src.auth.service import claims_dependency
from src.database import db_dependency
from src.database import DbSession
from src.utils import verify_email_token
router = APIRouter(
prefix="/user",
tags=["User"],
prefix="/user",
tags=["User"],
)
@router.get(
"/self/claims",
summary="Get current user OIDC claims.",
response_model=OIDCClaims,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
"/self/claims",
summary="Get current user OIDC claims.",
response_model=OIDCClaims,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
)
async def current_user_claims(user: claims_dependency):
"""
Returns the full OIDC claims associated with the currently logged-in user.
"""
user["allowed_origins"] = user.get("allowed-origins", [])
return user
"""
Returns the full OIDC claims associated with the currently logged-in user.
"""
user["allowed_origins"] = user.get("allowed-origins", [])
return user
@router.get(
"/self/db",
summary="Get current user hub details.",
response_model=UserResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
"/self/db",
summary="Get current user hub details.",
response_model=UserResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
)
async def current_user(user_model: user_model_claims_dependency):
"""
Returns the database details associated with the currently logged-in user.
"""
return user_model
"""
Returns the database details associated with the currently logged-in user.
"""
return user_model
@router.get(
"",
summary="Get user hub details by ID.",
response_model=UserResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
"",
summary="Get user hub details by ID.",
response_model=UserResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
)
async def get_user_by_id(
user_model: user_model_query_dependency, su: super_admin_dependency
user_model: user_model_query_dependency, su: super_admin_dependency
):
"""
Returns the database details associated with the provided user ID.
"""
return user_model
"""
Returns the database details associated with the provided user ID.
"""
return user_model
@router.delete(
"",
summary="Delete user from hub by ID.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
},
"",
summary="Delete user from hub by ID.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
},
)
async def delete_user_by_id(
db: db_dependency,
user_model: user_model_query_dependency,
su: super_admin_dependency,
async def soft_delete_user_by_id(
db: DbSession,
user_model: user_model_query_dependency,
su: super_admin_dependency,
):
"""
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
"""
db.delete(user_model)
db.commit()
"""
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
"""
user_model.active = False
user_model.deleted_at = datetime.now(tz=timezone.utc)
db.commit()
@router.get(
"/self/orgs",
summary="Get all orgs the current user is a member of",
status_code=status.HTTP_200_OK,
response_model=UserGetSelfOrgsResponse,
responses={},
"/self/orgs",
summary="Get all orgs the current user is a member of",
status_code=status.HTTP_200_OK,
response_model=UserGetSelfOrgsResponse,
responses={},
)
async def get_user_orgs(user_model: user_model_claims_dependency):
user_orgs = user_model.organisation_rel
response = []
for org in user_orgs:
response.append(
{
"organisation_id": org.id,
"name": org.name,
"status": org.status,
"intake_questionnaire": org.intake_questionnaire,
"root_user_email": org.root_user_email,
"billing_contact": {
"id": org.billing_contact_id,
"email": org.billing_contact_rel.email,
},
"owner_contact": {
"id": org.owner_contact_id,
"email": org.owner_contact_rel.email,
},
"security_contact": {
"id": org.security_contact_id,
"email": org.security_contact_rel.email,
},
}
)
user_orgs = user_model.organisation_rel
response = []
for org in user_orgs:
response.append(
{
"organisation_id": org.id,
"name": org.name,
"status": org.status,
"intake_questionnaire": org.intake_questionnaire,
"root_user_email": org.root_user_email,
"billing_contact": {
"id": org.billing_contact_id,
"email": org.billing_contact_rel.email,
},
"owner_contact": {
"id": org.owner_contact_id,
"email": org.owner_contact_rel.email,
},
"security_contact": {
"id": org.security_contact_id,
"email": org.security_contact_rel.email,
},
}
)
return {"organisations": response}
return {"organisations": response}
@router.post(
"/invitation",
summary="Send an email invitation for a user to join an org",
status_code=status.HTTP_200_OK,
response_model=UserPostInvitationResponse,
"/invitation",
summary="Send an email invitation for a user to join an org",
status_code=status.HTTP_200_OK,
response_model=UserPostInvitationResponse,
)
async def invitation(
background_tasks: BackgroundTasks,
org_model: org_model_root_claim_body_dependency,
request_model: UserPostInvitationRequest,
background_tasks: BackgroundTasks,
org_model: org_model_root_claim_body_dependency,
request_model: UserPostInvitationRequest,
):
org_id = org_model.id
org_name = org_model.name
user_email = request_model.user_email
org_id = org_model.id
org_name = org_model.name
user_email = request_model.user_email
background_tasks.add_task(
send_invitation, org_id=org_id, org_name=org_name, user_email=user_email
)
background_tasks.add_task(
send_invitation, org_id=org_id, org_name=org_name, user_email=user_email
)
response = {
"organisation": org_model,
"invited_email": user_email,
}
response = {
"organisation": org_model,
"invited_email": user_email,
}
return response
return response
@router.post(
"/invitation/accept",
summary="Accept email invitation to join an org",
status_code=status.HTTP_200_OK,
response_model=UserPostInvitationAcceptResponse,
"/invitation/accept",
summary="Accept email invitation to join an org",
status_code=status.HTTP_200_OK,
response_model=UserPostInvitationAcceptResponse,
)
async def accept_invitation(
db: db_dependency,
user_model: user_model_claims_dependency,
request_model: UserPostInvitationAcceptRequest,
db: DbSession,
user_model: user_model_claims_dependency,
request_model: UserPostInvitationAcceptRequest,
):
email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model)
email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model)
org_model = db.get(Org, email_claims["org_id"])
if org_model is None:
raise OrgNotFoundException()
org_model = db.get(Org, email_claims["org_id"])
if org_model is None:
raise OrgNotFoundException()
org_model.user_rel.append(user_model)
db.flush()
group_model = (
db.query(Group)
.filter(Group.org_id == org_model.id)
.filter(Group.name == "Default Users")
.first()
)
if group_model is not None:
user_model.group_rel.append(group_model)
org_model.user_rel.append(user_model)
db.flush()
group_model = (
db.query(Group)
.filter(Group.org_id == org_model.id)
.filter(Group.name == "Default Users")
.first()
)
if group_model is not None:
user_model.group_rel.append(group_model)
response = {
"organisation": org_model,
"user": user_model,
}
response = {
"organisation": org_model,
"user": user_model,
}
db.commit()
db.commit()
return response
return response

View file

@ -10,63 +10,63 @@ from src.schemas import CustomBaseModel, OrgIDMixin, OrgSummary, UserSummary
class OIDCClaims(CustomBaseModel):
exp: int
iat: int
auth_time: int
jti: str
iss: str
aud: str
sub: str
typ: str
azp: str
sid: str
acr: str
allowed_origins: list[str]
realm_access: dict[str, list[str]]
resource_access: dict[str, dict[str, list[str]]]
scope: str
email_verified: bool
name: str
preferred_username: str
given_name: str
family_name: str
email: str
db_id: int
exp: int
iat: int
auth_time: int
jti: str
iss: str
aud: str
sub: str
typ: str
azp: str
sid: str
acr: str
allowed_origins: list[str]
realm_access: dict[str, list[str]]
resource_access: dict[str, dict[str, list[str]]]
scope: str
email_verified: bool
name: str
preferred_username: str
given_name: str
family_name: str
email: str
db_id: int
class OIDCUser(CustomBaseModel):
first_name: str
last_name: str
email: str
oidc_id: str
first_name: str
last_name: str
email: str
oidc_id: str
class UserResponse(CustomBaseModel):
id: int
first_name: str
last_name: str
email: str
organisations: list[Optional[dict[str, str | int]]]
groups: Optional[dict[str, list[dict[str, str | int]]]] = None
id: int
first_name: str
last_name: str
email: str
organisations: list[Optional[dict[str, str | int]]]
groups: Optional[dict[str, list[dict[str, str | int]]]] = None
class UserPostInvitationRequest(OrgIDMixin):
user_email: EmailStr
user_email: EmailStr
class UserPostInvitationAcceptRequest(CustomBaseModel):
jwt: str
jwt: str
class UserGetSelfOrgsResponse(CustomBaseModel):
organisations: list[OrgSchema]
organisations: list[OrgSchema]
class UserPostInvitationResponse(CustomBaseModel):
organisation: OrgSummary
invited_email: EmailStr
organisation: OrgSummary
invited_email: EmailStr
class UserPostInvitationAcceptResponse(CustomBaseModel):
organisation: OrgSummary
user: UserSummary
organisation: OrgSummary
user: UserSummary

View file

@ -1,13 +1,11 @@
"""
Module specific business logic for user module
Exports:
- add_user_to_db: Creates a User record from OIDC claims, or updates user details
"""
from typing import Any
from datetime import datetime, timedelta, timezone
import logging
from sqlalchemy.orm import Session
from src.exceptions import UnprocessableContentException
@ -17,57 +15,50 @@ from src.user.schemas import OIDCUser
from src.user.models import User
async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
try:
valid_user = OIDCUser(
first_name=user_claims["given_name"],
last_name=user_claims["family_name"],
email=user_claims["email"],
oidc_id=user_claims["sub"],
)
except Exception as e:
print(e)
raise UnprocessableContentException("Invalid or missing OIDC data")
async def add_user(db: Session, user_claims: dict[str, Any]) -> int:
try:
valid_user = OIDCUser(
first_name=user_claims["given_name"],
last_name=user_claims["family_name"],
email=user_claims["email"],
oidc_id=user_claims["sub"],
)
except Exception as e:
logging.exception(e)
raise UnprocessableContentException("Invalid or missing OIDC data")
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
if not db_user:
user_model = User(**valid_user.model_dump())
db.add(user_model)
user_id = user_model.id
db.commit()
return user_id
else:
user_id = db_user.id
change = False
if db_user.first_name != valid_user.first_name:
db_user.first_name = valid_user.first_name
change = True
if db_user.last_name != valid_user.last_name:
db_user.last_name = valid_user.last_name
change = True
if change:
db.add(db_user)
db.commit()
return user_id
if not db_user:
user_model = User(**valid_user.model_dump())
db.add(user_model)
user_id = user_model.id
db.commit()
return user_id
user_id = db_user.id
db_user.first_name = valid_user.first_name
db_user.last_name = valid_user.last_name
db.commit()
return user_id
async def send_invitation(user_email: str, org_name: str, org_id: int):
expiry_delta = timedelta(hours=24)
expiry = datetime.now(timezone.utc) + expiry_delta
claims = {
"email": user_email,
"org_id": org_id,
"exp": expiry,
"type": "org_invite",
}
expiry_delta = timedelta(hours=24)
expiry = datetime.now(timezone.utc) + expiry_delta
claims = {
"email": user_email,
"org_id": org_id,
"exp": expiry,
"type": "org_invite",
}
token = await generate_jwt(claims)
subject = f"You have been invited to join {org_name}"
body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
token = await generate_jwt(claims)
subject = f"You have been invited to join {org_name}"
body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
await send_email(
recipient=user_email,
subject=subject,
body=body,
)
await send_email(
recipient=user_email,
subject=subject,
body=body,
)

View file

@ -1,3 +1,5 @@
import logging
from lettermint import Lettermint, ValidationError
from datetime import datetime, timezone
from joserfc import jwt, jwk, errors
@ -9,51 +11,56 @@ KEY = jwk.import_key(settings.SECRET_KEY.get_secret_value(), "oct")
async def generate_jwt(claims):
jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims)
jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims)
return jwt_token
return jwt_token
async def decode_jwt(encoded):
try:
token = jwt.decode(encoded, key=KEY)
return token.claims
except errors.DecodeError:
raise UnauthorizedException("Invalid JWS")
try:
token = jwt.decode(encoded, key=KEY)
return token.claims
except errors.DecodeError:
raise UnauthorizedException("Invalid JWS")
async def verify_email_token(user_model, token):
email_claims = await decode_jwt(token)
email_claims = await decode_jwt(token)
claimed_email = email_claims["email"]
claimed_email = email_claims["email"]
expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc)
expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc)
if expiry < datetime.now(timezone.utc):
raise UnauthorizedException("Invitation expired.")
if expiry < datetime.now(timezone.utc):
raise UnauthorizedException("Invitation expired.")
if user_model.email != claimed_email:
raise ForbiddenException("The logged in user and email do not match.")
if user_model.email != claimed_email:
raise ForbiddenException("The logged in user and email do not match.")
return email_claims
return email_claims
async def send_email(recipient: str, subject: str, body: str):
lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value())
if settings.ENVIRONMENT.is_testing:
return
if settings.ENVIRONMENT.is_testing or settings.ENVIRONMENT == "local":
recipient = "ok@testing.lettermint.co"
lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value())
try:
response = (
lettermint.email.from_("noreply@sr2.uk")
.to(recipient)
.subject(subject)
.text(body)
.send()
)
if settings.ENVIRONMENT == "local":
recipient = "ok@testing.lettermint.co"
print(response.status_code)
except ValidationError:
# Error thrown if domain not approved for project
print("Lettermint validation error")
try:
response = (
lettermint.email.from_("noreply@sr2.uk")
.to(recipient)
.subject(subject)
.text(body)
.send()
)
logging.info(
"Email sent to {} with subject {} (Status: {})".format(
recipient, subject, response.status_code
)
)
except ValidationError as e:
logging.exception(e)

View file

@ -1,8 +1,9 @@
from fastapi.dependencies.models import Dependant
import pytest
from typing import AsyncGenerator
from itertools import combinations
from fastapi.routing import APIRoute
from fastapi.routing import APIRoute, iter_route_contexts
from httpx import AsyncClient, ASGITransport
from sqlalchemy.orm import sessionmaker
@ -14,7 +15,7 @@ from src.iam.models import Group, Permission, OrgPermissions
from src.auth.service import get_current_user, get_dev_user
from src.auth.dependencies import empty_su_list, get_super_admin_list, testing_su_list
from src.main import app # inited FastAPI app
from src.database import engine, get_db
from src.database import engine, get_db_session
from src.models import CustomBase
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@ -22,269 +23,295 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture()
def db_session():
CustomBase.metadata.drop_all(bind=engine)
CustomBase.metadata.create_all(bind=engine)
db = SessionLocal()
try:
_seed(db) # extracted seeding logic into a plain function
yield db
finally:
db.rollback()
db.close()
CustomBase.metadata.drop_all(bind=engine)
CustomBase.metadata.create_all(bind=engine)
db = SessionLocal()
try:
_seed(db) # extracted seeding logic into a plain function
yield db
finally:
db.rollback()
db.close()
@pytest.fixture
async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override():
return db_session
def get_db_override():
return db_session
app.dependency_overrides[get_db] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_super_admin_list] = testing_su_list
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac
app.dependency_overrides[get_db_session] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_super_admin_list] = testing_su_list
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac
app.dependency_overrides.clear()
app.dependency_overrides.clear()
@pytest.fixture
async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override():
return db_session
def get_db_override():
return db_session
app.dependency_overrides[get_db] = get_db_override
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac
app.dependency_overrides[get_db_session] = get_db_override
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac
app.dependency_overrides.clear()
app.dependency_overrides.clear()
@pytest.fixture
async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override():
return db_session
def get_db_override():
return db_session
app.dependency_overrides[get_db] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_super_admin_list] = empty_su_list
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac
app.dependency_overrides[get_db_session] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_super_admin_list] = empty_su_list
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac
app.dependency_overrides.clear()
app.dependency_overrides.clear()
def _seed(db):
db.add(
User(
email="admin@test.com",
first_name="Admin",
last_name="Test",
oidc_id="abcd-efgh-ijkl-mnop",
)
)
db.add(
User(
email="user@orgone.com",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-qwer",
)
)
db.add(
User(
email="root@orgtwo.com",
first_name="Root",
last_name="Test",
oidc_id="abcd-efgh-ijkl-hjkl",
)
)
db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927"))
db.flush()
db.add(
Org(
name="Org One",
root_user_id=1,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(
Org(
name="Org Two",
root_user_id=3,
billing_contact_id=4,
owner_contact_id=5,
security_contact_id=6,
status="approved",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(
Org(
name="Org Three",
root_user_id=1,
billing_contact_id=7,
owner_contact_id=8,
security_contact_id=9,
status="partial",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(OrgUsers(org_id=1, user_id=2))
db.add(Service(name="Test Service", api_key="123456789"))
db.add(Permission(service_id=1, resource="test_resource", action="read"))
db.add(Permission(service_id=1, resource="test_resource", action="move"))
db.add(Permission(service_id=1, resource="test_resource", action="delete"))
db.add(OrgPermissions(org_id=1, permission_id=1))
db.add(OrgPermissions(org_id=1, permission_id=2))
db.add(Group(name="Org One Group", org_id=1))
db.add(Group(name="Org Two Group", org_id=2))
db.add(Group(name="Org One Group Two", org_id=1))
db.flush()
group_model = db.get(Group, 1)
perm_model = db.get(Permission, 1)
group_model.permission_rel.append(perm_model)
user_model = db.get(User, 1)
org_model = db.get(Org, 1)
org_model.user_rel.append(user_model)
org_model.group_rel.append(group_model)
db.flush()
group_model.user_rel.append(user_model)
db.commit()
db.add(
User(
email="admin@test.com",
first_name="Admin",
last_name="Test",
oidc_id="abcd-efgh-ijkl-mnop",
)
)
db.add(
User(
email="user@orgone.com",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-qwer",
)
)
db.add(
User(
email="root@orgtwo.com",
first_name="Root",
last_name="Test",
oidc_id="abcd-efgh-ijkl-hjkl",
)
)
db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927"))
db.flush()
db.add(
Org(
name="Org One",
root_user_id=1,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(
Org(
name="Org Two",
root_user_id=3,
billing_contact_id=4,
owner_contact_id=5,
security_contact_id=6,
status="approved",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(
Org(
name="Org Three",
root_user_id=1,
billing_contact_id=7,
owner_contact_id=8,
security_contact_id=9,
status="partial",
intake_questionnaire={
"metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"},
},
)
)
db.add(OrgUsers(org_id=1, user_id=2))
db.add(Service(name="Test Service", api_key="123456789"))
db.add(Permission(service_id=1, resource="test_resource", action="read"))
db.add(Permission(service_id=1, resource="test_resource", action="move"))
db.add(Permission(service_id=1, resource="test_resource", action="delete"))
db.add(OrgPermissions(org_id=1, permission_id=1))
db.add(OrgPermissions(org_id=1, permission_id=2))
db.add(Group(name="Org One Group", org_id=1))
db.add(Group(name="Org Two Group", org_id=2))
db.add(Group(name="Org One Group Two", org_id=1))
db.flush()
group_model = db.get(Group, 1)
perm_model = db.get(Permission, 1)
group_model.permission_rel.append(perm_model)
user_model = db.get(User, 1)
org_model = db.get(Org, 1)
org_model.user_rel.append(user_model)
org_model.group_rel.append(group_model)
db.flush()
group_model.user_rel.append(user_model)
db.commit()
def generate_query_and_status(params) -> list[tuple[str, int]]:
possible_values = [0, -1, 42, "banana", ""]
possible_values = [0, -1, 42, "banana", ""]
defaults = [f"{param}=1" for param in params]
defaults = [f"{param}=1" for param in params]
# Missing params
query_list = [
"&".join(combo)
for r in range(len(defaults) + 1)
for combo in combinations(defaults, r)
]
# Missing params
query_list = [
"&".join(combo)
for r in range(len(defaults) + 1)
for combo in combinations(defaults, r)
]
# Complete query as default for invalid checks
default_query = query_list.pop(-1)
# Complete query as default for invalid checks
default_query = query_list.pop(-1)
# Checks for each param being invalid
for param in params:
for value in possible_values:
new_value = f"&{param}={value}"
query_list.append(default_query.replace(f"{param}=1", new_value))
# Checks for each param being invalid
for param in params:
for value in possible_values:
new_value = f"&{param}={value}"
query_list.append(default_query.replace(f"{param}=1", new_value))
query_and_status = []
query_and_status = []
# Assign expected status
for query in query_list:
# ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if "42" in query else 422
# Remove leading "&" if present
query = query if len(query) > 1 and query[0] != "&" else query[1:]
query_and_status.append((query, status))
# Assign expected status
for query in query_list:
# ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if "42" in query else 422
# Remove leading "&" if present
query = query if len(query) > 1 and query[0] != "&" else query[1:]
query_and_status.append((query, status))
return query_and_status
return query_and_status
def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]:
possible_values_int = [0, -1, 42, "banana", ""]
possible_values_str = [0, "", "a"]
possible_values_int = [0, -1, 42, "banana", ""]
possible_values_str = [0, "", "a"]
defaults = [{param: 1 for param in params.keys()}]
defaults = [{param: 1 for param in params.keys()}]
# Missing params
body_list = [
{key: ("valid string" if params[key] == "str" else 1) for key in combo}
for r in range(len(defaults[0].keys()) + 1)
for combo in combinations(defaults[0].keys(), r)
]
# Missing params
body_list = [
{key: ("valid string" if params[key] == "str" else 1) for key in combo}
for r in range(len(defaults[0].keys()) + 1)
for combo in combinations(defaults[0].keys(), r)
]
# Complete body as default for generating invalid checks
default_body = body_list.pop(-1)
# Complete body as default for generating invalid checks
default_body = body_list.pop(-1)
# Generates checks for each param being invalid
for param, typ in params.items():
if typ == "int":
possible_values = possible_values_int
elif typ == "str":
possible_values = possible_values_str
else:
raise TypeError(f"Unknown type {typ}")
for value in possible_values:
new_record = default_body.copy()
new_record[param] = value
body_list.append(new_record)
# Generates checks for each param being invalid
for param, typ in params.items():
if typ == "int":
possible_values = possible_values_int
elif typ == "str":
possible_values = possible_values_str
else:
raise TypeError(f"Unknown type {typ}")
for value in possible_values:
new_record = default_body.copy()
new_record[param] = value
body_list.append(new_record)
body_and_status = []
body_and_status = []
# Assign expected status
for body in body_list:
# ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if 42 in body.values() else 422
body_and_status.append((body, status))
return body_and_status
# Assign expected status
for body in body_list:
# ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if 42 in body.values() else 422
body_and_status.append((body, status))
return body_and_status
def get_testable_routes():
routes = []
routes = []
for route in app.routes:
if not isinstance(route, APIRoute):
continue
contexts = list(iter_route_contexts(app.routes))
for method in route.methods:
if method in {"HEAD", "OPTIONS"}:
continue
for route in contexts:
if not route.methods:
continue
if not isinstance(route.route, APIRoute):
continue
routes.append(
(
method,
route.path,
route.status_code,
route.response_model,
route.summary,
)
)
dep_func_names = set()
return routes
unchecked = []
unchecked.append(route.route.dependant)
while unchecked:
dependant = unchecked.pop(0)
ck = dependant.cache_key[0]
if hasattr(ck, "__name__"):
dep_func_names.add(ck.__name__)
unchecked += [
dep for dep in dependant.dependencies if isinstance(dep, Dependant)
]
auth_level = None
if "get_current_user" in dep_func_names:
auth_level = "User"
if (
"org_body_root_claims" in dep_func_names
or "org_query_root_claims" in dep_func_names
):
auth_level = "Root User"
if "user_model_super_admin" in dep_func_names:
auth_level = "Super Admin"
if "valid_service_key" in dep_func_names:
auth_level = "API Key"
for method in route.methods:
if method in {"HEAD", "OPTIONS"}:
continue
routes.append(
(
method,
route.route.path,
route.route.status_code,
route.route.response_model,
route.route.summary,
auth_level,
)
)
return routes
# with open("endpoints.txt", "w") as f:
# for ep in get_testable_routes():
# f.write(f"[{ep[0]}]({ep[1]}) -> {ep[2]}: {ep[3]}\n")
#
#
### Docstring formatted output ###
# with open("endpoints.txt", "w") as f:
# for ep in get_testable_routes():
# f.write(f"- [{ep[0]}]({ep[1]}): []: {ep[4]}\n")
with open("endpoints.txt", "w") as f:
for ep in get_testable_routes():
f.write(f"- [{ep[0]}]({ep[1]}): [{ep[5]}]: {ep[4]}\n")

View file

@ -8,181 +8,181 @@ import pytest
from httpx import AsyncClient
pytestmark = [
pytest.mark.auth,
pytest.mark.preapproval,
pytest.mark.auth,
pytest.mark.preapproval,
]
@pytest.mark.anyio
async def test_get_org_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org?org_id=3")
assert resp.status_code != 422
assert resp.status_code == 200
resp = await no_su_client.get("/org?org_id=3")
assert resp.status_code != 422
assert resp.status_code == 200
@pytest.mark.anyio
async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.patch(
"/org/questionnaire",
json={
"organisation_id": 3,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
assert resp.status_code != 422
assert resp.status_code == 200
resp = await no_su_client.patch(
"/org/questionnaire",
json={
"organisation_id": 3,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
assert resp.status_code != 422
assert resp.status_code == 200
@pytest.mark.anyio
async def test_get_org_users_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/users?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.get("/org/users?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_groups_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/groups?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.get("/org/groups?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_contact_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing")
assert resp.status_code != 422
assert resp.status_code == 200
resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing")
assert resp.status_code != 422
assert resp.status_code == 200
@pytest.mark.anyio
async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.patch(
"/org/contact",
json={
"organisation_id": 3,
"contact_type": "billing",
"email": "user@example.com",
},
)
assert resp.status_code != 422
assert resp.status_code == 200
resp = await no_su_client.patch(
"/org/contact",
json={
"organisation_id": 3,
"contact_type": "billing",
"email": "user@example.com",
},
)
assert resp.status_code != 422
assert resp.status_code == 200
@pytest.mark.anyio
async def test_get_service_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/service?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.get("/service?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_iam_group_permissions_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_iam_group_users_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_post_iam_group_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 3}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 3}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.put(
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 3},
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.put(
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 3},
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_iam_permissions_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/permissions?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.get("/iam/permissions?org_id=3")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_post_iam_permissions_search_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 3, "action": "read"}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 3, "action": "read"}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_delete_org_user_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.delete("/org/user?org_id=3&user_id=1")
resp = await no_su_client.delete("/org/user?org_id=3&user_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_delete_preapproval_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.delete("/org/self?org_id=3")
resp = await no_su_client.delete("/org/self?org_id=3")
assert resp.status_code != 422
assert resp.status_code == 204
assert resp.status_code != 422
assert resp.status_code == 204
@pytest.mark.anyio
async def test_post_user_invitation_auth_approval(no_su_client: AsyncClient):
body = {"user_email": "admin@test.com", "organisation_id": 3}
resp = await no_su_client.post("/user/invitation", json=body)
body = {"user_email": "admin@test.com", "organisation_id": 3}
resp = await no_su_client.post("/user/invitation", json=body)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_delete_group_permissions_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1")
resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_delete_group_users_success(no_su_client: AsyncClient):
resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1")
resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1")
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_put_group_user_invitation_success(no_su_client: AsyncClient):
body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1}
resp = await no_su_client.put("/iam/group/user/invitation", json=body)
body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1}
resp = await no_su_client.put("/iam/group/user/invitation", json=body)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]

View file

@ -5,14 +5,14 @@ from httpx import AsyncClient
pytestmark = [
pytest.mark.auth,
pytest.mark.auth,
]
@pytest.mark.anyio
async def test_get_org_auth_root_su(default_client: AsyncClient):
# If a super admin can access a resource when not the root user
resp = await default_client.get("/org?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 200
assert resp.json()["organisations"][0]["name"] == "Org Two"
# If a super admin can access a resource when not the root user
resp = await default_client.get("/org?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 200
assert resp.json()["organisations"][0]["name"] == "Org Two"

View file

@ -7,147 +7,147 @@ import pytest
from httpx import AsyncClient
pytestmark = [
pytest.mark.auth,
pytest.mark.root_user,
pytest.mark.auth,
pytest.mark.root_user,
]
@pytest.mark.anyio
async def test_get_org_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/org?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.patch(
"/org/questionnaire",
json={
"organisation_id": 2,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.patch(
"/org/questionnaire",
json={
"organisation_id": 2,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_users_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/users?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/org/users?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_groups_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/groups?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/org/groups?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_org_contact_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_patch_org_contact_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.patch(
"/org/contact",
json={
"organisation_id": 2,
"contact_type": "billing",
"email": "user@example.com",
},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.patch(
"/org/contact",
json={
"organisation_id": 2,
"contact_type": "billing",
"email": "user@example.com",
},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_service_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/service?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/service?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_iam_group_permissions_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_post_iam_group_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.put(
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 2},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.put(
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 2},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_put_iam_group_user_auth_root(
no_su_client: AsyncClient,
no_su_client: AsyncClient,
):
resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/permissions?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.get("/iam/permissions?org_id=2")
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 2, "action": "read"}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]
resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 2, "action": "read"}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"]

View file

@ -7,69 +7,69 @@ import pytest
from httpx import AsyncClient
pytestmark = [
pytest.mark.auth,
pytest.mark.super_admin,
pytest.mark.auth,
pytest.mark.super_admin,
]
@pytest.mark.anyio
async def test_get_user_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.get("/user?user_id=1")
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
resp = await no_su_client.get("/user?user_id=1")
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio
async def test_patch_org_status_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch(
"/org/status", json={"organisation_id": 1, "status": "submitted"}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
resp = await no_su_client.patch(
"/org/status", json={"organisation_id": 1, "status": "submitted"}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio
async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
resp = await no_su_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio
async def test_patch_service_key_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch("/service/key", json={"service_id": 1})
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
resp = await no_su_client.patch("/service/key", json={"service_id": 1})
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio
async def test_post_service_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.post("/service", json={"name": "New Test Service"})
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
resp = await no_su_client.post("/service", json={"name": "New Test Service"})
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio
async def test_post_perm_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.post(
"/iam/permission",
json={"service_id": 1, "resource": "test_resource", "action": "create"},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
resp = await no_su_client.post(
"/iam/permission",
json={"service_id": 1, "resource": "test_resource", "action": "create"},
)
assert resp.status_code != 422
assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio
async def test_post_org_user_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2})
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be super admin" in resp.json()["detail"]
resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2})
assert resp.status_code != 422
assert resp.status_code == 403
assert "Must be super admin" in resp.json()["detail"]

View file

@ -7,22 +7,22 @@ from httpx import AsyncClient
pytestmark = [
pytest.mark.auth,
pytest.mark.user,
pytest.mark.auth,
pytest.mark.user,
]
@pytest.mark.anyio
async def test_get_self_db_auth_user(no_user_client: AsyncClient):
resp = await no_user_client.get("/user/self/db")
assert resp.status_code != 422
assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated"
resp = await no_user_client.get("/user/self/db")
assert resp.status_code != 422
assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated"
@pytest.mark.anyio
async def test_post_org_success_auth_user(no_user_client: AsyncClient):
resp = await no_user_client.post("/org", json={"name": "New Test Org"})
assert resp.status_code != 422
assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated"
resp = await no_user_client.post("/org", json={"name": "New Test Org"})
assert resp.status_code != 422
assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated"

View file

@ -4,7 +4,7 @@ from httpx import AsyncClient
@pytest.mark.anyio
async def test_healthcheck(default_client: AsyncClient):
resp = await default_client.get("/healthcheck")
resp = await default_client.get("/healthcheck")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}

File diff suppressed because it is too large Load diff

View file

@ -9,506 +9,506 @@ from .conftest import generate_query_and_status
pytestmark = [
pytest.mark.org_module,
pytest.mark.org_module,
]
@pytest.mark.anyio
async def test_get_org_success(default_client: AsyncClient):
resp = await default_client.get("/org?org_id=1")
data = resp.json()
resp = await default_client.get("/org?org_id=1")
data = resp.json()
assert resp.status_code == 200
assert resp.status_code == 200
org = data["organisations"][0]
org = data["organisations"][0]
assert isinstance(org, dict)
assert org["organisation_id"] == 1
assert org["name"] == "Org One"
assert org["status"] == "approved"
assert org["root_user_email"] == "admin@test.com"
assert "intake_questionnaire" in org
assert isinstance(org["intake_questionnaire"], dict)
assert isinstance(org, dict)
assert org["organisation_id"] == 1
assert org["name"] == "Org One"
assert org["status"] == "approved"
assert org["root_user_email"] == "admin@test.com"
assert "intake_questionnaire" in org
assert isinstance(org["intake_questionnaire"], dict)
assert org["billing_contact"]["email"] == "billing@orgone.com"
assert org["owner_contact"]["email"] == "owner@orgone.com"
assert org["security_contact"]["email"] == "security@orgone.com"
assert org["billing_contact"]["email"] == "billing@orgone.com"
assert org["owner_contact"]["email"] == "owner@orgone.com"
assert org["security_contact"]["email"] == "security@orgone.com"
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio
async def test_get_org_status_checks(
default_client: AsyncClient, query: str, expected_status: int
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org?{query}")
resp = await default_client.get(f"/org?{query}")
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_post_org_success(default_client: AsyncClient):
resp = await default_client.post("/org", json={"name": "New Test Org"})
data = resp.json()
resp = await default_client.post("/org", json={"name": "New Test Org"})
data = resp.json()
assert resp.status_code == 201
assert data["name"] == "New Test Org"
assert data["status"] == "partial"
assert resp.status_code == 201
assert data["name"] == "New Test Org"
assert data["status"] == "partial"
@pytest.mark.parametrize(
"body, expected_status",
[
({"name": 42}, 422),
({}, 422),
({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422),
],
"body, expected_status",
[
({"name": 42}, 422),
({}, 422),
({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422),
],
)
@pytest.mark.anyio
async def test_post_org_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/org", json=body)
resp = await default_client.post("/org", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient):
resp = await default_client.patch(
"/org/questionnaire",
json={
"organisation_id": 3,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
data = resp.json()
resp = await default_client.patch(
"/org/questionnaire",
json={
"organisation_id": 3,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
data = resp.json()
assert resp.status_code == 200
assert data["name"] == "Org Three"
assert data["status"] == "partial"
assert "intake_questionnaire" in data
assert isinstance(data["intake_questionnaire"], dict)
metadata = data["intake_questionnaire"]["metadata"]
assert metadata["version"] == 0
assert metadata["submission_date"] is None
questions = data["intake_questionnaire"]["questions"]
assert questions["question_one"] == "new answer one"
assert questions["question_two"] == "answer two"
assert questions["question_three"] is None
assert resp.status_code == 200
assert data["name"] == "Org Three"
assert data["status"] == "partial"
assert "intake_questionnaire" in data
assert isinstance(data["intake_questionnaire"], dict)
metadata = data["intake_questionnaire"]["metadata"]
assert metadata["version"] == 0
assert metadata["submission_date"] is None
questions = data["intake_questionnaire"]["questions"]
assert questions["question_one"] == "new answer one"
assert questions["question_two"] == "answer two"
assert questions["question_three"] is None
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42}, 404),
({"organisation_id": "Org One"}, 422),
({"organisation_id": ""}, 422),
({}, 422),
(
{
"organisation_id": "1",
"intake_questionnaire": {"question_one": 42},
"partial": True,
},
422,
),
(
{"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}},
422,
),
(
{
"organisation_id": "1",
"intake_questionnaire": {"question_one": "valid"},
"partial": 42,
},
422,
),
],
"body, expected_status",
[
({"organisation_id": 42}, 404),
({"organisation_id": "Org One"}, 422),
({"organisation_id": ""}, 422),
({}, 422),
(
{
"organisation_id": "1",
"intake_questionnaire": {"question_one": 42},
"partial": True,
},
422,
),
(
{"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}},
422,
),
(
{
"organisation_id": "1",
"intake_questionnaire": {"question_one": "valid"},
"partial": 42,
},
422,
),
],
)
@pytest.mark.anyio
async def test_patch_questionnaire_partial_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/questionnaire", json=body)
resp = await default_client.patch("/org/questionnaire", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient):
resp = await default_client.patch(
"/org/questionnaire",
json={
"organisation_id": 3,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": False,
},
)
data = resp.json()
resp = await default_client.patch(
"/org/questionnaire",
json={
"organisation_id": 3,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": False,
},
)
data = resp.json()
assert resp.status_code == 200
assert data["name"] == "Org Three"
assert data["status"] == "submitted"
assert "intake_questionnaire" in data
assert isinstance(data["intake_questionnaire"], dict)
metadata = data["intake_questionnaire"]["metadata"]
assert metadata["version"] == 0
assert metadata["submission_date"] is not None
questions = data["intake_questionnaire"]["questions"]
assert questions["question_one"] == "new answer one"
assert questions["question_two"] == "answer two"
assert questions["question_three"] is None
assert resp.status_code == 200
assert data["name"] == "Org Three"
assert data["status"] == "submitted"
assert "intake_questionnaire" in data
assert isinstance(data["intake_questionnaire"], dict)
metadata = data["intake_questionnaire"]["metadata"]
assert metadata["version"] == 0
assert metadata["submission_date"] is not None
questions = data["intake_questionnaire"]["questions"]
assert questions["question_one"] == "new answer one"
assert questions["question_two"] == "answer two"
assert questions["question_three"] is None
@pytest.mark.parametrize(
"status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"]
"status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"]
)
@pytest.mark.anyio
async def test_patch_org_status_success(default_client: AsyncClient, status: str):
resp = await default_client.patch(
"/org/status", json={"organisation_id": 1, "status": status}
)
data = resp.json()
resp = await default_client.patch(
"/org/status", json={"organisation_id": 1, "status": status}
)
data = resp.json()
assert resp.status_code == 200
assert data["name"] == "Org One"
assert data["status"] == status
assert resp.status_code == 200
assert data["name"] == "Org One"
assert data["status"] == status
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42}, 404),
({"organisation_id": "Org One"}, 422),
({"organisation_id": ""}, 422),
({}, 422),
({"organisation_id": "1", "status": True}, 422),
({"organisation_id": "1", "status": 42}, 422),
],
"body, expected_status",
[
({"organisation_id": 42}, 404),
({"organisation_id": "Org One"}, 422),
({"organisation_id": ""}, 422),
({}, 422),
({"organisation_id": "1", "status": True}, 422),
({"organisation_id": "1", "status": 42}, 422),
],
)
@pytest.mark.anyio
async def test_patch_org_status_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/status", json=body)
resp = await default_client.patch("/org/status", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_get_org_users_success(default_client: AsyncClient):
resp = await default_client.get("/org/users?org_id=1")
data = resp.json()
resp = await default_client.get("/org/users?org_id=1")
data = resp.json()
assert resp.status_code == 200
assert resp.status_code == 200
assert "users" in data
assert isinstance(data["users"], list)
assert len(data["users"]) == 2
assert "users" in data
assert isinstance(data["users"], list)
assert len(data["users"]) == 2
user = data["users"][0]
assert isinstance(user, dict)
assert user["email"] == "admin@test.com"
assert user["id"] == 1
user = data["users"][0]
assert isinstance(user, dict)
assert user["email"] == "admin@test.com"
assert user["id"] == 1
assert "organisation" in data
assert data["organisation"]["name"] == "Org One"
assert data["organisation"]["id"] == 1
assert "organisation" in data
assert data["organisation"]["name"] == "Org One"
assert data["organisation"]["id"] == 1
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio
async def test_get_org_users_status_checks(
default_client: AsyncClient, query: str, expected_status: int
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org/users?{query}")
resp = await default_client.get(f"/org/users?{query}")
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_post_org_user_success(default_client: AsyncClient):
resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3})
resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3})
assert resp.status_code == 200
assert resp.status_code == 200
data = resp.json()
data = resp.json()
assert "organisation" in data
assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "organisation" in data
assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "users" in data
assert isinstance(data["users"], list)
assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1
assert "users" in data
assert isinstance(data["users"], list)
assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42}, 404),
({}, 422),
({"organisation_id": 1, "user_id": "id"}, 422),
({"user_id": 2}, 422),
({"organisation_id": 1, "user_id": 42}, 404),
({"organisation_id": 1, "user_id": 1}, 409),
],
"body, expected_status",
[
({"organisation_id": 42}, 404),
({}, 422),
({"organisation_id": 1, "user_id": "id"}, 422),
({"user_id": 2}, 422),
({"organisation_id": 1, "user_id": 42}, 404),
({"organisation_id": 1, "user_id": 1}, 409),
],
)
@pytest.mark.anyio
async def test_post_org_user_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/org/user", json=body)
resp = await default_client.post("/org/user", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_patch_org_root_user_success(default_client: AsyncClient):
resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code == 200
resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code == 200
data = resp.json()
data = resp.json()
assert data["name"] == "Org One"
assert data["root_user_email"] == "user@orgone.com"
assert data["name"] == "Org One"
assert data["root_user_email"] == "user@orgone.com"
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42, "user_id": 2}, 404),
({"organisation_id": "Org One", "user_id": 2}, 422),
({"organisation_id": "", "user_id": 2}, 422),
({}, 422),
({"user_id": 2}, 422),
({"user_id": 42}, 404),
({"organisation_id": 1, "user_id": "Test User"}, 422),
],
"body, expected_status",
[
({"organisation_id": 42, "user_id": 2}, 404),
({"organisation_id": "Org One", "user_id": 2}, 422),
({"organisation_id": "", "user_id": 2}, 422),
({}, 422),
({"user_id": 2}, 422),
({"user_id": 42}, 404),
({"organisation_id": 1, "user_id": "Test User"}, 422),
],
)
@pytest.mark.anyio
async def test_patch_root_user_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/root_user", json=body)
resp = await default_client.patch("/org/root_user", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_patch_org_root_user_non_member(default_client: AsyncClient):
resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 3}
)
data = resp.json()
resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 3}
)
data = resp.json()
assert resp.status_code == 422
assert data["detail"] == "This user does not belong to your organisation."
assert resp.status_code == 422
assert data["detail"] == "This user does not belong to your organisation."
@pytest.mark.anyio
async def test_get_org_groups_success(default_client: AsyncClient):
resp = await default_client.get("/org/groups?org_id=1")
assert resp.status_code == 200
resp = await default_client.get("/org/groups?org_id=1")
assert resp.status_code == 200
data = resp.json()
data = resp.json()
assert "organisation" in data
assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "organisation" in data
assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "groups" in data
assert isinstance(data["groups"], list)
group = data["groups"][0]
assert isinstance(group, dict)
assert group["id"] == 1
assert group["name"] == "Org One Group"
assert "groups" in data
assert isinstance(data["groups"], list)
group = data["groups"][0]
assert isinstance(group, dict)
assert group["id"] == 1
assert group["name"] == "Org One Group"
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio
async def test_get_org_groups_status_checks(
default_client: AsyncClient, query: str, expected_status: int
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org/groups?{query}")
resp = await default_client.get(f"/org/groups?{query}")
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.parametrize("contact_type", ["billing", "security", "owner"])
@pytest.mark.anyio
async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str):
resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}")
data = resp.json()
resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}")
data = resp.json()
assert resp.status_code == 200
assert resp.status_code == 200
assert "organisation" in data
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "organisation" in data
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
attributes = [
"email",
"first_name",
"last_name",
"phonenumber",
"vat_number",
"address",
]
attributes = [
"email",
"first_name",
"last_name",
"phonenumber",
"vat_number",
"address",
]
for attribute in attributes:
assert attribute in data["contact"]
for attribute in attributes:
assert attribute in data["contact"]
address_attributes = [
"post_office_box_number",
"street_address",
"street_address_line_2",
"locality",
"address_region",
"country_code",
"postal_code",
]
address_attributes = [
"post_office_box_number",
"street_address",
"street_address_line_2",
"locality",
"address_region",
"country_code",
"postal_code",
]
for attribute in address_attributes:
assert attribute in data["contact"]["address"]
for attribute in address_attributes:
assert attribute in data["contact"]["address"]
@pytest.mark.parametrize(
"query, expected_status",
[
("org_id=42&contact_type=billing", 404),
("org_id=banana&contact_type=billing", 422),
("", 422),
("org_id=1&contact_type=contact", 422),
("contact_type=billing", 422),
],
"query, expected_status",
[
("org_id=42&contact_type=billing", 404),
("org_id=banana&contact_type=billing", 422),
("", 422),
("org_id=1&contact_type=contact", 422),
("contact_type=billing", 422),
],
)
@pytest.mark.anyio
async def test_get_org_contact_status_checks(
default_client: AsyncClient, query: str, expected_status: int
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org/contact?{query}")
resp = await default_client.get(f"/org/contact?{query}")
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.parametrize(
"key, value",
[
("email", "user@example.com"),
("first_name", "John"),
("last_name", "Doe"),
("phonenumber", "+441234567890"),
("vat_number", "GB123456789"),
("post_office_box_number", "PO Box 123"),
("street_address", "123 Example Street"),
("street_address_line_2", "Suite 4B"),
("locality", "Glasgow"),
("address_region", "Glasgow City"),
("country_code", "GB"),
("postal_code", "G1 1AA"),
],
"key, value",
[
("email", "user@example.com"),
("first_name", "John"),
("last_name", "Doe"),
("phonenumber", "+441234567890"),
("vat_number", "GB123456789"),
("post_office_box_number", "PO Box 123"),
("street_address", "123 Example Street"),
("street_address_line_2", "Suite 4B"),
("locality", "Glasgow"),
("address_region", "Glasgow City"),
("country_code", "GB"),
("postal_code", "G1 1AA"),
],
)
@pytest.mark.anyio
async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str):
resp = await default_client.patch(
"/org/contact",
json={"organisation_id": 1, "contact_type": "billing", key: value},
)
assert resp.status_code == 200
resp = await default_client.patch(
"/org/contact",
json={"organisation_id": 1, "contact_type": "billing", key: value},
)
assert resp.status_code == 200
data = resp.json()
data = resp.json()
assert "organisation" in data
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "organisation" in data
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
attributes = [
"email",
"first_name",
"last_name",
"phonenumber",
"vat_number",
"address",
]
attributes = [
"email",
"first_name",
"last_name",
"phonenumber",
"vat_number",
"address",
]
for attribute in attributes:
assert attribute in data["contact"]
for attribute in attributes:
assert attribute in data["contact"]
address_attributes = [
"post_office_box_number",
"street_address",
"street_address_line_2",
"locality",
"address_region",
"country_code",
"postal_code",
]
address_attributes = [
"post_office_box_number",
"street_address",
"street_address_line_2",
"locality",
"address_region",
"country_code",
"postal_code",
]
for attribute in address_attributes:
assert attribute in data["contact"]["address"]
for attribute in address_attributes:
assert attribute in data["contact"]["address"]
if key in data["contact"]:
assert data["contact"][key] == value
elif key in data["contact"]["address"]:
assert data["contact"]["address"][key] == value
else:
pytest.fail(f"Invalid contact key: {key}")
if key in data["contact"]:
assert data["contact"][key] == value
elif key in data["contact"]["address"]:
assert data["contact"]["address"][key] == value
else:
pytest.fail(f"Invalid contact key: {key}")
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42, "contact_type": "billing"}, 404),
({"organisation_id": 1, "contact_type": "security"}, 200),
({"organisation_id": 1, "contact_type": "owner"}, 200),
({"organisation_id": "Org One", "contact_type": "billing"}, 422),
({"organisation_id": "", "contact_type": "billing"}, 422),
({}, 422),
({"organisation_id": 1, "contact_type": "not_real"}, 422),
({"organisation_id": 1, "contact_type": 42}, 422),
({"organisation_id": 1, "contact_type": ""}, 422),
],
"body, expected_status",
[
({"organisation_id": 42, "contact_type": "billing"}, 404),
({"organisation_id": 1, "contact_type": "security"}, 200),
({"organisation_id": 1, "contact_type": "owner"}, 200),
({"organisation_id": "Org One", "contact_type": "billing"}, 422),
({"organisation_id": "", "contact_type": "billing"}, 422),
({}, 422),
({"organisation_id": 1, "contact_type": "not_real"}, 422),
({"organisation_id": 1, "contact_type": 42}, 422),
({"organisation_id": 1, "contact_type": ""}, 422),
],
)
@pytest.mark.anyio
async def test_patch_org_contact_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/contact", json=body)
resp = await default_client.patch("/org/contact", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_delete_org_success(default_client: AsyncClient):
resp = await default_client.delete("/org?org_id=1")
resp = await default_client.delete("/org?org_id=1")
assert resp.status_code == 204
assert resp.status_code == 204
@pytest.mark.anyio
async def test_delete_org_users_success(default_client: AsyncClient):
resp = await default_client.delete("/org/user?org_id=1&user_id=2")
resp = await default_client.delete("/org/user?org_id=1&user_id=2")
assert resp.status_code == 204
assert resp.status_code == 204
@pytest.mark.anyio
async def test_delete_preapproval_org_success(default_client: AsyncClient):
resp = await default_client.delete("/org/self?org_id=3")
resp = await default_client.delete("/org/self?org_id=3")
assert resp.status_code == 204
assert resp.status_code == 204

View file

@ -8,90 +8,90 @@ from httpx import AsyncClient
from .conftest import generate_query_and_status, generate_body_and_status
pytestmark = [
pytest.mark.service_module,
pytest.mark.service_module,
]
@pytest.mark.anyio
async def test_get_services_success(default_client: AsyncClient):
resp = await default_client.get("/service?org_id=1")
data = resp.json()
resp = await default_client.get("/service?org_id=1")
data = resp.json()
assert resp.status_code == 200
assert "services" in data
assert isinstance(data["services"], list)
assert data["services"][0]["id"] == 1
assert data["services"][0]["name"] == "Test Service"
assert resp.status_code == 200
assert "services" in data
assert isinstance(data["services"], list)
assert data["services"][0]["id"] == 1
assert data["services"][0]["name"] == "Test Service"
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio
async def test_get_services_status_checks(
default_client: AsyncClient, query: str, expected_status: int
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/service?{query}")
resp = await default_client.get(f"/service?{query}")
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_post_service_success(default_client: AsyncClient):
resp = await default_client.post("/service", json={"name": "New Test Service"})
data = resp.json()
resp = await default_client.post("/service", json={"name": "New Test Service"})
data = resp.json()
assert resp.status_code == 200
assert "service" in data
assert isinstance(data["service"], dict)
assert data["service"]["name"] == "New Test Service"
assert data["service"]["id"] == 2
assert isinstance(data["service"]["api_key"], str)
assert resp.status_code == 200
assert "service" in data
assert isinstance(data["service"], dict)
assert data["service"]["name"] == "New Test Service"
assert data["service"]["id"] == 2
assert isinstance(data["service"]["api_key"], str)
@pytest.mark.parametrize("body, expected_status", generate_body_and_status({"name": "str"}))
@pytest.mark.anyio
async def test_post_service_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/service", json=body)
resp = await default_client.post("/service", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_post_service_conflict(default_client: AsyncClient):
resp = await default_client.post("/service", json={"name": "Test Service"})
resp = await default_client.post("/service", json={"name": "Test Service"})
assert resp.status_code == 409
assert resp.status_code == 409
@pytest.mark.anyio
async def test_patch_service_success(default_client: AsyncClient):
resp = await default_client.patch("/service/key", json={"service_id": 1})
data = resp.json()
resp = await default_client.patch("/service/key", json={"service_id": 1})
data = resp.json()
assert resp.status_code == 200
assert "service" in data
assert isinstance(data["service"], dict)
assert data["service"]["name"] == "Test Service"
assert data["service"]["id"] == 1
assert isinstance(data["service"]["api_key"], str)
assert resp.status_code == 200
assert "service" in data
assert isinstance(data["service"], dict)
assert data["service"]["name"] == "Test Service"
assert data["service"]["id"] == 1
assert isinstance(data["service"]["api_key"], str)
@pytest.mark.parametrize(
"body, expected_status",
generate_body_and_status({"service_id": "int"}),
"body, expected_status",
generate_body_and_status({"service_id": "int"}),
)
@pytest.mark.anyio
async def test_patch_services_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/service/key", json=body)
resp = await default_client.patch("/service/key", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_delete_service_success(default_client: AsyncClient):
resp = await default_client.delete("/service?service_id=1")
resp = await default_client.delete("/service?service_id=1")
assert resp.status_code == 204
assert resp.status_code == 204

View file

@ -5,197 +5,202 @@
import pytest
from httpx import AsyncClient
from fastapi.routing import APIRoute
from fastapi.routing import APIRoute, iter_route_contexts
from .conftest import generate_query_and_status
pytestmark = [
pytest.mark.user_module,
pytest.mark.user_module,
]
@pytest.mark.anyio
async def test_get_self_db_success(default_client: AsyncClient):
resp = await default_client.get("/user/self/db")
data = resp.json()
resp = await default_client.get("/user/self/db")
data = resp.json()
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert isinstance(data["organisations"], list)
assert "groups" in data
assert isinstance(data["groups"], dict)
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert isinstance(data["organisations"], list)
assert "groups" in data
assert isinstance(data["groups"], dict)
@pytest.mark.anyio
async def test_get_user_success(default_client: AsyncClient):
resp = await default_client.get("/user?user_id=1")
data = resp.json()
resp = await default_client.get("/user?user_id=1")
data = resp.json()
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert isinstance(data["organisations"], list)
assert "groups" in data
assert isinstance(data["groups"], dict)
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert isinstance(data["organisations"], list)
assert "groups" in data
assert isinstance(data["groups"], dict)
@pytest.mark.anyio
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["user_id"]))
async def test_get_user_status_checks(
default_client: AsyncClient, query: str, expected_status: int
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/user?{query}")
resp = await default_client.get(f"/user?{query}")
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_delete_user_success(default_client: AsyncClient):
resp = await default_client.delete("/user?user_id=1")
resp = await default_client.delete("/user?user_id=1")
assert resp.status_code == 204
assert resp.status_code == 204
@pytest.mark.anyio
async def test_post_user_invitation_success(default_client: AsyncClient):
body = {"user_email": "admin@test.com", "organisation_id": 1}
resp = await default_client.post("/user/invitation", json=body)
body = {"user_email": "admin@test.com", "organisation_id": 1}
resp = await default_client.post("/user/invitation", json=body)
assert resp.status_code == 200
data = resp.json()
assert "organisation" in data
assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert resp.status_code == 200
data = resp.json()
assert "organisation" in data
assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One"
assert "invited_email" in data
assert isinstance(data["invited_email"], str)
assert data["invited_email"] == "admin@test.com"
assert "invited_email" in data
assert isinstance(data["invited_email"], str)
assert data["invited_email"] == "admin@test.com"
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42, "user_email": "admin@test.com"}, 404),
({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422),
({"organisation_id": "", "user_email": "admin@test.com"}, 422),
({}, 422),
({"user_email": 42}, 422),
({"organisation_id": 1, "user_email": "Test User"}, 422),
],
"body, expected_status",
[
({"organisation_id": 42, "user_email": "admin@test.com"}, 404),
({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422),
({"organisation_id": "", "user_email": "admin@test.com"}, 422),
({}, 422),
({"user_email": 42}, 422),
({"organisation_id": 1, "user_email": "Test User"}, 422),
],
)
@pytest.mark.anyio
async def test_post_user_invitation_status_checks(
default_client: AsyncClient, body, expected_status
default_client: AsyncClient, body, expected_status
):
resp = await default_client.post("/user/invitation", json=body)
resp = await default_client.post("/user/invitation", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
@pytest.mark.parametrize(
"body, expected_status",
[
({"jwt": "invalid"}, 401),
({"jwt": ""}, 401),
({"jwt": None}, 422),
({"jwt": 42}, 422),
],
"body, expected_status",
[
({"jwt": "invalid"}, 401),
({"jwt": ""}, 401),
({"jwt": None}, 422),
({"jwt": 42}, 422),
],
)
@pytest.mark.anyio
async def test_post_user_invitation_accept_status_checks(
default_client: AsyncClient, body, expected_status
default_client: AsyncClient, body, expected_status
):
resp = await default_client.post("/user/invitation/accept", json=body)
resp = await default_client.post("/user/invitation/accept", json=body)
assert resp.status_code == expected_status
assert resp.status_code == expected_status
if resp.status_code == 401:
assert resp.json()["detail"] == "Invalid JWS"
if resp.status_code == 401:
assert resp.json()["detail"] == "Invalid JWS"
@pytest.mark.anyio
async def test_get_self_orgs_success(default_client: AsyncClient):
resp = await default_client.get("/user/self/orgs")
assert resp.status_code == 200
resp = await default_client.get("/user/self/orgs")
assert resp.status_code == 200
data = resp.json()
data = resp.json()
assert "organisations" in data
assert isinstance(data["organisations"], list)
assert len(data["organisations"]) > 0
assert "organisations" in data
assert isinstance(data["organisations"], list)
assert len(data["organisations"]) > 0
org = data["organisations"][0]
assert org["organisation_id"] == 1
assert org["name"] == "Org One"
assert org["status"] == "approved"
assert org["root_user_email"] == "admin@test.com"
assert "intake_questionnaire" in org
assert isinstance(org["intake_questionnaire"], dict)
org = data["organisations"][0]
assert org["organisation_id"] == 1
assert org["name"] == "Org One"
assert org["status"] == "approved"
assert org["root_user_email"] == "admin@test.com"
assert "intake_questionnaire" in org
assert isinstance(org["intake_questionnaire"], dict)
assert isinstance(org["billing_contact"], dict)
assert org["billing_contact"]["email"] == "billing@orgone.com"
assert org["billing_contact"]["id"] == 1
assert isinstance(org["billing_contact"], dict)
assert org["billing_contact"]["email"] == "billing@orgone.com"
assert org["billing_contact"]["id"] == 1
assert isinstance(org["owner_contact"], dict)
assert org["owner_contact"]["email"] == "owner@orgone.com"
assert org["owner_contact"]["id"] == 2
assert isinstance(org["owner_contact"], dict)
assert org["owner_contact"]["email"] == "owner@orgone.com"
assert org["owner_contact"]["id"] == 2
assert isinstance(org["security_contact"], dict)
assert org["security_contact"]["email"] == "security@orgone.com"
assert org["security_contact"]["id"] == 3
assert isinstance(org["security_contact"], dict)
assert org["security_contact"]["email"] == "security@orgone.com"
assert org["security_contact"]["id"] == 3
@pytest.mark.anyio
async def test_get_self_orgs_dynamic(default_client: AsyncClient):
method = "GET"
path = "/user/self/orgs"
expected_data = {
"organisations": [
{
"organisation_id": 1,
"name": "Org One",
"status": "approved",
"root_user_email": "admin@test.com",
"owner_contact": {"email": "owner@orgone.com", "id": 2},
"security_contact": {"email": "security@orgone.com", "id": 3},
"billing_contact": {"email": "billing@orgone.com", "id": 1},
"intake_questionnaire": {
"questions": {
"question_one": None,
"question_three": None,
"question_two": "answer two",
},
"metadata": {"version": 0, "submission_date": None},
},
}
]
}
method = "GET"
path = "/user/self/orgs"
expected_data = {
"organisations": [
{
"organisation_id": 1,
"name": "Org One",
"status": "approved",
"root_user_email": "admin@test.com",
"owner_contact": {"email": "owner@orgone.com", "id": 2},
"security_contact": {"email": "security@orgone.com", "id": 3},
"billing_contact": {"email": "billing@orgone.com", "id": 1},
"intake_questionnaire": {
"questions": {
"question_one": None,
"question_three": None,
"question_two": "answer two",
},
"metadata": {"version": 0, "submission_date": None},
},
}
]
}
resp = await default_client.get(path)
resp = await default_client.get(path)
route = next(
route
for route in default_client._transport.app.routes # ty:ignore[unresolved-attribute]
if isinstance(route, APIRoute) and path in route.path and method in route.methods
)
contexts = list(iter_route_contexts(default_client._transport.app.routes)) # ty:ignore[unresolved-attribute]
assert resp.status_code == route.status_code
if route.status_code == 204:
return
route = next(
route.route
for route in contexts
if isinstance(route.route, APIRoute)
and path in route.route.path
and isinstance(route.methods, set)
and method in route.methods
)
expected_response_schema = route.response_model
data = resp.json()
assert resp.status_code == route.status_code
if route.status_code == 204:
return
response_model = expected_response_schema(**data)
assert isinstance(response_model, expected_response_schema)
expected_response_schema = route.response_model
data = resp.json()
expected_response_model = expected_response_schema(**expected_data)
response_model = expected_response_schema(**data)
assert isinstance(response_model, expected_response_schema)
assert response_model == expected_response_model
expected_response_model = expected_response_schema(**expected_data)
assert response_model == expected_response_model

11
uv.lock generated
View file

@ -6,6 +6,9 @@ requires-python = ">=3.12"
exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values.
exclude-newer-span = "P2W"
[options.exclude-newer-package]
fastapi = "2026-06-22T00:00:00Z"
[[package]]
name = "alembic"
version = "1.18.4"
@ -238,7 +241,7 @@ dev = [
requires-dist = [
{ name = "alembic", specifier = ">=1.18.4" },
{ name = "email-validator", specifier = ">=2.3.0" },
{ name = "fastapi", specifier = ">=0.136.3" },
{ name = "fastapi", specifier = ">=0.138.0" },
{ name = "httptools", specifier = ">=0.7.1" },
{ name = "httpx", specifier = ">=0.28.1" },
{ name = "itsdangerous", specifier = ">=2.2.0" },
@ -349,7 +352,7 @@ wheels = [
[[package]]
name = "fastapi"
version = "0.136.3"
version = "0.138.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "annotated-doc" },
@ -358,9 +361,9 @@ dependencies = [
{ name = "typing-extensions" },
{ name = "typing-inspection" },
]
sdist = { url = "https://files.pythonhosted.org/packages/81/2d/ff8d91d7b564d464629a0fd50a4489c97fcb836ac230bf3a7269232a9b1f/fastapi-0.136.3.tar.gz", hash = "sha256:e487fae93ad408e6f47641ee4dfe389864fd7bec92e547ea8498fc13f43e83ab", size = 396410, upload-time = "2026-05-23T18:53:15.192Z" }
sdist = { url = "https://files.pythonhosted.org/packages/5b/58/ff455d9fe47c60abadb34b9e05a304b1f05f5ab8000ac01565156b6f5e43/fastapi-0.138.0.tar.gz", hash = "sha256:d445a4877636ad191e7053e08c9bf98cb921a6756776848400bb773d1740c061", size = 419240, upload-time = "2026-06-20T01:18:05.259Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/e0/82/45359b62a067409bd929ae8a56b8ed13e5a8c8a61194b3c236920999ab83/fastapi-0.136.3-py3-none-any.whl", hash = "sha256:3d2a69bdf04b7e9f3afa292c3bc7a98816bbfafa10bc9b45f3f3700d2f761620", size = 117481, upload-time = "2026-05-23T18:53:16.924Z" },
{ url = "https://files.pythonhosted.org/packages/6c/ff/8496d9847a5fedae775eb49460722d3efaa80487854273e9647ae876218c/fastapi-0.138.0-py3-none-any.whl", hash = "sha256:b6f54fd1bd72c80b0f899f172c61a600f6f7af9b43d4d772a018f35624048cb0", size = 126779, upload-time = "2026-06-20T01:18:03.483Z" },
]
[[package]]