Compare commits
1 commit
| Author | SHA1 | Date | |
|---|---|---|---|
| 2a1d28bc54 |
65 changed files with 3727 additions and 3946 deletions
|
|
@ -34,6 +34,8 @@ jobs:
|
||||||
- run: uv python install # Gets Python version from pyproject.toml
|
- run: uv python install # Gets Python version from pyproject.toml
|
||||||
- run: uv sync --dev
|
- run: uv sync --dev
|
||||||
- run: uv run ty check
|
- run: uv run ty check
|
||||||
|
- run: uv run ruff format
|
||||||
|
- run: uv run pytest test
|
||||||
env:
|
env:
|
||||||
ENVIRONMENT: testing
|
ENVIRONMENT: testing
|
||||||
|
|
||||||
|
|
@ -52,35 +54,3 @@ jobs:
|
||||||
- run: uv run pytest test
|
- run: uv run pytest test
|
||||||
env:
|
env:
|
||||||
ENVIRONMENT: testing
|
ENVIRONMENT: testing
|
||||||
|
|
||||||
build:
|
|
||||||
needs: [ ruff, ty, tests ]
|
|
||||||
if: ${{ always() && needs.ruff.result == 'success' && needs.ty.result == 'success' && needs.tests.result == 'success' }}
|
|
||||||
runs-on: docker
|
|
||||||
container:
|
|
||||||
image: ghcr.io/catthehacker/ubuntu:act-latest
|
|
||||||
options: -v /dind/docker.sock:/var/run/docker.sock
|
|
||||||
steps:
|
|
||||||
- name: Checkout the repo
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
- name: Checkout the frontend
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
repository: sr2/cloud-portal.git
|
|
||||||
path: frontend
|
|
||||||
ref: main
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
- name: Login to the registry
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
registry: guardianproject.dev
|
|
||||||
username: irl
|
|
||||||
password: ${{ secrets.PACKAGE_TOKEN }}
|
|
||||||
- name: Build and push
|
|
||||||
uses: docker/build-push-action@v6
|
|
||||||
with:
|
|
||||||
file: /workspace/sr2/cloud-api/Containerfile
|
|
||||||
context: /workspace/sr2/cloud-api/
|
|
||||||
push: true
|
|
||||||
tags: guardianproject.dev/${{ github.repository }}:${{ github.ref_name }}
|
|
||||||
|
|
|
||||||
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -206,7 +206,5 @@ marimo/_static/
|
||||||
marimo/_lsp/
|
marimo/_lsp/
|
||||||
__marimo__/
|
__marimo__/
|
||||||
|
|
||||||
endpoints.txt
|
|
||||||
|
|
||||||
# React Frontend
|
endpoints.txt
|
||||||
/frontend/
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
3.12
|
3.14
|
||||||
|
|
|
||||||
|
|
@ -1,42 +0,0 @@
|
||||||
FROM node:22-slim AS react-builder
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
COPY frontend/ /app/
|
|
||||||
RUN --mount=type=cache,target=/root/.npm npm ci
|
|
||||||
RUN npm run build # Outputs to /app/dist
|
|
||||||
|
|
||||||
FROM ghcr.io/astral-sh/uv:python3.12-trixie-slim AS python-builder
|
|
||||||
|
|
||||||
ENV UV_PYTHON_DOWNLOADS=0
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Install dependencies first (layer caching)
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
|
||||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
|
||||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
|
||||||
uv sync --locked --no-install-project --no-editable
|
|
||||||
|
|
||||||
# Copy project source and install the project itself
|
|
||||||
COPY ./ /app/
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
|
||||||
uv sync --locked --no-editable
|
|
||||||
|
|
||||||
|
|
||||||
FROM python:3.12-slim-trixie
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
COPY alembic /app/alembic
|
|
||||||
COPY alembic.ini /app
|
|
||||||
COPY src /app/src
|
|
||||||
COPY --from=python-builder /app/.venv /app/.venv
|
|
||||||
COPY --from=react-builder /app/dist /app/static
|
|
||||||
|
|
||||||
# Ensure venv is on PATH
|
|
||||||
ENV PATH="/app/.venv/bin:$PATH" \
|
|
||||||
UV_PYTHON_DOWNLOADS=0
|
|
||||||
|
|
||||||
EXPOSE 8000
|
|
||||||
|
|
||||||
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
||||||
|
|
@ -1,32 +0,0 @@
|
||||||
"""fix user activated default
|
|
||||||
|
|
||||||
Revision ID: ae433e1c3b20
|
|
||||||
Revises: 661202797ecd
|
|
||||||
Create Date: 2026-06-22 15:26:57.805129
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = 'ae433e1c3b20'
|
|
||||||
down_revision: Union[str, Sequence[str], None] = '661202797ecd'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.alter_column('user', 'active', server_default=sa.true())
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.alter_column('user', 'active', server_default=sa.false())
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
@ -1,44 +0,0 @@
|
||||||
"""model mixins
|
|
||||||
|
|
||||||
Revision ID: 661202797ecd
|
|
||||||
Revises: 869d48618a1c
|
|
||||||
Create Date: 2026-06-22 13:29:39.689067
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = '661202797ecd'
|
|
||||||
down_revision: Union[str, Sequence[str], None] = '869d48618a1c'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
"""Upgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.add_column('organisation', sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()))
|
|
||||||
op.add_column('organisation', sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()))
|
|
||||||
op.add_column('organisation', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
|
||||||
op.add_column('user', sa.Column('active', sa.Boolean(), nullable=False, server_default=sa.false()))
|
|
||||||
op.add_column('user', sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()))
|
|
||||||
op.add_column('user', sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()))
|
|
||||||
op.add_column('user', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
"""Downgrade schema."""
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_column('user', 'deleted_at')
|
|
||||||
op.drop_column('user', 'updated_at')
|
|
||||||
op.drop_column('user', 'created_at')
|
|
||||||
op.drop_column('user', 'active')
|
|
||||||
op.drop_column('organisation', 'deleted_at')
|
|
||||||
op.drop_column('organisation', 'updated_at')
|
|
||||||
op.drop_column('organisation', 'created_at')
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
@ -8,7 +8,7 @@ requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"alembic>=1.18.4",
|
"alembic>=1.18.4",
|
||||||
"email-validator>=2.3.0",
|
"email-validator>=2.3.0",
|
||||||
"fastapi>=0.138.0",
|
"fastapi>=0.136.3",
|
||||||
"httptools>=0.7.1",
|
"httptools>=0.7.1",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
"itsdangerous>=2.2.0",
|
"itsdangerous>=2.2.0",
|
||||||
|
|
@ -34,11 +34,11 @@ line-length = 92
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
quote-style = "double"
|
quote-style = "double"
|
||||||
|
indent-style = "tab"
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
add-bounds = "major"
|
add-bounds = "major"
|
||||||
exclude-newer = "P2W"
|
exclude-newer = "P2W"
|
||||||
exclude-newer-package = { "fastapi" = "2026-06-22T00:00:00Z" }
|
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
|
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
{
|
|
||||||
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
|
||||||
"extends": [
|
|
||||||
"config:recommended"
|
|
||||||
],
|
|
||||||
"minimumReleaseAge": "7 days",
|
|
||||||
"gitAuthor": "Renovate<noreply@sr2.uk>"
|
|
||||||
}
|
|
||||||
|
|
@ -22,5 +22,5 @@ from fastapi import APIRouter
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
tags=[""],
|
tags=[""],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,6 @@ Exports:
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
tags=["admin"],
|
tags=["admin"],
|
||||||
prefix="/admin",
|
prefix="/admin",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
14
src/api.py
14
src/api.py
|
|
@ -26,15 +26,15 @@ api_router.include_router(iam_router)
|
||||||
|
|
||||||
|
|
||||||
class HealthCheckResponse(CustomBaseModel):
|
class HealthCheckResponse(CustomBaseModel):
|
||||||
status: str
|
status: str
|
||||||
|
|
||||||
|
|
||||||
@api_router.get(
|
@api_router.get(
|
||||||
path="/healthcheck",
|
path="/healthcheck",
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
response_model=HealthCheckResponse,
|
response_model=HealthCheckResponse,
|
||||||
include_in_schema=False,
|
include_in_schema=False,
|
||||||
)
|
)
|
||||||
def healthcheck():
|
def healthcheck():
|
||||||
"""Simple health check endpoint."""
|
"""Simple health check endpoint."""
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,9 @@ from src.config import CustomBaseSettings
|
||||||
|
|
||||||
|
|
||||||
class AuthConfig(CustomBaseSettings):
|
class AuthConfig(CustomBaseSettings):
|
||||||
OIDC_CONFIG: str = ""
|
OIDC_CONFIG: str = ""
|
||||||
OIDC_ISSUER: str = ""
|
OIDC_ISSUER: str = ""
|
||||||
CLIENT_ID: str = ""
|
CLIENT_ID: str = ""
|
||||||
|
|
||||||
|
|
||||||
auth_settings = AuthConfig()
|
auth_settings = AuthConfig()
|
||||||
|
|
|
||||||
|
|
@ -16,92 +16,92 @@ from src.exceptions import ForbiddenException
|
||||||
from src.user.dependencies import user_model_claims_dependency
|
from src.user.dependencies import user_model_claims_dependency
|
||||||
from src.user.models import User
|
from src.user.models import User
|
||||||
from src.organisation.dependencies import (
|
from src.organisation.dependencies import (
|
||||||
org_model_query_dependency,
|
org_model_query_dependency,
|
||||||
org_model_body_dependency,
|
org_model_body_dependency,
|
||||||
)
|
)
|
||||||
from src.organisation.models import Organisation as Org
|
from src.organisation.models import Organisation as Org
|
||||||
|
|
||||||
|
|
||||||
async def org_query_user_claims(
|
async def org_query_user_claims(
|
||||||
org_model: org_model_query_dependency, user_model: user_model_claims_dependency
|
org_model: org_model_query_dependency, user_model: user_model_claims_dependency
|
||||||
):
|
):
|
||||||
if user_model in org_model.user_rel:
|
if user_model in org_model.user_rel:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
raise ForbiddenException()
|
raise ForbiddenException()
|
||||||
|
|
||||||
|
|
||||||
org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)]
|
org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)]
|
||||||
|
|
||||||
|
|
||||||
def get_super_admin_list():
|
def get_super_admin_list():
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def empty_su_list():
|
def empty_su_list():
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def testing_su_list():
|
def testing_su_list():
|
||||||
return ["admin@test.com"]
|
return ["admin@test.com"]
|
||||||
|
|
||||||
|
|
||||||
su_list_dependency = Annotated[list[str | None], Depends(get_super_admin_list)]
|
su_list_dependency = Annotated[list[str | None], Depends(get_super_admin_list)]
|
||||||
|
|
||||||
|
|
||||||
async def user_model_super_admin(
|
async def user_model_super_admin(
|
||||||
user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency
|
user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency
|
||||||
):
|
):
|
||||||
if user_model.email in super_admin_emails:
|
if user_model.email in super_admin_emails:
|
||||||
return user_model
|
return user_model
|
||||||
|
|
||||||
raise ForbiddenException(message="Must be super admin")
|
raise ForbiddenException(message="Must be super admin")
|
||||||
|
|
||||||
|
|
||||||
super_admin_dependency = Annotated[User, Depends(user_model_super_admin)]
|
super_admin_dependency = Annotated[User, Depends(user_model_super_admin)]
|
||||||
|
|
||||||
|
|
||||||
async def org_query_root_claims(
|
async def org_query_root_claims(
|
||||||
user_model: user_model_claims_dependency,
|
user_model: user_model_claims_dependency,
|
||||||
org_model: org_model_query_dependency,
|
org_model: org_model_query_dependency,
|
||||||
su_emails: su_list_dependency,
|
su_emails: su_list_dependency,
|
||||||
request: Request,
|
request: Request,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if await user_model_super_admin(user_model, su_emails):
|
if await user_model_super_admin(user_model, su_emails):
|
||||||
return org_model
|
return org_model
|
||||||
except ForbiddenException:
|
except ForbiddenException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
await org_status_check(org_model, request)
|
await org_status_check(org_model, request)
|
||||||
|
|
||||||
if org_model.root_user_id == user_model.id:
|
if org_model.root_user_id == user_model.id:
|
||||||
return org_model
|
return org_model
|
||||||
|
|
||||||
raise ForbiddenException(message="Must be the org's root user")
|
raise ForbiddenException(message="Must be the org's root user")
|
||||||
|
|
||||||
|
|
||||||
org_model_root_claim_query_dependency = Annotated[Org, Depends(org_query_root_claims)]
|
org_model_root_claim_query_dependency = Annotated[Org, Depends(org_query_root_claims)]
|
||||||
|
|
||||||
|
|
||||||
async def org_body_root_claims(
|
async def org_body_root_claims(
|
||||||
user_model: user_model_claims_dependency,
|
user_model: user_model_claims_dependency,
|
||||||
org_model: org_model_body_dependency,
|
org_model: org_model_body_dependency,
|
||||||
su_emails: su_list_dependency,
|
su_emails: su_list_dependency,
|
||||||
request: Request,
|
request: Request,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if await user_model_super_admin(user_model, su_emails):
|
if await user_model_super_admin(user_model, su_emails):
|
||||||
return org_model
|
return org_model
|
||||||
except ForbiddenException:
|
except ForbiddenException:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
await org_status_check(org_model, request)
|
await org_status_check(org_model, request)
|
||||||
|
|
||||||
if org_model.root_user_id == user_model.id:
|
if org_model.root_user_id == user_model.id:
|
||||||
return org_model
|
return org_model
|
||||||
|
|
||||||
raise ForbiddenException(message="Must be the org's root user")
|
raise ForbiddenException(message="Must be the org's root user")
|
||||||
|
|
||||||
|
|
||||||
org_model_root_claim_body_dependency = Annotated[Org, Depends(org_body_root_claims)]
|
org_model_root_claim_body_dependency = Annotated[Org, Depends(org_body_root_claims)]
|
||||||
|
|
|
||||||
|
|
@ -8,5 +8,5 @@ Exports:
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
tags=["auth"],
|
tags=["auth"],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from src.organisation.exceptions import AwaitingApprovalException
|
||||||
from src.organisation.models import Organisation as Org
|
from src.organisation.models import Organisation as Org
|
||||||
from src.exceptions import UnauthorizedException, ForbiddenException
|
from src.exceptions import UnauthorizedException, ForbiddenException
|
||||||
from src.auth.config import auth_settings
|
from src.auth.config import auth_settings
|
||||||
from src.user.service import add_user
|
from src.user.service import add_user_to_db
|
||||||
from src.database import DbSession
|
from src.database import DbSession
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -31,56 +31,56 @@ oidc_dependency = Annotated[str, Depends(oidc)]
|
||||||
|
|
||||||
|
|
||||||
async def get_dev_user():
|
async def get_dev_user():
|
||||||
return {"db_id": 1, "email": "chris@sr2.uk"}
|
return {"db_id": 1, "email": "chris@sr2.uk"}
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
oidc_auth_string: oidc_dependency, db: DbSession
|
oidc_auth_string: oidc_dependency, db: DbSession
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
config_url = urlopen(auth_settings.OIDC_CONFIG)
|
config_url = urlopen(auth_settings.OIDC_CONFIG)
|
||||||
config = json.loads(config_url.read())
|
config = json.loads(config_url.read())
|
||||||
jwks_uri = config["jwks_uri"]
|
jwks_uri = config["jwks_uri"]
|
||||||
key_response = requests.get(jwks_uri)
|
key_response = requests.get(jwks_uri)
|
||||||
jwk_keys = KeySet.import_key_set(key_response.json())
|
jwk_keys = KeySet.import_key_set(key_response.json())
|
||||||
|
|
||||||
token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys)
|
token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys)
|
||||||
|
|
||||||
claims_requests = jwt.JWTClaimsRegistry(
|
claims_requests = jwt.JWTClaimsRegistry(
|
||||||
exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER}
|
exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER}
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
claims_requests.validate(token.claims)
|
claims_requests.validate(token.claims)
|
||||||
except ExpiredTokenError:
|
except ExpiredTokenError:
|
||||||
raise UnauthorizedException(message="Token is expired")
|
raise UnauthorizedException(message="Token is expired")
|
||||||
db_id = await add_user(db, token.claims)
|
db_id = await add_user_to_db(db, token.claims)
|
||||||
|
|
||||||
token.claims["db_id"] = db_id
|
token.claims["db_id"] = db_id
|
||||||
|
|
||||||
return token.claims
|
return token.claims
|
||||||
|
|
||||||
|
|
||||||
claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)]
|
claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)]
|
||||||
|
|
||||||
|
|
||||||
async def org_status_check(org_model: Org, request: Request):
|
async def org_status_check(org_model: Org, request: Request):
|
||||||
org_status = OrgStatus(org_model.status)
|
org_status = OrgStatus(org_model.status)
|
||||||
if org_status.is_blocked:
|
if org_status.is_blocked:
|
||||||
raise ForbiddenException("This organisation cannot perform this action.")
|
raise ForbiddenException("This organisation cannot perform this action.")
|
||||||
|
|
||||||
root = "/api/v1"
|
root = "/api/v1"
|
||||||
|
|
||||||
pre_approval_endpoints = [
|
pre_approval_endpoints = [
|
||||||
f"PATCH{root}/org/status",
|
f"PATCH{root}/org/status",
|
||||||
f"PATCH{root}/org/questionnaire",
|
f"PATCH{root}/org/questionnaire",
|
||||||
f"GET{root}/org",
|
f"GET{root}/org",
|
||||||
f"GET{root}/org/contact",
|
f"GET{root}/org/contact",
|
||||||
f"PATCH{root}/org/contact",
|
f"PATCH{root}/org/contact",
|
||||||
f"DELETE{root}/org/self",
|
f"DELETE{root}/org/self",
|
||||||
]
|
]
|
||||||
current_request = f"{request.method}{request.url.path}"
|
current_request = f"{request.method}{request.url.path}"
|
||||||
if (
|
if (
|
||||||
current_request not in pre_approval_endpoints
|
current_request not in pre_approval_endpoints
|
||||||
and org_model.status != OrgStatus.APPROVED
|
and org_model.status != OrgStatus.APPROVED
|
||||||
):
|
):
|
||||||
raise AwaitingApprovalException(org_model.id)
|
raise AwaitingApprovalException(org_model.id)
|
||||||
|
|
|
||||||
|
|
@ -16,31 +16,31 @@ from src.constants import Environment
|
||||||
|
|
||||||
|
|
||||||
class CustomBaseSettings(BaseSettings):
|
class CustomBaseSettings(BaseSettings):
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(
|
||||||
env_file=".env", env_file_encoding="utf-8", extra="ignore"
|
env_file=".env", env_file_encoding="utf-8", extra="ignore"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Config(CustomBaseSettings):
|
class Config(CustomBaseSettings):
|
||||||
APP_VERSION: str = "0.1"
|
APP_VERSION: str = "0.1"
|
||||||
ENVIRONMENT: Environment = Environment.PRODUCTION
|
ENVIRONMENT: Environment = Environment.PRODUCTION
|
||||||
SECRET_KEY: SecretStr = SecretStr("")
|
SECRET_KEY: SecretStr = SecretStr("")
|
||||||
DISABLE_AUTH: bool = False
|
DISABLE_AUTH: bool = False
|
||||||
|
|
||||||
CORS_ORIGINS: list[str] = ["*"]
|
CORS_ORIGINS: list[str] = ["*"]
|
||||||
CORS_ORIGINS_REGEX: str | None = None
|
CORS_ORIGINS_REGEX: str | None = None
|
||||||
CORS_HEADERS: list[str] = ["*"]
|
CORS_HEADERS: list[str] = ["*"]
|
||||||
|
|
||||||
DATABASE_NAME: str = "fastapi-exp"
|
DATABASE_NAME: str = "fastapi-exp"
|
||||||
DATABASE_PORT: str = "5432"
|
DATABASE_PORT: str = "5432"
|
||||||
DATABASE_HOSTNAME: str = "localhost"
|
DATABASE_HOSTNAME: str = "localhost"
|
||||||
DATABASE_CREDENTIALS: SecretStr = SecretStr(":")
|
DATABASE_CREDENTIALS: SecretStr = SecretStr(":")
|
||||||
|
|
||||||
DATABASE_POOL_SIZE: int = 16
|
DATABASE_POOL_SIZE: int = 16
|
||||||
DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes
|
DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes
|
||||||
DATABASE_POOL_PRE_PING: bool = True
|
DATABASE_POOL_PRE_PING: bool = True
|
||||||
|
|
||||||
LETTERMINT_API_TOKEN: SecretStr = SecretStr("")
|
LETTERMINT_API_TOKEN: SecretStr = SecretStr("")
|
||||||
|
|
||||||
|
|
||||||
settings = Config()
|
settings = Config()
|
||||||
|
|
@ -51,20 +51,20 @@ DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME
|
||||||
DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value()
|
DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value()
|
||||||
# this will support special chars for credentials
|
# this will support special chars for credentials
|
||||||
_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split(
|
_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split(
|
||||||
":"
|
":"
|
||||||
)
|
)
|
||||||
_QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD))
|
_QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD))
|
||||||
|
|
||||||
SQLALCHEMY_DATABASE_URI = SecretStr(
|
SQLALCHEMY_DATABASE_URI = SecretStr(
|
||||||
f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
|
f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if settings.ENVIRONMENT == Environment.TESTING:
|
if settings.ENVIRONMENT == Environment.TESTING:
|
||||||
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")
|
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")
|
||||||
|
|
||||||
app_configs: dict[str, Any] = {"title": "App API"}
|
app_configs: dict[str, Any] = {"title": "App API"}
|
||||||
if settings.ENVIRONMENT.is_deployed:
|
if settings.ENVIRONMENT.is_deployed:
|
||||||
app_configs["root_path"] = f"/v{settings.APP_VERSION}"
|
app_configs["root_path"] = f"/v{settings.APP_VERSION}"
|
||||||
|
|
||||||
if not settings.ENVIRONMENT.is_debug:
|
if not settings.ENVIRONMENT.is_debug:
|
||||||
app_configs["openapi_url"] = None # hide docs
|
app_configs["openapi_url"] = None # hide docs
|
||||||
|
|
|
||||||
|
|
@ -9,29 +9,29 @@ from enum import StrEnum, auto
|
||||||
|
|
||||||
|
|
||||||
class Environment(StrEnum):
|
class Environment(StrEnum):
|
||||||
"""
|
"""
|
||||||
Enumeration of environments.
|
Enumeration of environments.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
LOCAL (str): Application is running locally
|
LOCAL (str): Application is running locally
|
||||||
TESTING (str): Application is running in testing mode
|
TESTING (str): Application is running in testing mode
|
||||||
STAGING (str): Application is running in staging mode (ie not testing)
|
STAGING (str): Application is running in staging mode (ie not testing)
|
||||||
PRODUCTION (str): Application is running in production mode
|
PRODUCTION (str): Application is running in production mode
|
||||||
"""
|
"""
|
||||||
|
|
||||||
LOCAL = auto()
|
LOCAL = auto()
|
||||||
TESTING = auto()
|
TESTING = auto()
|
||||||
STAGING = auto()
|
STAGING = auto()
|
||||||
PRODUCTION = auto()
|
PRODUCTION = auto()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_debug(self):
|
def is_debug(self):
|
||||||
return self in (self.LOCAL, self.STAGING, self.TESTING)
|
return self in (self.LOCAL, self.STAGING, self.TESTING)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_testing(self):
|
def is_testing(self):
|
||||||
return self == self.TESTING
|
return self == self.TESTING
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_deployed(self) -> bool:
|
def is_deployed(self) -> bool:
|
||||||
return self in (self.STAGING, self.PRODUCTION)
|
return self in (self.STAGING, self.PRODUCTION)
|
||||||
|
|
|
||||||
|
|
@ -11,13 +11,13 @@ from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
|
||||||
class ContactNotFoundException(HTTPException):
|
class ContactNotFoundException(HTTPException):
|
||||||
def __init__(self, contact_id: Optional[int] = None) -> None:
|
def __init__(self, contact_id: Optional[int] = None) -> None:
|
||||||
detail = (
|
detail = (
|
||||||
"Contact not found"
|
"Contact not found"
|
||||||
if contact_id is None
|
if contact_id is None
|
||||||
else f"Contact with ID '{contact_id}' was not found."
|
else f"Contact with ID '{contact_id}' was not found."
|
||||||
)
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=detail,
|
detail=detail,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -6,31 +6,30 @@ Models:
|
||||||
street_address, street_address_line_2, post_office_box_number, address_locality, country_code, address_region, postal_code
|
street_address, street_address_line_2, post_office_box_number, address_locality, country_code, address_region, postal_code
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.models import IdMixin
|
|
||||||
|
|
||||||
from sqlalchemy import ForeignKey
|
from sqlalchemy import ForeignKey
|
||||||
from sqlalchemy.orm import mapped_column, Mapped
|
from sqlalchemy.orm import mapped_column, Mapped
|
||||||
|
|
||||||
from src.models import CustomBase
|
from src.models import CustomBase
|
||||||
|
|
||||||
|
|
||||||
class Contact(CustomBase, IdMixin):
|
class Contact(CustomBase):
|
||||||
__tablename__ = "contact"
|
__tablename__ = "contact"
|
||||||
|
|
||||||
email: Mapped[str] = mapped_column(default=None, nullable=True)
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
first_name: Mapped[str] = mapped_column(default=None, nullable=True)
|
email: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||||
last_name: Mapped[str] = mapped_column(default=None, nullable=True)
|
first_name: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||||
phonenumber: Mapped[str] = mapped_column(default=None, nullable=True)
|
last_name: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||||
vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
|
phonenumber: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||||
|
vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
|
||||||
|
|
||||||
street_address: Mapped[str] = mapped_column(default=None, nullable=True)
|
street_address: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||||
street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True)
|
street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||||
post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
|
post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
|
||||||
locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City
|
locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City
|
||||||
country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB
|
country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB
|
||||||
address_region: Mapped[str | None] = mapped_column(default=None, nullable=True)
|
address_region: Mapped[str | None] = mapped_column(default=None, nullable=True)
|
||||||
postal_code: Mapped[str] = mapped_column(default=None, nullable=True)
|
postal_code: Mapped[str] = mapped_column(default=None, nullable=True)
|
||||||
|
|
||||||
org_id: Mapped[int] = mapped_column(
|
org_id: Mapped[int] = mapped_column(
|
||||||
ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False
|
ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,6 @@ from fastapi import APIRouter
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/contact",
|
prefix="/contact",
|
||||||
tags=["contact"],
|
tags=["contact"],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -14,22 +14,22 @@ from src.schemas import CustomBaseModel
|
||||||
|
|
||||||
|
|
||||||
class ContactAddress(CustomBaseModel):
|
class ContactAddress(CustomBaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||||
|
|
||||||
post_office_box_number: Optional[str] = None
|
post_office_box_number: Optional[str] = None
|
||||||
street_address: Optional[str] = None
|
street_address: Optional[str] = None
|
||||||
street_address_line_2: Optional[str] = None
|
street_address_line_2: Optional[str] = None
|
||||||
locality: Optional[str] = None
|
locality: Optional[str] = None
|
||||||
address_region: Optional[str] = None
|
address_region: Optional[str] = None
|
||||||
country_code: Optional[str] = None
|
country_code: Optional[str] = None
|
||||||
postal_code: Optional[str] = None
|
postal_code: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ContactModel(CustomBaseModel):
|
class ContactModel(CustomBaseModel):
|
||||||
email: Optional[EmailStr] = None
|
email: Optional[EmailStr] = None
|
||||||
first_name: Optional[str] = None
|
first_name: Optional[str] = None
|
||||||
last_name: Optional[str] = None
|
last_name: Optional[str] = None
|
||||||
phonenumber: Optional[str] = None
|
phonenumber: Optional[str] = None
|
||||||
vat_number: Optional[str] = None
|
vat_number: Optional[str] = None
|
||||||
|
|
||||||
address: ContactAddress
|
address: ContactAddress
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Database connection and session utilities
|
Database connection and session utilities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Annotated, Generator
|
from typing import Annotated, Generator
|
||||||
from sqlalchemy import create_engine, StaticPool, Connection
|
from sqlalchemy import create_engine, StaticPool, Connection
|
||||||
|
|
@ -30,7 +29,6 @@ else:
|
||||||
|
|
||||||
sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine)
|
sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_db_connection() -> Generator[Connection, None, None]:
|
def get_db_connection() -> Generator[Connection, None, None]:
|
||||||
with engine.connect() as connection:
|
with engine.connect() as connection:
|
||||||
|
|
@ -40,15 +38,12 @@ def get_db_connection() -> Generator[Connection, None, None]:
|
||||||
connection.rollback()
|
connection.rollback()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _get_db_connection() -> Generator[Connection, None]:
|
||||||
def _get_db_connection() -> Generator[Connection, None, None]:
|
|
||||||
with get_db_connection() as connection:
|
with get_db_connection() as connection:
|
||||||
yield connection
|
yield connection
|
||||||
|
|
||||||
|
|
||||||
DbConnection = Annotated[Connection, Depends(_get_db_connection)]
|
DbConnection = Annotated[Connection, Depends(_get_db_connection)]
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_db_session() -> Generator[Session, None, None]:
|
def get_db_session() -> Generator[Session, None, None]:
|
||||||
session = sm()
|
session = sm()
|
||||||
|
|
@ -61,9 +56,8 @@ def get_db_session() -> Generator[Session, None, None]:
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
def _get_db_session() -> Generator[Session, None, None]:
|
def _get_db_session() -> Generator[Session, None]:
|
||||||
with get_db_session() as session:
|
with get_db_session() as session:
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
|
||||||
DbSession = Annotated[Session, Depends(_get_db_session)]
|
DbSession = Annotated[Session, Depends(_get_db_session)]
|
||||||
|
|
|
||||||
|
|
@ -12,36 +12,36 @@ from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
|
||||||
class UnprocessableContentException(HTTPException):
|
class UnprocessableContentException(HTTPException):
|
||||||
def __init__(self, message: Optional[str] = None) -> None:
|
def __init__(self, message: Optional[str] = None) -> None:
|
||||||
detail = "Unprocessable content" if not message else message
|
detail = "Unprocessable content" if not message else message
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
|
||||||
detail=detail,
|
detail=detail,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConflictException(HTTPException):
|
class ConflictException(HTTPException):
|
||||||
def __init__(self, message: Optional[str] = None) -> None:
|
def __init__(self, message: Optional[str] = None) -> None:
|
||||||
detail = "Conflict" if not message else message
|
detail = "Conflict" if not message else message
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
detail=detail,
|
detail=detail,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ForbiddenException(HTTPException):
|
class ForbiddenException(HTTPException):
|
||||||
def __init__(self, message: Optional[str] = None) -> None:
|
def __init__(self, message: Optional[str] = None) -> None:
|
||||||
detail = "Forbidden" if not message else message
|
detail = "Forbidden" if not message else message
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=detail,
|
detail=detail,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UnauthorizedException(HTTPException):
|
class UnauthorizedException(HTTPException):
|
||||||
def __init__(self, message: Optional[str] = None) -> None:
|
def __init__(self, message: Optional[str] = None) -> None:
|
||||||
detail = "Not authorized" if not message else message
|
detail = "Not authorized" if not message else message
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=detail,
|
detail=detail,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -18,55 +18,59 @@ from src.iam.exceptions import GroupNotFoundException, PermNotFoundException
|
||||||
from src.iam.schemas import GroupIDMixin, PermIDMixin
|
from src.iam.schemas import GroupIDMixin, PermIDMixin
|
||||||
|
|
||||||
|
|
||||||
def get_group_model_query(db: DbSession, group_id: Annotated[int, Query(gt=0)]) -> Group:
|
def get_group_model_query(
|
||||||
group_model = db.get(Group, group_id)
|
db: DbSession, group_id: Annotated[int, Query(gt=0)]
|
||||||
if group_model is None:
|
) -> Group:
|
||||||
raise GroupNotFoundException(group_id)
|
group_model = db.get(Group, group_id)
|
||||||
|
if group_model is None:
|
||||||
|
raise GroupNotFoundException(group_id)
|
||||||
|
|
||||||
return group_model
|
return group_model
|
||||||
|
|
||||||
|
|
||||||
group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)]
|
group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)]
|
||||||
|
|
||||||
|
|
||||||
def get_group_model_body(
|
def get_group_model_body(
|
||||||
db: DbSession, request_model: Optional[GroupIDMixin] = None
|
db: DbSession, request_model: Optional[GroupIDMixin] = None
|
||||||
) -> Group:
|
) -> Group:
|
||||||
group_id = getattr(request_model, "group_id", None)
|
group_id = getattr(request_model, "group_id", None)
|
||||||
if group_id is None:
|
if group_id is None:
|
||||||
raise GroupNotFoundException()
|
raise GroupNotFoundException()
|
||||||
group_model = db.get(Group, group_id)
|
group_model = db.get(Group, group_id)
|
||||||
if group_model is None:
|
if group_model is None:
|
||||||
raise GroupNotFoundException(group_id)
|
raise GroupNotFoundException(group_id)
|
||||||
|
|
||||||
return group_model
|
return group_model
|
||||||
|
|
||||||
|
|
||||||
group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)]
|
group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)]
|
||||||
|
|
||||||
|
|
||||||
def get_perm_model_body(
|
def get_perm_model_body(
|
||||||
db: DbSession, request_model: Optional[PermIDMixin] = None
|
db: DbSession, request_model: Optional[PermIDMixin] = None
|
||||||
) -> Permission:
|
) -> Permission:
|
||||||
perm_id = getattr(request_model, "permission_id", None)
|
perm_id = getattr(request_model, "permission_id", None)
|
||||||
if perm_id is None:
|
if perm_id is None:
|
||||||
raise PermNotFoundException
|
raise PermNotFoundException
|
||||||
perm_model = db.get(Permission, perm_id)
|
perm_model = db.get(Permission, perm_id)
|
||||||
if perm_model is None:
|
if perm_model is None:
|
||||||
raise PermNotFoundException(perm_id)
|
raise PermNotFoundException(perm_id)
|
||||||
|
|
||||||
return perm_model
|
return perm_model
|
||||||
|
|
||||||
|
|
||||||
perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)]
|
perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)]
|
||||||
|
|
||||||
|
|
||||||
def get_perm_model_query(db: DbSession, perm_id: Annotated[int, Query(gt=0)]) -> Permission:
|
def get_perm_model_query(
|
||||||
perm_model = db.get(Permission, perm_id)
|
db: DbSession, perm_id: Annotated[int, Query(gt=0)]
|
||||||
if perm_model is None:
|
) -> Permission:
|
||||||
raise PermNotFoundException(perm_id)
|
perm_model = db.get(Permission, perm_id)
|
||||||
|
if perm_model is None:
|
||||||
|
raise PermNotFoundException(perm_id)
|
||||||
|
|
||||||
return perm_model
|
return perm_model
|
||||||
|
|
||||||
|
|
||||||
perm_model_query_dependency = Annotated[Permission, Depends(get_perm_model_query)]
|
perm_model_query_dependency = Annotated[Permission, Depends(get_perm_model_query)]
|
||||||
|
|
|
||||||
|
|
@ -12,26 +12,26 @@ from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
|
||||||
class GroupNotFoundException(HTTPException):
|
class GroupNotFoundException(HTTPException):
|
||||||
def __init__(self, group_id: Optional[int] = None) -> None:
|
def __init__(self, group_id: Optional[int] = None) -> None:
|
||||||
detail = (
|
detail = (
|
||||||
"Group not found"
|
"Group not found"
|
||||||
if group_id is None
|
if group_id is None
|
||||||
else f"User with ID '{group_id}' was not found."
|
else f"User with ID '{group_id}' was not found."
|
||||||
)
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=detail,
|
detail=detail,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class PermNotFoundException(HTTPException):
|
class PermNotFoundException(HTTPException):
|
||||||
def __init__(self, perm_id: Optional[int] = None) -> None:
|
def __init__(self, perm_id: Optional[int] = None) -> None:
|
||||||
detail = (
|
detail = (
|
||||||
"Permission not found"
|
"Permission not found"
|
||||||
if perm_id is None
|
if perm_id is None
|
||||||
else f"User with ID '{perm_id}' was not found."
|
else f"User with ID '{perm_id}' was not found."
|
||||||
)
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=detail,
|
detail=detail,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -21,94 +21,95 @@ Models:
|
||||||
from sqlalchemy import ForeignKey, UniqueConstraint
|
from sqlalchemy import ForeignKey, UniqueConstraint
|
||||||
from sqlalchemy.orm import relationship, mapped_column, Mapped
|
from sqlalchemy.orm import relationship, mapped_column, Mapped
|
||||||
|
|
||||||
from src.models import CustomBase, IdMixin
|
from src.models import CustomBase
|
||||||
|
|
||||||
|
|
||||||
class Permission(CustomBase, IdMixin):
|
class Permission(CustomBase):
|
||||||
__tablename__ = "permission"
|
__tablename__ = "permission"
|
||||||
|
|
||||||
resource: Mapped[str]
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
action: Mapped[str]
|
resource: Mapped[str]
|
||||||
|
action: Mapped[str]
|
||||||
|
|
||||||
service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE"))
|
service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE"))
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint(
|
UniqueConstraint(
|
||||||
"service_id",
|
"service_id",
|
||||||
"resource",
|
"resource",
|
||||||
"action",
|
"action",
|
||||||
name="uniq_permission_resource_and_action",
|
name="uniq_permission_resource_and_action",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
service_rel = relationship(
|
service_rel = relationship(
|
||||||
"Service",
|
"Service",
|
||||||
back_populates="permission_rel",
|
back_populates="permission_rel",
|
||||||
foreign_keys="Permission.service_id",
|
foreign_keys="Permission.service_id",
|
||||||
)
|
)
|
||||||
|
|
||||||
group_rel = relationship(
|
group_rel = relationship(
|
||||||
"Group", secondary="group_permissions", back_populates="permission_rel"
|
"Group", secondary="group_permissions", back_populates="permission_rel"
|
||||||
)
|
)
|
||||||
|
|
||||||
org_rel = relationship(
|
org_rel = relationship(
|
||||||
"Organisation", secondary="org_permissions", back_populates="permission_rel"
|
"Organisation", secondary="org_permissions", back_populates="permission_rel"
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def service_name(self):
|
def service_name(self):
|
||||||
return self.service_rel.name
|
return self.service_rel.name
|
||||||
|
|
||||||
|
|
||||||
class Group(CustomBase, IdMixin):
|
class Group(CustomBase):
|
||||||
__tablename__ = "group"
|
__tablename__ = "group"
|
||||||
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
|
name: Mapped[str]
|
||||||
|
|
||||||
name: Mapped[str]
|
org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE"))
|
||||||
|
|
||||||
org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE"))
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"name",
|
||||||
|
"org_id",
|
||||||
|
name="uniq_group_name_org_id",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
__table_args__ = (
|
user_rel = relationship("User", secondary="user_groups", back_populates="group_rel")
|
||||||
UniqueConstraint(
|
|
||||||
"name",
|
|
||||||
"org_id",
|
|
||||||
name="uniq_group_name_org_id",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
user_rel = relationship("User", secondary="user_groups", back_populates="group_rel")
|
org_rel = relationship("Organisation", back_populates="group_rel")
|
||||||
|
|
||||||
org_rel = relationship("Organisation", back_populates="group_rel")
|
permission_rel = relationship(
|
||||||
|
"Permission", secondary="group_permissions", back_populates="group_rel"
|
||||||
permission_rel = relationship(
|
)
|
||||||
"Permission", secondary="group_permissions", back_populates="group_rel"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GroupPermissions(CustomBase):
|
class GroupPermissions(CustomBase):
|
||||||
__tablename__ = "group_permissions"
|
__tablename__ = "group_permissions"
|
||||||
group_id: Mapped[int] = mapped_column(
|
group_id: Mapped[int] = mapped_column(
|
||||||
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
|
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
|
||||||
)
|
)
|
||||||
permission_id: Mapped[int] = mapped_column(
|
permission_id: Mapped[int] = mapped_column(
|
||||||
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
|
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UserGroups(CustomBase):
|
class UserGroups(CustomBase):
|
||||||
__tablename__ = "user_groups"
|
__tablename__ = "user_groups"
|
||||||
user_id: Mapped[int] = mapped_column(
|
user_id: Mapped[int] = mapped_column(
|
||||||
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
|
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
|
||||||
)
|
)
|
||||||
group_id: Mapped[int] = mapped_column(
|
group_id: Mapped[int] = mapped_column(
|
||||||
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
|
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OrgPermissions(CustomBase):
|
class OrgPermissions(CustomBase):
|
||||||
__tablename__ = "org_permissions"
|
__tablename__ = "org_permissions"
|
||||||
org_id: Mapped[int] = mapped_column(
|
org_id: Mapped[int] = mapped_column(
|
||||||
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
|
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
|
||||||
)
|
)
|
||||||
permission_id: Mapped[int] = mapped_column(
|
permission_id: Mapped[int] = mapped_column(
|
||||||
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
|
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
1050
src/iam/router.py
1050
src/iam/router.py
File diff suppressed because it is too large
Load diff
|
|
@ -11,151 +11,151 @@ from typing import Optional, Annotated
|
||||||
from pydantic import EmailStr, ConfigDict, Field
|
from pydantic import EmailStr, ConfigDict, Field
|
||||||
|
|
||||||
from src.schemas import (
|
from src.schemas import (
|
||||||
CustomBaseModel,
|
CustomBaseModel,
|
||||||
ResourceName,
|
ResourceName,
|
||||||
ServiceIDMixin,
|
ServiceIDMixin,
|
||||||
OrgIDMixin,
|
OrgIDMixin,
|
||||||
UserIDMixin,
|
UserIDMixin,
|
||||||
PermIDMixin,
|
PermIDMixin,
|
||||||
GroupIDMixin,
|
GroupIDMixin,
|
||||||
GroupSummary,
|
GroupSummary,
|
||||||
OrgSummary,
|
OrgSummary,
|
||||||
UserSummary,
|
UserSummary,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UserSchema(CustomBaseModel):
|
class UserSchema(CustomBaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
first_name: str
|
first_name: str
|
||||||
last_name: str
|
last_name: str
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
|
|
||||||
|
|
||||||
class PermissionSchema(CustomBaseModel):
|
class PermissionSchema(CustomBaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
service_name: str
|
service_name: str
|
||||||
resource: str
|
resource: str
|
||||||
action: str
|
action: str
|
||||||
|
|
||||||
|
|
||||||
class GroupDetails(CustomBaseModel):
|
class GroupDetails(CustomBaseModel):
|
||||||
details: GroupSummary
|
details: GroupSummary
|
||||||
permissions: list[PermissionSchema]
|
permissions: list[PermissionSchema]
|
||||||
|
|
||||||
|
|
||||||
class IAMCAoRRequest(CustomBaseModel):
|
class IAMCAoRRequest(CustomBaseModel):
|
||||||
action: str
|
action: str
|
||||||
rn: ResourceName
|
rn: ResourceName
|
||||||
|
|
||||||
|
|
||||||
class IAMCAoRResponse(CustomBaseModel):
|
class IAMCAoRResponse(CustomBaseModel):
|
||||||
allowed: bool
|
allowed: bool
|
||||||
user: UserSummary
|
user: UserSummary
|
||||||
action: str
|
action: str
|
||||||
rn: ResourceName
|
rn: ResourceName
|
||||||
|
|
||||||
|
|
||||||
class IAMGetGroupPermissionsResponse(CustomBaseModel):
|
class IAMGetGroupPermissionsResponse(CustomBaseModel):
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
group: GroupSummary
|
group: GroupSummary
|
||||||
permissions: list[PermissionSchema]
|
permissions: list[PermissionSchema]
|
||||||
|
|
||||||
|
|
||||||
class IAMGetGroupUsersResponse(CustomBaseModel):
|
class IAMGetGroupUsersResponse(CustomBaseModel):
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
group: GroupSummary
|
group: GroupSummary
|
||||||
users: list[UserSummary]
|
users: list[UserSummary]
|
||||||
|
|
||||||
|
|
||||||
class IAMPostGroupRequest(OrgIDMixin):
|
class IAMPostGroupRequest(OrgIDMixin):
|
||||||
name: str = Field(min_length=3)
|
name: str = Field(min_length=3)
|
||||||
|
|
||||||
|
|
||||||
class IAMPostGroupResponse(CustomBaseModel):
|
class IAMPostGroupResponse(CustomBaseModel):
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
group: GroupSummary
|
group: GroupSummary
|
||||||
|
|
||||||
|
|
||||||
class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin):
|
class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class IAMPutGroupPermissionResponse(CustomBaseModel):
|
class IAMPutGroupPermissionResponse(CustomBaseModel):
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
group: GroupSummary
|
group: GroupSummary
|
||||||
permissions: list[PermissionSchema]
|
permissions: list[PermissionSchema]
|
||||||
|
|
||||||
|
|
||||||
class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin):
|
class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class IAMPutGroupUserResponse(CustomBaseModel):
|
class IAMPutGroupUserResponse(CustomBaseModel):
|
||||||
group: GroupSummary
|
group: GroupSummary
|
||||||
users: list[UserSchema]
|
users: list[UserSchema]
|
||||||
|
|
||||||
|
|
||||||
class IAMDeleteGroupPermissionResponse(CustomBaseModel):
|
class IAMDeleteGroupPermissionResponse(CustomBaseModel):
|
||||||
group: GroupSummary
|
group: GroupSummary
|
||||||
permissions: list[PermissionSchema]
|
permissions: list[PermissionSchema]
|
||||||
|
|
||||||
|
|
||||||
class IAMDeleteGroupUserResponse(CustomBaseModel):
|
class IAMDeleteGroupUserResponse(CustomBaseModel):
|
||||||
group: GroupSummary
|
group: GroupSummary
|
||||||
users: list[UserSchema]
|
users: list[UserSchema]
|
||||||
|
|
||||||
|
|
||||||
class IAMGetPermissionsResponse(CustomBaseModel):
|
class IAMGetPermissionsResponse(CustomBaseModel):
|
||||||
permissions: list[PermissionSchema]
|
permissions: list[PermissionSchema]
|
||||||
|
|
||||||
|
|
||||||
class IAMPostPermissionRequest(ServiceIDMixin):
|
class IAMPostPermissionRequest(ServiceIDMixin):
|
||||||
resource: str
|
resource: str
|
||||||
action: str
|
action: str
|
||||||
|
|
||||||
|
|
||||||
class IAMPostPermissionResponse(CustomBaseModel):
|
class IAMPostPermissionResponse(CustomBaseModel):
|
||||||
permission: PermissionSchema
|
permission: PermissionSchema
|
||||||
|
|
||||||
|
|
||||||
class IAMGetPermissionsSearchRequest(OrgIDMixin):
|
class IAMGetPermissionsSearchRequest(OrgIDMixin):
|
||||||
service_id: Annotated[int | None, Field(gt=0)] = None
|
service_id: Annotated[int | None, Field(gt=0)] = None
|
||||||
resource: Optional[str] = None
|
resource: Optional[str] = None
|
||||||
action: Optional[str] = None
|
action: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class IAMGetPermissionsSearchResponse(CustomBaseModel):
|
class IAMGetPermissionsSearchResponse(CustomBaseModel):
|
||||||
permissions: list[PermissionSchema]
|
permissions: list[PermissionSchema]
|
||||||
|
|
||||||
|
|
||||||
class IAMPutGroupInvitationRequest(OrgIDMixin, GroupIDMixin):
|
class IAMPutGroupInvitationRequest(OrgIDMixin, GroupIDMixin):
|
||||||
user_email: EmailStr
|
user_email: EmailStr
|
||||||
|
|
||||||
|
|
||||||
class IAMPutGroupInvitationResponse(CustomBaseModel):
|
class IAMPutGroupInvitationResponse(CustomBaseModel):
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
group: GroupSummary
|
group: GroupSummary
|
||||||
invited_email: EmailStr
|
invited_email: EmailStr
|
||||||
|
|
||||||
|
|
||||||
class IAMPutGroupInvitationAcceptRequest(CustomBaseModel):
|
class IAMPutGroupInvitationAcceptRequest(CustomBaseModel):
|
||||||
jwt: str
|
jwt: str
|
||||||
|
|
||||||
|
|
||||||
class IAMPutGroupInvitationAcceptResponse(CustomBaseModel):
|
class IAMPutGroupInvitationAcceptResponse(CustomBaseModel):
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
user: UserSummary
|
user: UserSummary
|
||||||
group: GroupDetails
|
group: GroupDetails
|
||||||
|
|
||||||
|
|
||||||
class IAMPutOrgPermissionsRequest(OrgIDMixin):
|
class IAMPutOrgPermissionsRequest(OrgIDMixin):
|
||||||
permissions: list[int]
|
permissions: list[int]
|
||||||
|
|
||||||
|
|
||||||
class IAMPutOrgPermissionsResponse(CustomBaseModel):
|
class IAMPutOrgPermissionsResponse(CustomBaseModel):
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
permissions: list[PermissionSchema]
|
permissions: list[PermissionSchema]
|
||||||
|
|
|
||||||
|
|
@ -23,90 +23,90 @@ from src.service.schemas import HasServiceName
|
||||||
|
|
||||||
|
|
||||||
def valid_service_key(
|
def valid_service_key(
|
||||||
db: DbSession, request: Request, request_model: HasServiceName
|
db: DbSession, request: Request, request_model: HasServiceName
|
||||||
) -> bool:
|
) -> bool:
|
||||||
rn = request_model.rn
|
rn = request_model.rn
|
||||||
api_key = request.headers.get("X-API-Key", None)
|
api_key = request.headers.get("X-API-Key", None)
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise UnauthorizedException("Missing API key")
|
raise UnauthorizedException("Missing API key")
|
||||||
service = rn.service
|
service = rn.service
|
||||||
result = (
|
result = (
|
||||||
db.query(Service)
|
db.query(Service)
|
||||||
.filter(Service.name == service)
|
.filter(Service.name == service)
|
||||||
.filter(Service.api_key == api_key)
|
.filter(Service.api_key == api_key)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if result is None:
|
if result is None:
|
||||||
raise UnauthorizedException("Invalid API key")
|
raise UnauthorizedException("Invalid API key")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
service_key_dependency = Annotated[bool, Depends(valid_service_key)]
|
service_key_dependency = Annotated[bool, Depends(valid_service_key)]
|
||||||
|
|
||||||
|
|
||||||
async def send_user_group_invitation(
|
async def send_user_group_invitation(
|
||||||
user_email: str, org_name: str, org_id: int, group_id: int, group_name: str
|
user_email: str, org_name: str, org_id: int, group_id: int, group_name: str
|
||||||
):
|
):
|
||||||
expiry_delta = timedelta(hours=24)
|
expiry_delta = timedelta(hours=24)
|
||||||
expiry = datetime.now(timezone.utc) + expiry_delta
|
expiry = datetime.now(timezone.utc) + expiry_delta
|
||||||
claims = {
|
claims = {
|
||||||
"email": user_email,
|
"email": user_email,
|
||||||
"org_id": org_id,
|
"org_id": org_id,
|
||||||
"group_id": group_id,
|
"group_id": group_id,
|
||||||
"group_name": group_name,
|
"group_name": group_name,
|
||||||
"exp": expiry,
|
"exp": expiry,
|
||||||
"type": "group_invite",
|
"type": "group_invite",
|
||||||
}
|
}
|
||||||
|
|
||||||
token = await generate_jwt(claims)
|
token = await generate_jwt(claims)
|
||||||
subject = f"You have been invited to join a group of {org_name}"
|
subject = f"You have been invited to join a group of {org_name}"
|
||||||
body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
|
body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
|
||||||
|
|
||||||
await send_email(
|
await send_email(
|
||||||
recipient=user_email,
|
recipient=user_email,
|
||||||
subject=subject,
|
subject=subject,
|
||||||
body=body,
|
body=body,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def create_group_and_assign_perms(
|
async def create_group_and_assign_perms(
|
||||||
db: Session, org_model: Org, group_name: str, perm_list: list[int]
|
db: Session, org_model: Org, group_name: str, perm_list: list[int]
|
||||||
):
|
):
|
||||||
new_group = Group(name=group_name, org_id=org_model.id)
|
new_group = Group(name=group_name, org_id=org_model.id)
|
||||||
db.add(new_group)
|
db.add(new_group)
|
||||||
db.flush()
|
db.flush()
|
||||||
|
|
||||||
for permission in perm_list:
|
for permission in perm_list:
|
||||||
perm_model = db.get(Perm, permission)
|
perm_model = db.get(Perm, permission)
|
||||||
|
|
||||||
if perm_model is None:
|
if perm_model is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_group.permission_rel.append(perm_model)
|
new_group.permission_rel.append(perm_model)
|
||||||
db.flush()
|
db.flush()
|
||||||
|
|
||||||
return new_group
|
return new_group
|
||||||
|
|
||||||
|
|
||||||
async def assign_default_group(
|
async def assign_default_group(
|
||||||
db: DbSession,
|
db: DbSession,
|
||||||
org_model: Org,
|
org_model: Org,
|
||||||
user_model: User,
|
user_model: User,
|
||||||
group_name: str,
|
group_name: str,
|
||||||
perm_list: list[int],
|
perm_list: list[int],
|
||||||
):
|
):
|
||||||
group_model = (
|
group_model = (
|
||||||
db.query(Group)
|
db.query(Group)
|
||||||
.filter(Group.org_id == org_model.id)
|
.filter(Group.org_id == org_model.id)
|
||||||
.filter(Group.name == group_name)
|
.filter(Group.name == group_name)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
if group_model is None:
|
if group_model is None:
|
||||||
group_model = await create_group_and_assign_perms(
|
group_model = await create_group_and_assign_perms(
|
||||||
db=db, group_name=group_name, org_model=org_model, perm_list=perm_list
|
db=db, group_name=group_name, org_model=org_model, perm_list=perm_list
|
||||||
)
|
)
|
||||||
|
|
||||||
user_model.group_rel.append(group_model)
|
user_model.group_rel.append(group_model)
|
||||||
db.flush()
|
db.flush()
|
||||||
|
|
|
||||||
72
src/main.py
72
src/main.py
|
|
@ -2,7 +2,6 @@
|
||||||
Application root file: Inits the FastAPI application
|
Application root file: Inits the FastAPI application
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os.path
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
|
@ -20,43 +19,43 @@ from src.auth.service import get_current_user, get_dev_user
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_application: FastAPI) -> AsyncGenerator:
|
async def lifespan(_application: FastAPI) -> AsyncGenerator:
|
||||||
# Startup
|
# Startup
|
||||||
yield
|
yield
|
||||||
# Shutdown
|
# Shutdown
|
||||||
|
|
||||||
|
|
||||||
if settings.ENVIRONMENT.is_deployed:
|
if settings.ENVIRONMENT.is_deployed:
|
||||||
# Just a precaution, should be False anyway
|
# Just a precaution, should be False anyway
|
||||||
settings.DISABLE_AUTH = False
|
settings.DISABLE_AUTH = False
|
||||||
|
|
||||||
|
|
||||||
tags_metadata = [
|
tags_metadata = [
|
||||||
{
|
{
|
||||||
"name": "User",
|
"name": "User",
|
||||||
"description": "User related operations, includes getting information about the current user",
|
"description": "User related operations, includes getting information about the current user",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Organisation",
|
"name": "Organisation",
|
||||||
"description": "Organisation related operations, includes getting lists of users etc associated with orgs",
|
"description": "Organisation related operations, includes getting lists of users etc associated with orgs",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Service",
|
"name": "Service",
|
||||||
"description": "Services related operations, includes registering services and reissuing API keys",
|
"description": "Services related operations, includes registering services and reissuing API keys",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "IAM",
|
"name": "IAM",
|
||||||
"description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.",
|
"description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
swagger_ui_init_oauth={
|
swagger_ui_init_oauth={
|
||||||
"clientId": auth_settings.CLIENT_ID,
|
"clientId": auth_settings.CLIENT_ID,
|
||||||
"usePkceWithAuthorizationCodeGrant": True,
|
"usePkceWithAuthorizationCodeGrant": True,
|
||||||
"scopes": "openid profile email",
|
"scopes": "openid profile email",
|
||||||
},
|
},
|
||||||
openapi_tags=tags_metadata,
|
openapi_tags=tags_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Type inspection disabled for middleware injection.
|
# Type inspection disabled for middleware injection.
|
||||||
|
|
@ -65,19 +64,16 @@ app = FastAPI(
|
||||||
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value())
|
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value())
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=settings.CORS_ORIGINS,
|
allow_origins=settings.CORS_ORIGINS,
|
||||||
allow_origin_regex=settings.CORS_ORIGINS_REGEX,
|
allow_origin_regex=settings.CORS_ORIGINS_REGEX,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
|
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
|
||||||
allow_headers=settings.CORS_HEADERS,
|
allow_headers=settings.CORS_HEADERS,
|
||||||
)
|
)
|
||||||
|
|
||||||
if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL):
|
if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL):
|
||||||
app.dependency_overrides[get_current_user] = get_dev_user
|
app.dependency_overrides[get_current_user] = get_dev_user
|
||||||
|
|
||||||
|
|
||||||
app.include_router(api_router)
|
app.include_router(api_router)
|
||||||
|
|
||||||
if os.path.exists("/app/static"):
|
|
||||||
app.frontend("/ui", directory="/app/static", fallback="index.html")
|
|
||||||
|
|
|
||||||
|
|
@ -5,33 +5,12 @@ Global database models
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import DateTime, JSON, func
|
from sqlalchemy import DateTime, JSON
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
|
||||||
class CustomBase(DeclarativeBase):
|
class CustomBase(DeclarativeBase):
|
||||||
type_annotation_map = {
|
type_annotation_map = {
|
||||||
datetime: DateTime(timezone=True),
|
datetime: DateTime(timezone=True),
|
||||||
dict[str, Any]: JSON,
|
dict[str, Any]: JSON,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ActivatedMixin:
|
|
||||||
active: Mapped[bool] = mapped_column(default=True)
|
|
||||||
|
|
||||||
|
|
||||||
class DeletedTimestampMixin:
|
|
||||||
deleted_at: Mapped[datetime | None] = mapped_column(nullable=True)
|
|
||||||
|
|
||||||
|
|
||||||
class DescriptionMixin:
|
|
||||||
description: Mapped[str]
|
|
||||||
|
|
||||||
|
|
||||||
class IdMixin:
|
|
||||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
|
||||||
|
|
||||||
|
|
||||||
class TimestampMixin:
|
|
||||||
created_at: Mapped[datetime] = mapped_column(default=func.now())
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now())
|
|
||||||
|
|
|
||||||
|
|
@ -10,48 +10,48 @@ from enum import StrEnum, auto
|
||||||
|
|
||||||
|
|
||||||
class Status(StrEnum):
|
class Status(StrEnum):
|
||||||
"""
|
"""
|
||||||
Enumeration of organisation statuses.
|
Enumeration of organisation statuses.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted.
|
PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted.
|
||||||
SUBMITTED (str): Questionnaire submitted but not approved.
|
SUBMITTED (str): Questionnaire submitted but not approved.
|
||||||
REMEDIATION (str): Questionnaire submitted but requires revisions.
|
REMEDIATION (str): Questionnaire submitted but requires revisions.
|
||||||
APPROVED (str): Questionnaire has been approved by an admin.
|
APPROVED (str): Questionnaire has been approved by an admin.
|
||||||
REJECTED (str): Questionnaire has been rejected by an admin.
|
REJECTED (str): Questionnaire has been rejected by an admin.
|
||||||
REMOVED (str): Organisation has been removed.
|
REMOVED (str): Organisation has been removed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PARTIAL = auto()
|
PARTIAL = auto()
|
||||||
SUBMITTED = auto()
|
SUBMITTED = auto()
|
||||||
REMEDIATION = auto()
|
REMEDIATION = auto()
|
||||||
APPROVED = auto()
|
APPROVED = auto()
|
||||||
REJECTED = auto()
|
REJECTED = auto()
|
||||||
REMOVED = auto()
|
REMOVED = auto()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_pre_approval(self):
|
def is_pre_approval(self):
|
||||||
return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION)
|
return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_pre_submission(self):
|
def is_pre_submission(self):
|
||||||
return self in (self.PARTIAL, self.REMEDIATION)
|
return self in (self.PARTIAL, self.REMEDIATION)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_blocked(self):
|
def is_blocked(self):
|
||||||
return self in (self.REMOVED, self.REJECTED)
|
return self in (self.REMOVED, self.REJECTED)
|
||||||
|
|
||||||
|
|
||||||
class ContactType(StrEnum):
|
class ContactType(StrEnum):
|
||||||
"""
|
"""
|
||||||
Enumeration of organisation contact types.
|
Enumeration of organisation contact types.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
BILLING(str): Billing contact.
|
BILLING(str): Billing contact.
|
||||||
SECURITY (str): Security contact.
|
SECURITY (str): Security contact.
|
||||||
OWNER (str): Owner contact.
|
OWNER (str): Owner contact.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
BILLING = auto()
|
BILLING = auto()
|
||||||
SECURITY = auto()
|
SECURITY = auto()
|
||||||
OWNER = auto()
|
OWNER = auto()
|
||||||
|
|
|
||||||
|
|
@ -18,25 +18,25 @@ from src.organisation.exceptions import OrgNotFoundException
|
||||||
|
|
||||||
|
|
||||||
def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org:
|
def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org:
|
||||||
org_model = db.get(Org, org_id)
|
org_model = db.get(Org, org_id)
|
||||||
if org_model is None:
|
if org_model is None:
|
||||||
raise OrgNotFoundException(org_id)
|
raise OrgNotFoundException(org_id)
|
||||||
return org_model
|
return org_model
|
||||||
|
|
||||||
|
|
||||||
org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)]
|
org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)]
|
||||||
|
|
||||||
|
|
||||||
def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org:
|
def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org:
|
||||||
org_id: Optional[int] = getattr(request_model, "organisation_id", None)
|
org_id: Optional[int] = getattr(request_model, "organisation_id", None)
|
||||||
if org_id is None:
|
if org_id is None:
|
||||||
raise OrgNotFoundException()
|
raise OrgNotFoundException()
|
||||||
|
|
||||||
org_model = db.get(Org, org_id)
|
org_model = db.get(Org, org_id)
|
||||||
if org_model is None:
|
if org_model is None:
|
||||||
raise OrgNotFoundException(org_id)
|
raise OrgNotFoundException(org_id)
|
||||||
|
|
||||||
return org_model
|
return org_model
|
||||||
|
|
||||||
|
|
||||||
org_model_body_dependency = Annotated[Org, Depends(get_org_model_body)]
|
org_model_body_dependency = Annotated[Org, Depends(get_org_model_body)]
|
||||||
|
|
|
||||||
|
|
@ -12,26 +12,26 @@ from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
|
||||||
class OrgNotFoundException(HTTPException):
|
class OrgNotFoundException(HTTPException):
|
||||||
def __init__(self, org_id: Optional[int] = None) -> None:
|
def __init__(self, org_id: Optional[int] = None) -> None:
|
||||||
detail = (
|
detail = (
|
||||||
"Organisation not found"
|
"Organisation not found"
|
||||||
if org_id is None
|
if org_id is None
|
||||||
else f"Organisation with ID '{org_id}' was not found."
|
else f"Organisation with ID '{org_id}' was not found."
|
||||||
)
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=detail,
|
detail=detail,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AwaitingApprovalException(HTTPException):
|
class AwaitingApprovalException(HTTPException):
|
||||||
def __init__(self, org_id: Optional[int] = None) -> None:
|
def __init__(self, org_id: Optional[int] = None) -> None:
|
||||||
detail = (
|
detail = (
|
||||||
"Organisation has not been approved."
|
"Organisation has not been approved."
|
||||||
if org_id is None
|
if org_id is None
|
||||||
else f"Organisation with ID '{org_id}' has not been approved."
|
else f"Organisation with ID '{org_id}' has not been approved."
|
||||||
)
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail=detail,
|
detail=detail,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -14,62 +14,61 @@ Models:
|
||||||
- OrgUsers: org_id[FK][PK], user_id[FK][PK]
|
- OrgUsers: org_id[FK][PK], user_id[FK][PK]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.models import IdMixin, DeletedTimestampMixin
|
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import ForeignKey
|
from sqlalchemy import ForeignKey
|
||||||
from sqlalchemy.orm import relationship, Mapped, mapped_column
|
from sqlalchemy.orm import relationship, Mapped, mapped_column
|
||||||
|
|
||||||
from src.models import CustomBase, TimestampMixin
|
from src.models import CustomBase
|
||||||
|
|
||||||
|
|
||||||
class Organisation(CustomBase, IdMixin, TimestampMixin, DeletedTimestampMixin):
|
class Organisation(CustomBase):
|
||||||
__tablename__ = "organisation"
|
__tablename__ = "organisation"
|
||||||
|
|
||||||
name: Mapped[str]
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
status: Mapped[str] = mapped_column(default="partial")
|
name: Mapped[str]
|
||||||
intake_questionnaire: Mapped[dict[str, Any] | None]
|
status: Mapped[str] = mapped_column(default="partial")
|
||||||
|
intake_questionnaire: Mapped[dict[str, Any] | None]
|
||||||
|
|
||||||
root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
|
root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
|
||||||
billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
|
billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
|
||||||
security_contact_id: Mapped[int] = mapped_column(
|
security_contact_id: Mapped[int] = mapped_column(
|
||||||
ForeignKey("contact.id"), nullable=True
|
ForeignKey("contact.id"), nullable=True
|
||||||
)
|
)
|
||||||
owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
|
owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
|
||||||
|
|
||||||
user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel")
|
user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel")
|
||||||
|
|
||||||
group_rel = relationship(
|
group_rel = relationship(
|
||||||
"Group", back_populates="org_rel", cascade="all, delete-orphan"
|
"Group", back_populates="org_rel", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id")
|
root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id")
|
||||||
|
|
||||||
billing_contact_rel = relationship(
|
billing_contact_rel = relationship(
|
||||||
"Contact", foreign_keys="Organisation.billing_contact_id"
|
"Contact", foreign_keys="Organisation.billing_contact_id"
|
||||||
)
|
)
|
||||||
security_contact_rel = relationship(
|
security_contact_rel = relationship(
|
||||||
"Contact", foreign_keys="Organisation.security_contact_id"
|
"Contact", foreign_keys="Organisation.security_contact_id"
|
||||||
)
|
)
|
||||||
owner_contact_rel = relationship(
|
owner_contact_rel = relationship(
|
||||||
"Contact", foreign_keys="Organisation.owner_contact_id"
|
"Contact", foreign_keys="Organisation.owner_contact_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
permission_rel = relationship(
|
permission_rel = relationship(
|
||||||
"Permission", secondary="org_permissions", back_populates="org_rel"
|
"Permission", secondary="org_permissions", back_populates="org_rel"
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def root_user_email(self) -> str:
|
def root_user_email(self) -> str:
|
||||||
return self.root_user_rel.email if self.root_user_rel else ""
|
return self.root_user_rel.email if self.root_user_rel else ""
|
||||||
|
|
||||||
|
|
||||||
class OrgUsers(CustomBase):
|
class OrgUsers(CustomBase):
|
||||||
__tablename__ = "orgusers"
|
__tablename__ = "orgusers"
|
||||||
|
|
||||||
org_id: Mapped[int] = mapped_column(
|
org_id: Mapped[int] = mapped_column(
|
||||||
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
|
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
|
||||||
)
|
)
|
||||||
user_id: Mapped[int] = mapped_column(
|
user_id: Mapped[int] = mapped_column(
|
||||||
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
|
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -12,139 +12,139 @@ from datetime import datetime
|
||||||
from pydantic import EmailStr, ConfigDict, Field
|
from pydantic import EmailStr, ConfigDict, Field
|
||||||
|
|
||||||
from src.schemas import (
|
from src.schemas import (
|
||||||
CustomBaseModel,
|
CustomBaseModel,
|
||||||
OrgIDMixin,
|
OrgIDMixin,
|
||||||
UserIDMixin,
|
UserIDMixin,
|
||||||
GroupSummary,
|
GroupSummary,
|
||||||
OrgSummary,
|
OrgSummary,
|
||||||
UserSummary,
|
UserSummary,
|
||||||
)
|
)
|
||||||
from src.contact.schemas import ContactModel
|
from src.contact.schemas import ContactModel
|
||||||
|
|
||||||
from src.organisation.constants import Status, ContactType
|
from src.organisation.constants import Status, ContactType
|
||||||
from src.organisation.schemas_questionnaires import (
|
from src.organisation.schemas_questionnaires import (
|
||||||
QuestionnaireQuestionsVersion0 as CurrentQuestions,
|
QuestionnaireQuestionsVersion0 as CurrentQuestions,
|
||||||
questionnaire_union,
|
questionnaire_union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class QuestionnaireMetadata(CustomBaseModel):
|
class QuestionnaireMetadata(CustomBaseModel):
|
||||||
version: int
|
version: int
|
||||||
submission_date: Optional[datetime] = None
|
submission_date: Optional[datetime] = None
|
||||||
|
|
||||||
|
|
||||||
class Questionnaire(CustomBaseModel):
|
class Questionnaire(CustomBaseModel):
|
||||||
metadata: QuestionnaireMetadata
|
metadata: QuestionnaireMetadata
|
||||||
questions: questionnaire_union
|
questions: questionnaire_union
|
||||||
|
|
||||||
|
|
||||||
class ContactSummary(CustomBaseModel):
|
class ContactSummary(CustomBaseModel):
|
||||||
id: int
|
id: int
|
||||||
email: Optional[EmailStr] = None
|
email: Optional[EmailStr] = None
|
||||||
|
|
||||||
|
|
||||||
class OrgSchema(OrgIDMixin):
|
class OrgSchema(OrgIDMixin):
|
||||||
name: str
|
name: str
|
||||||
status: Status
|
status: Status
|
||||||
root_user_email: EmailStr
|
root_user_email: EmailStr
|
||||||
intake_questionnaire: Optional[Questionnaire] = None
|
intake_questionnaire: Optional[Questionnaire] = None
|
||||||
|
|
||||||
billing_contact: ContactSummary
|
billing_contact: ContactSummary
|
||||||
owner_contact: ContactSummary
|
owner_contact: ContactSummary
|
||||||
security_contact: ContactSummary
|
security_contact: ContactSummary
|
||||||
|
|
||||||
|
|
||||||
class OrgPostOrgRequest(CustomBaseModel):
|
class OrgPostOrgRequest(CustomBaseModel):
|
||||||
name: str = Field(min_length=3)
|
name: str = Field(min_length=3)
|
||||||
intake_questionnaire: Optional[CurrentQuestions] = None
|
intake_questionnaire: Optional[CurrentQuestions] = None
|
||||||
|
|
||||||
|
|
||||||
class OrgPostOrgResponse(CustomBaseModel):
|
class OrgPostOrgResponse(CustomBaseModel):
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
status: Status
|
status: Status
|
||||||
|
|
||||||
|
|
||||||
class OrgPatchQuestionnaireRequest(OrgIDMixin):
|
class OrgPatchQuestionnaireRequest(OrgIDMixin):
|
||||||
intake_questionnaire: CurrentQuestions
|
intake_questionnaire: CurrentQuestions
|
||||||
partial: bool
|
partial: bool
|
||||||
|
|
||||||
|
|
||||||
class OrgPatchQuestionnaireResponse(CustomBaseModel):
|
class OrgPatchQuestionnaireResponse(CustomBaseModel):
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
intake_questionnaire: Questionnaire
|
intake_questionnaire: Questionnaire
|
||||||
status: Status
|
status: Status
|
||||||
|
|
||||||
|
|
||||||
class OrgPatchStatusRequest(OrgIDMixin):
|
class OrgPatchStatusRequest(OrgIDMixin):
|
||||||
status: Status
|
status: Status
|
||||||
|
|
||||||
|
|
||||||
class OrgPatchStatusResponse(CustomBaseModel):
|
class OrgPatchStatusResponse(CustomBaseModel):
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
status: Status
|
status: Status
|
||||||
|
|
||||||
|
|
||||||
class OrgPatchContactRequest(OrgIDMixin):
|
class OrgPatchContactRequest(OrgIDMixin):
|
||||||
contact_type: ContactType
|
contact_type: ContactType
|
||||||
|
|
||||||
email: Optional[EmailStr] = None
|
email: Optional[EmailStr] = None
|
||||||
first_name: Optional[str] = None
|
first_name: Optional[str] = None
|
||||||
last_name: Optional[str] = None
|
last_name: Optional[str] = None
|
||||||
phonenumber: Optional[str] = None
|
phonenumber: Optional[str] = None
|
||||||
vat_number: Optional[str] = None
|
vat_number: Optional[str] = None
|
||||||
post_office_box_number: Optional[str] = None
|
post_office_box_number: Optional[str] = None
|
||||||
street_address: Optional[str] = None
|
street_address: Optional[str] = None
|
||||||
street_address_line_2: Optional[str] = None
|
street_address_line_2: Optional[str] = None
|
||||||
locality: Optional[str] = None
|
locality: Optional[str] = None
|
||||||
address_region: Optional[str] = None
|
address_region: Optional[str] = None
|
||||||
country_code: Optional[str] = None
|
country_code: Optional[str] = None
|
||||||
postal_code: Optional[str] = None
|
postal_code: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class OrgPostUserRequest(OrgIDMixin, UserIDMixin):
|
class OrgPostUserRequest(OrgIDMixin, UserIDMixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class OrgPostUserResponse(CustomBaseModel):
|
class OrgPostUserResponse(CustomBaseModel):
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
users: list[UserSummary]
|
users: list[UserSummary]
|
||||||
|
|
||||||
|
|
||||||
class OrgPatchRootRequest(OrgIDMixin, UserIDMixin):
|
class OrgPatchRootRequest(OrgIDMixin, UserIDMixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class OrgPatchRootResponse(CustomBaseModel):
|
class OrgPatchRootResponse(CustomBaseModel):
|
||||||
name: str
|
name: str
|
||||||
root_user_email: str
|
root_user_email: str
|
||||||
|
|
||||||
|
|
||||||
class OrgGetUserResponse(CustomBaseModel):
|
class OrgGetUserResponse(CustomBaseModel):
|
||||||
users: list[dict[str, str | int]]
|
users: list[dict[str, str | int]]
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
|
|
||||||
|
|
||||||
class OrgGetGroupResponse(CustomBaseModel):
|
class OrgGetGroupResponse(CustomBaseModel):
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
groups: list[GroupSummary]
|
groups: list[GroupSummary]
|
||||||
|
|
||||||
|
|
||||||
class OrgGetContactResponse(CustomBaseModel):
|
class OrgGetContactResponse(CustomBaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||||
|
|
||||||
contact: ContactModel
|
contact: ContactModel
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
|
|
||||||
|
|
||||||
class OrgPatchContactResponse(CustomBaseModel):
|
class OrgPatchContactResponse(CustomBaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||||
|
|
||||||
contact: ContactModel
|
contact: ContactModel
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
|
|
||||||
|
|
||||||
class OrgGetOrgResponse(CustomBaseModel):
|
class OrgGetOrgResponse(CustomBaseModel):
|
||||||
organisations: list[OrgSchema]
|
organisations: list[OrgSchema]
|
||||||
|
|
|
||||||
|
|
@ -4,13 +4,13 @@ from src.schemas import CustomBaseModel
|
||||||
|
|
||||||
|
|
||||||
class QuestionnaireQuestions(CustomBaseModel):
|
class QuestionnaireQuestions(CustomBaseModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class QuestionnaireQuestionsVersion0(QuestionnaireQuestions):
|
class QuestionnaireQuestionsVersion0(QuestionnaireQuestions):
|
||||||
question_one: Optional[str] = None
|
question_one: Optional[str] = None
|
||||||
question_two: Optional[str] = None
|
question_two: Optional[str] = None
|
||||||
question_three: Optional[str] = None
|
question_three: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
questionnaire_union = QuestionnaireQuestionsVersion0 # | QuestionnaireQuestionsVersion1
|
questionnaire_union = QuestionnaireQuestionsVersion0 # | QuestionnaireQuestionsVersion1
|
||||||
|
|
|
||||||
|
|
@ -11,57 +11,57 @@ from src.user.models import User
|
||||||
|
|
||||||
|
|
||||||
async def add_default_org_permissions(
|
async def add_default_org_permissions(
|
||||||
db: Session,
|
db: Session,
|
||||||
org_model: Org,
|
org_model: Org,
|
||||||
perm_list: list[int],
|
perm_list: list[int],
|
||||||
):
|
):
|
||||||
for permission in perm_list:
|
for permission in perm_list:
|
||||||
perm_model = db.get(Perm, permission)
|
perm_model = db.get(Perm, permission)
|
||||||
|
|
||||||
if perm_model is None:
|
if perm_model is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if perm_model in org_model.permission_rel:
|
if perm_model in org_model.permission_rel:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
org_model.permission_rel.append(perm_model)
|
org_model.permission_rel.append(perm_model)
|
||||||
db.flush()
|
db.flush()
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
async def assign_defaults(
|
async def assign_defaults(
|
||||||
db: Session,
|
db: Session,
|
||||||
org_id: int,
|
org_id: int,
|
||||||
user_id: int,
|
user_id: int,
|
||||||
):
|
):
|
||||||
default_org_permissions = []
|
default_org_permissions = []
|
||||||
|
|
||||||
default_user_permissions = []
|
default_user_permissions = []
|
||||||
|
|
||||||
org_model = db.get(Org, org_id)
|
org_model = db.get(Org, org_id)
|
||||||
if org_model is None:
|
if org_model is None:
|
||||||
print("Org not found while adding defaults")
|
print("Org not found while adding defaults")
|
||||||
return
|
return
|
||||||
|
|
||||||
user_model = db.get(User, user_id)
|
user_model = db.get(User, user_id)
|
||||||
if user_model is None:
|
if user_model is None:
|
||||||
print("User not found while adding defaults")
|
print("User not found while adding defaults")
|
||||||
return
|
return
|
||||||
|
|
||||||
await add_default_org_permissions(db, org_model, default_org_permissions)
|
await add_default_org_permissions(db, org_model, default_org_permissions)
|
||||||
await assign_default_group(
|
await assign_default_group(
|
||||||
db=db,
|
db=db,
|
||||||
org_model=org_model,
|
org_model=org_model,
|
||||||
user_model=user_model,
|
user_model=user_model,
|
||||||
group_name="Default Users",
|
group_name="Default Users",
|
||||||
perm_list=default_user_permissions,
|
perm_list=default_user_permissions,
|
||||||
)
|
)
|
||||||
await assign_default_group(
|
await assign_default_group(
|
||||||
db=db,
|
db=db,
|
||||||
org_model=org_model,
|
org_model=org_model,
|
||||||
user_model=user_model,
|
user_model=user_model,
|
||||||
group_name="Root User",
|
group_name="Root User",
|
||||||
perm_list=default_org_permissions,
|
perm_list=default_org_permissions,
|
||||||
)
|
)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
|
||||||
|
|
@ -11,54 +11,54 @@ from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class CustomBaseModel(BaseModel):
|
class CustomBaseModel(BaseModel):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
### Mixins ###
|
### Mixins ###
|
||||||
class OrgIDMixin(CustomBaseModel):
|
class OrgIDMixin(CustomBaseModel):
|
||||||
organisation_id: int = Field(gt=0)
|
organisation_id: int = Field(gt=0)
|
||||||
|
|
||||||
|
|
||||||
class GroupIDMixin(CustomBaseModel):
|
class GroupIDMixin(CustomBaseModel):
|
||||||
group_id: int = Field(gt=0)
|
group_id: int = Field(gt=0)
|
||||||
|
|
||||||
|
|
||||||
class PermIDMixin(CustomBaseModel):
|
class PermIDMixin(CustomBaseModel):
|
||||||
permission_id: int = Field(gt=0)
|
permission_id: int = Field(gt=0)
|
||||||
|
|
||||||
|
|
||||||
class ServiceIDMixin(CustomBaseModel):
|
class ServiceIDMixin(CustomBaseModel):
|
||||||
service_id: int = Field(gt=0)
|
service_id: int = Field(gt=0)
|
||||||
|
|
||||||
|
|
||||||
class UserIDMixin(CustomBaseModel):
|
class UserIDMixin(CustomBaseModel):
|
||||||
user_id: int = Field(gt=0)
|
user_id: int = Field(gt=0)
|
||||||
|
|
||||||
|
|
||||||
class ServiceNameMixin(CustomBaseModel):
|
class ServiceNameMixin(CustomBaseModel):
|
||||||
service: str
|
service: str
|
||||||
|
|
||||||
|
|
||||||
class OrgSummary(CustomBaseModel):
|
class OrgSummary(CustomBaseModel):
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class GroupSummary(CustomBaseModel):
|
class GroupSummary(CustomBaseModel):
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class UserSummary(CustomBaseModel):
|
class UserSummary(CustomBaseModel):
|
||||||
id: int
|
id: int
|
||||||
email: str
|
email: str
|
||||||
|
|
||||||
|
|
||||||
class ServiceSummary(CustomBaseModel):
|
class ServiceSummary(CustomBaseModel):
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
class ResourceName(ServiceNameMixin, OrgIDMixin):
|
class ResourceName(ServiceNameMixin, OrgIDMixin):
|
||||||
resource: str
|
resource: str
|
||||||
instance: Optional[str] = None
|
instance: Optional[str] = None
|
||||||
|
|
|
||||||
|
|
@ -16,23 +16,25 @@ from src.service.models import Service
|
||||||
from src.service.schemas import ServiceIDMixin
|
from src.service.schemas import ServiceIDMixin
|
||||||
|
|
||||||
|
|
||||||
async def get_service_model_query(db: DbSession, service_id: Annotated[int, Query(gt=0)]):
|
async def get_service_model_query(
|
||||||
service_model = db.get(Service, service_id)
|
db: DbSession, service_id: Annotated[int, Query(gt=0)]
|
||||||
if service_model is None:
|
):
|
||||||
raise ServiceNotFoundException(service_id=service_id)
|
service_model = db.get(Service, service_id)
|
||||||
|
if service_model is None:
|
||||||
|
raise ServiceNotFoundException(service_id=service_id)
|
||||||
|
|
||||||
return service_model
|
return service_model
|
||||||
|
|
||||||
|
|
||||||
service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)]
|
service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)]
|
||||||
|
|
||||||
|
|
||||||
async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin):
|
async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin):
|
||||||
service_model = db.get(Service, request_model.service_id)
|
service_model = db.get(Service, request_model.service_id)
|
||||||
if service_model is None:
|
if service_model is None:
|
||||||
raise ServiceNotFoundException(service_id=request_model.service_id)
|
raise ServiceNotFoundException(service_id=request_model.service_id)
|
||||||
|
|
||||||
return service_model
|
return service_model
|
||||||
|
|
||||||
|
|
||||||
service_model_body_dependency = Annotated[Service, Depends(get_service_model_body)]
|
service_model_body_dependency = Annotated[Service, Depends(get_service_model_body)]
|
||||||
|
|
|
||||||
|
|
@ -11,13 +11,13 @@ from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
|
||||||
class ServiceNotFoundException(HTTPException):
|
class ServiceNotFoundException(HTTPException):
|
||||||
def __init__(self, service_id: Optional[int] = None) -> None:
|
def __init__(self, service_id: Optional[int] = None) -> None:
|
||||||
detail = (
|
detail = (
|
||||||
"Service not found"
|
"Service not found"
|
||||||
if service_id is None
|
if service_id is None
|
||||||
else f"Service with ID '{service_id}' was not found."
|
else f"Service with ID '{service_id}' was not found."
|
||||||
)
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=detail,
|
detail=detail,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,15 +8,16 @@ Models:
|
||||||
|
|
||||||
from sqlalchemy.orm import relationship, mapped_column, Mapped
|
from sqlalchemy.orm import relationship, mapped_column, Mapped
|
||||||
|
|
||||||
from src.models import CustomBase, IdMixin
|
from src.models import CustomBase
|
||||||
|
|
||||||
|
|
||||||
class Service(CustomBase, IdMixin):
|
class Service(CustomBase):
|
||||||
__tablename__ = "service"
|
__tablename__ = "service"
|
||||||
|
|
||||||
name: Mapped[str] = mapped_column(unique=True)
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
api_key: Mapped[str]
|
name: Mapped[str] = mapped_column(unique=True)
|
||||||
|
api_key: Mapped[str]
|
||||||
|
|
||||||
permission_rel = relationship(
|
permission_rel = relationship(
|
||||||
"Permission", back_populates="service_rel", cascade="all, delete-orphan"
|
"Permission", back_populates="service_rel", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,8 @@ from psycopg.errors import UniqueViolation
|
||||||
from src.exceptions import ConflictException
|
from src.exceptions import ConflictException
|
||||||
from src.database import DbSession
|
from src.database import DbSession
|
||||||
from src.auth.dependencies import (
|
from src.auth.dependencies import (
|
||||||
super_admin_dependency,
|
super_admin_dependency,
|
||||||
org_model_root_claim_query_dependency,
|
org_model_root_claim_query_dependency,
|
||||||
)
|
)
|
||||||
from src.iam.service import service_key_dependency
|
from src.iam.service import service_key_dependency
|
||||||
from src.iam.models import Permission as Perm
|
from src.iam.models import Permission as Perm
|
||||||
|
|
@ -25,210 +25,212 @@ from src.service.exceptions import ServiceNotFoundException
|
||||||
from src.service.models import Service
|
from src.service.models import Service
|
||||||
from src.service.utils import generate_api_key
|
from src.service.utils import generate_api_key
|
||||||
from src.service.dependencies import (
|
from src.service.dependencies import (
|
||||||
service_model_body_dependency,
|
service_model_body_dependency,
|
||||||
service_model_query_dependency,
|
service_model_query_dependency,
|
||||||
)
|
)
|
||||||
from src.service.schemas import (
|
from src.service.schemas import (
|
||||||
ServiceGetServiceResponse,
|
ServiceGetServiceResponse,
|
||||||
ServicePostServiceRequest,
|
ServicePostServiceRequest,
|
||||||
ServicePostServiceResponse,
|
ServicePostServiceResponse,
|
||||||
ServiceWithKeySchema,
|
ServiceWithKeySchema,
|
||||||
ServicePatchKeyResponse,
|
ServicePatchKeyResponse,
|
||||||
ServicePatchKeyRequest,
|
ServicePatchKeyRequest,
|
||||||
ServicePostPermissionsResponse,
|
ServicePostPermissionsResponse,
|
||||||
ServicePostPermissionsRequest,
|
ServicePostPermissionsRequest,
|
||||||
)
|
)
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
tags=["Service"],
|
tags=["Service"],
|
||||||
prefix="/service",
|
prefix="/service",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"",
|
"",
|
||||||
summary="Get all services",
|
summary="Get all services",
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
response_model=ServiceGetServiceResponse,
|
response_model=ServiceGetServiceResponse,
|
||||||
responses={
|
responses={
|
||||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||||
status.HTTP_401_UNAUTHORIZED: {
|
status.HTTP_401_UNAUTHORIZED: {
|
||||||
"description": "Unauthorized",
|
"description": "Unauthorized",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"examples": {
|
"examples": {
|
||||||
"awaiting_approval": {
|
"awaiting_approval": {
|
||||||
"summary": "Organisation has not yet been approved."
|
"summary": "Organisation has not yet been approved."
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
status.HTTP_403_FORBIDDEN: {
|
status.HTTP_403_FORBIDDEN: {
|
||||||
"description": "Forbidden",
|
"description": "Forbidden",
|
||||||
"content": {
|
"content": {
|
||||||
"application/json": {
|
"application/json": {
|
||||||
"examples": {
|
"examples": {
|
||||||
"not_root": {"summary": "Not authorised. Must be root user."},
|
"not_root": {"summary": "Not authorised. Must be root user."},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_all_services(db: DbSession, org_model: org_model_root_claim_query_dependency):
|
async def get_all_services(
|
||||||
"""
|
db: DbSession, org_model: org_model_root_claim_query_dependency
|
||||||
Returns the ID and name of all services registered to the hub.
|
):
|
||||||
"""
|
"""
|
||||||
permission_models = db.query(Service).all()
|
Returns the ID and name of all services registered to the hub.
|
||||||
|
"""
|
||||||
|
permission_models = db.query(Service).all()
|
||||||
|
|
||||||
return {"services": permission_models}
|
return {"services": permission_models}
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"",
|
"",
|
||||||
summary="Register a new service.",
|
summary="Register a new service.",
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
response_model=ServicePostServiceResponse,
|
response_model=ServicePostServiceResponse,
|
||||||
responses={
|
responses={
|
||||||
status.HTTP_200_OK: {"description": "Successfully registered a new service"},
|
status.HTTP_200_OK: {"description": "Successfully registered a new service"},
|
||||||
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
|
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
|
||||||
status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"},
|
status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def register_service(
|
async def register_service(
|
||||||
db: DbSession,
|
db: DbSession,
|
||||||
su: super_admin_dependency,
|
su: super_admin_dependency,
|
||||||
request_model: ServicePostServiceRequest,
|
request_model: ServicePostServiceRequest,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Registers a new service to the hub, generating and returning an API key for it.
|
Registers a new service to the hub, generating and returning an API key for it.
|
||||||
"""
|
"""
|
||||||
key = generate_api_key()
|
key = generate_api_key()
|
||||||
service_model = Service(name=request_model.name, api_key=key)
|
service_model = Service(name=request_model.name, api_key=key)
|
||||||
|
|
||||||
db.add(service_model)
|
db.add(service_model)
|
||||||
try:
|
try:
|
||||||
db.flush()
|
db.flush()
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
if (
|
if (
|
||||||
isinstance(e.orig, UniqueViolation) # Postgres unique violation
|
isinstance(e.orig, UniqueViolation) # Postgres unique violation
|
||||||
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
|
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
|
||||||
):
|
):
|
||||||
raise ConflictException(message="Service with this name already exists")
|
raise ConflictException(message="Service with this name already exists")
|
||||||
raise
|
raise
|
||||||
response = ServiceWithKeySchema(**service_model.__dict__)
|
response = ServiceWithKeySchema(**service_model.__dict__)
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"service": response}
|
return {"service": response}
|
||||||
|
|
||||||
|
|
||||||
@router.patch(
|
@router.patch(
|
||||||
"/key",
|
"/key",
|
||||||
summary="Regenerate service API key.",
|
summary="Regenerate service API key.",
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
response_model=ServicePatchKeyResponse,
|
response_model=ServicePatchKeyResponse,
|
||||||
responses={
|
responses={
|
||||||
status.HTTP_200_OK: {"description": "Successful update of API key"},
|
status.HTTP_200_OK: {"description": "Successful update of API key"},
|
||||||
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
|
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def regenerate_api_key(
|
async def regenerate_api_key(
|
||||||
db: DbSession,
|
db: DbSession,
|
||||||
su: super_admin_dependency,
|
su: super_admin_dependency,
|
||||||
service_model: service_model_body_dependency,
|
service_model: service_model_body_dependency,
|
||||||
request_model: ServicePatchKeyRequest,
|
request_model: ServicePatchKeyRequest,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generates and returns a new API key for the service to access the hub.
|
Generates and returns a new API key for the service to access the hub.
|
||||||
"""
|
"""
|
||||||
key = generate_api_key()
|
key = generate_api_key()
|
||||||
service_model.api_key = key
|
service_model.api_key = key
|
||||||
|
|
||||||
db.flush()
|
db.flush()
|
||||||
response = ServiceWithKeySchema(**service_model.__dict__)
|
response = ServiceWithKeySchema(**service_model.__dict__)
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"service": response}
|
return {"service": response}
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"",
|
"",
|
||||||
summary="Remove a service.",
|
summary="Remove a service.",
|
||||||
status_code=status.HTTP_204_NO_CONTENT,
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
responses={
|
responses={
|
||||||
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"},
|
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"},
|
||||||
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
|
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def remove_service(
|
async def remove_service(
|
||||||
db: DbSession,
|
db: DbSession,
|
||||||
service_model: service_model_query_dependency,
|
service_model: service_model_query_dependency,
|
||||||
su: super_admin_dependency,
|
su: super_admin_dependency,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Removes a service from the hub.
|
Removes a service from the hub.
|
||||||
"""
|
"""
|
||||||
db.delete(service_model)
|
db.delete(service_model)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
path="/permissions",
|
path="/permissions",
|
||||||
summary="Service endpoint for creating its own permissions.",
|
summary="Service endpoint for creating its own permissions.",
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
response_model=ServicePostPermissionsResponse,
|
response_model=ServicePostPermissionsResponse,
|
||||||
responses={
|
responses={
|
||||||
status.HTTP_401_UNAUTHORIZED: {
|
status.HTTP_401_UNAUTHORIZED: {
|
||||||
"description": "API Key missing or invalid | Issue verifying user OIDC claims"
|
"description": "API Key missing or invalid | Issue verifying user OIDC claims"
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def service_create_new_permissions(
|
async def service_create_new_permissions(
|
||||||
db: DbSession,
|
db: DbSession,
|
||||||
request_model: ServicePostPermissionsRequest,
|
request_model: ServicePostPermissionsRequest,
|
||||||
valid_key: service_key_dependency,
|
valid_key: service_key_dependency,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Allows a service to register its own set of permissions.
|
Allows a service to register its own set of permissions.
|
||||||
"""
|
"""
|
||||||
service_model = (
|
service_model = (
|
||||||
db.query(Service).filter(Service.name == request_model.rn.service).first()
|
db.query(Service).filter(Service.name == request_model.rn.service).first()
|
||||||
)
|
)
|
||||||
if service_model is None:
|
if service_model is None:
|
||||||
raise ServiceNotFoundException()
|
raise ServiceNotFoundException()
|
||||||
else:
|
else:
|
||||||
service_id = service_model.id
|
service_id = service_model.id
|
||||||
response_list = []
|
response_list = []
|
||||||
for new_permission in request_model.permissions:
|
for new_permission in request_model.permissions:
|
||||||
perm_model = (
|
perm_model = (
|
||||||
db.query(Perm)
|
db.query(Perm)
|
||||||
.filter(Perm.service_id == service_id)
|
.filter(Perm.service_id == service_id)
|
||||||
.filter(Perm.resource == new_permission.resource)
|
.filter(Perm.resource == new_permission.resource)
|
||||||
.filter(Perm.action == new_permission.action)
|
.filter(Perm.action == new_permission.action)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if perm_model is not None:
|
if perm_model is not None:
|
||||||
response_code = 409
|
response_code = 409
|
||||||
response = {
|
response = {
|
||||||
"id": perm_model.id,
|
"id": perm_model.id,
|
||||||
"service_name": perm_model.service_name,
|
"service_name": perm_model.service_name,
|
||||||
"resource": perm_model.resource,
|
"resource": perm_model.resource,
|
||||||
"action": perm_model.action,
|
"action": perm_model.action,
|
||||||
}
|
}
|
||||||
response_list.append((response, response_code))
|
response_list.append((response, response_code))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_perm_model = Perm(**new_permission.__dict__)
|
new_perm_model = Perm(**new_permission.__dict__)
|
||||||
new_perm_model.service_id = service_id
|
new_perm_model.service_id = service_id
|
||||||
db.add(new_perm_model)
|
db.add(new_perm_model)
|
||||||
db.flush()
|
db.flush()
|
||||||
response_code = 201
|
response_code = 201
|
||||||
response = {
|
response = {
|
||||||
"id": new_perm_model.id,
|
"id": new_perm_model.id,
|
||||||
"service_name": new_perm_model.service_name,
|
"service_name": new_perm_model.service_name,
|
||||||
"resource": new_perm_model.resource,
|
"resource": new_perm_model.resource,
|
||||||
"action": new_perm_model.action,
|
"action": new_perm_model.action,
|
||||||
}
|
}
|
||||||
response_list.append((response, response_code))
|
response_list.append((response, response_code))
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
return {"permissions": response_list}
|
return {"permissions": response_list}
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,10 @@ from typing import Generic, TypeVar
|
||||||
from pydantic import Field, ConfigDict
|
from pydantic import Field, ConfigDict
|
||||||
|
|
||||||
from src.schemas import (
|
from src.schemas import (
|
||||||
CustomBaseModel,
|
CustomBaseModel,
|
||||||
ServiceIDMixin,
|
ServiceIDMixin,
|
||||||
ServiceSummary,
|
ServiceSummary,
|
||||||
ServiceNameMixin,
|
ServiceNameMixin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -21,51 +21,51 @@ T = TypeVar("T", bound=ServiceNameMixin)
|
||||||
|
|
||||||
|
|
||||||
class HasServiceName(CustomBaseModel, Generic[T]):
|
class HasServiceName(CustomBaseModel, Generic[T]):
|
||||||
rn: T
|
rn: T
|
||||||
|
|
||||||
|
|
||||||
class PermissionResponseSchema(CustomBaseModel):
|
class PermissionResponseSchema(CustomBaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
model_config = ConfigDict(from_attributes=True, extra="ignore")
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
service_name: str
|
service_name: str
|
||||||
resource: str
|
resource: str
|
||||||
action: str
|
action: str
|
||||||
|
|
||||||
|
|
||||||
class PermissionRequestSchema(CustomBaseModel):
|
class PermissionRequestSchema(CustomBaseModel):
|
||||||
resource: str
|
resource: str
|
||||||
action: str
|
action: str
|
||||||
|
|
||||||
|
|
||||||
class ServiceWithKeySchema(ServiceSummary):
|
class ServiceWithKeySchema(ServiceSummary):
|
||||||
api_key: str
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
class ServiceGetServiceResponse(CustomBaseModel):
|
class ServiceGetServiceResponse(CustomBaseModel):
|
||||||
services: list[ServiceSummary]
|
services: list[ServiceSummary]
|
||||||
|
|
||||||
|
|
||||||
class ServicePostServiceRequest(CustomBaseModel):
|
class ServicePostServiceRequest(CustomBaseModel):
|
||||||
name: str = Field(min_length=3)
|
name: str = Field(min_length=3)
|
||||||
|
|
||||||
|
|
||||||
class ServicePostServiceResponse(CustomBaseModel):
|
class ServicePostServiceResponse(CustomBaseModel):
|
||||||
service: ServiceWithKeySchema
|
service: ServiceWithKeySchema
|
||||||
|
|
||||||
|
|
||||||
class ServicePatchKeyRequest(ServiceIDMixin):
|
class ServicePatchKeyRequest(ServiceIDMixin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ServicePatchKeyResponse(CustomBaseModel):
|
class ServicePatchKeyResponse(CustomBaseModel):
|
||||||
service: ServiceWithKeySchema
|
service: ServiceWithKeySchema
|
||||||
|
|
||||||
|
|
||||||
class ServicePostPermissionsRequest(CustomBaseModel):
|
class ServicePostPermissionsRequest(CustomBaseModel):
|
||||||
rn: ServiceNameMixin
|
rn: ServiceNameMixin
|
||||||
permissions: list[PermissionRequestSchema]
|
permissions: list[PermissionRequestSchema]
|
||||||
|
|
||||||
|
|
||||||
class ServicePostPermissionsResponse(CustomBaseModel):
|
class ServicePostPermissionsResponse(CustomBaseModel):
|
||||||
permissions: list[tuple[PermissionResponseSchema, int]]
|
permissions: list[tuple[PermissionResponseSchema, int]]
|
||||||
|
|
|
||||||
|
|
@ -9,4 +9,4 @@ import uuid
|
||||||
|
|
||||||
|
|
||||||
def generate_api_key() -> str:
|
def generate_api_key() -> str:
|
||||||
return str(uuid.uuid4())
|
return str(uuid.uuid4())
|
||||||
|
|
|
||||||
|
|
@ -10,50 +10,46 @@ Exports:
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi import Depends, Query
|
from fastapi import Depends, Query
|
||||||
|
|
||||||
from src.auth.service import claims_dependency
|
|
||||||
from src.database import DbSession
|
|
||||||
from src.schemas import UserIDMixin
|
|
||||||
from src.exceptions import ForbiddenException
|
|
||||||
|
|
||||||
from src.user.exceptions import UserNotFoundException
|
from src.user.exceptions import UserNotFoundException
|
||||||
from src.user.models import User
|
from src.user.models import User
|
||||||
|
|
||||||
|
from src.auth.service import claims_dependency
|
||||||
|
from src.database import DbSession
|
||||||
|
from src.schemas import UserIDMixin
|
||||||
|
|
||||||
|
|
||||||
async def get_user_model_claims(claims: claims_dependency, db: DbSession):
|
async def get_user_model_claims(claims: claims_dependency, db: DbSession):
|
||||||
user_id = claims.get("db_id", None)
|
user_id = claims.get("db_id", None)
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
raise UserNotFoundException()
|
raise UserNotFoundException()
|
||||||
|
|
||||||
user_model = db.get(User, user_id)
|
user_model = db.get(User, user_id)
|
||||||
if user_model is None:
|
if user_model is None:
|
||||||
raise UserNotFoundException(user_id=user_id)
|
raise UserNotFoundException(user_id=user_id)
|
||||||
|
|
||||||
if not user_model.active:
|
return user_model
|
||||||
raise ForbiddenException("User account is not active")
|
|
||||||
|
|
||||||
return user_model
|
|
||||||
|
|
||||||
|
|
||||||
user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)]
|
user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)]
|
||||||
|
|
||||||
|
|
||||||
async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]):
|
async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]):
|
||||||
user_model = db.get(User, user_id)
|
user_model = db.get(User, user_id)
|
||||||
if user_model is None:
|
if user_model is None:
|
||||||
raise UserNotFoundException(user_id=user_id)
|
raise UserNotFoundException(user_id=user_id)
|
||||||
|
|
||||||
return user_model
|
return user_model
|
||||||
|
|
||||||
|
|
||||||
user_model_query_dependency = Annotated[User, Depends(get_user_model_query)]
|
user_model_query_dependency = Annotated[User, Depends(get_user_model_query)]
|
||||||
|
|
||||||
|
|
||||||
async def get_user_model_body(db: DbSession, request_model: UserIDMixin):
|
async def get_user_model_body(db: DbSession, request_model: UserIDMixin):
|
||||||
user_model = db.get(User, request_model.user_id)
|
user_model = db.get(User, request_model.user_id)
|
||||||
if user_model is None:
|
if user_model is None:
|
||||||
raise UserNotFoundException(user_id=request_model.user_id)
|
raise UserNotFoundException(user_id=request_model.user_id)
|
||||||
|
|
||||||
return user_model
|
return user_model
|
||||||
|
|
||||||
|
|
||||||
user_model_body_dependency = Annotated[User, Depends(get_user_model_body)]
|
user_model_body_dependency = Annotated[User, Depends(get_user_model_body)]
|
||||||
|
|
|
||||||
|
|
@ -11,13 +11,13 @@ from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
|
||||||
class UserNotFoundException(HTTPException):
|
class UserNotFoundException(HTTPException):
|
||||||
def __init__(self, user_id: Optional[int] = None) -> None:
|
def __init__(self, user_id: Optional[int] = None) -> None:
|
||||||
detail = (
|
detail = (
|
||||||
"User not found"
|
"User not found"
|
||||||
if user_id is None
|
if user_id is None
|
||||||
else f"User with ID '{user_id}' was not found."
|
else f"User with ID '{user_id}' was not found."
|
||||||
)
|
)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail=detail,
|
detail=detail,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,6 @@ Models:
|
||||||
- groups: Calc property dict of {group_rel.org_rel.name: group_rel.name}
|
- groups: Calc property dict of {group_rel.org_rel.name: group_rel.name}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from src.models import IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin
|
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from sqlalchemy.orm import relationship, mapped_column, Mapped
|
from sqlalchemy.orm import relationship, mapped_column, Mapped
|
||||||
|
|
@ -19,27 +17,28 @@ from sqlalchemy.orm import relationship, mapped_column, Mapped
|
||||||
from src.models import CustomBase
|
from src.models import CustomBase
|
||||||
|
|
||||||
|
|
||||||
class User(CustomBase, IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin):
|
class User(CustomBase):
|
||||||
__tablename__ = "user"
|
__tablename__ = "user"
|
||||||
|
|
||||||
email: Mapped[str]
|
id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
first_name: Mapped[str]
|
email: Mapped[str]
|
||||||
last_name: Mapped[str]
|
first_name: Mapped[str]
|
||||||
oidc_id: Mapped[str] = mapped_column(index=True, unique=True)
|
last_name: Mapped[str]
|
||||||
|
oidc_id: Mapped[str] = mapped_column(index=True, unique=True)
|
||||||
|
|
||||||
organisation_rel = relationship(
|
organisation_rel = relationship(
|
||||||
"Organisation", secondary="orgusers", back_populates="user_rel"
|
"Organisation", secondary="orgusers", back_populates="user_rel"
|
||||||
)
|
)
|
||||||
|
|
||||||
group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel")
|
group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def organisations(self):
|
def organisations(self):
|
||||||
return [{"name": org.name, "id": org.id} for org in self.organisation_rel]
|
return [{"name": org.name, "id": org.id} for org in self.organisation_rel]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def groups(self):
|
def groups(self):
|
||||||
result = defaultdict(list)
|
result = defaultdict(list)
|
||||||
for group in self.group_rel:
|
for group in self.group_rel:
|
||||||
result[group.org_rel.name].append({"name": group.name, "id": group.id})
|
result[group.org_rel.name].append({"name": group.name, "id": group.id})
|
||||||
return dict(result)
|
return dict(result)
|
||||||
|
|
|
||||||
|
|
@ -8,213 +8,210 @@ Endpoints:
|
||||||
- [DELETE](/user/): [super admin]: Removes a User(id) from the hub database.
|
- [DELETE](/user/): [super admin]: Removes a User(id) from the hub database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from fastapi import APIRouter, status, BackgroundTasks
|
from fastapi import APIRouter, status, BackgroundTasks
|
||||||
|
|
||||||
from src.iam.models import Group
|
from src.iam.models import Group
|
||||||
from src.organisation.exceptions import OrgNotFoundException
|
from src.organisation.exceptions import OrgNotFoundException
|
||||||
from src.user.schemas import (
|
from src.user.schemas import (
|
||||||
UserResponse,
|
UserResponse,
|
||||||
OIDCClaims,
|
OIDCClaims,
|
||||||
UserPostInvitationRequest,
|
UserPostInvitationRequest,
|
||||||
UserPostInvitationAcceptRequest,
|
UserPostInvitationAcceptRequest,
|
||||||
UserGetSelfOrgsResponse,
|
UserGetSelfOrgsResponse,
|
||||||
UserPostInvitationResponse,
|
UserPostInvitationResponse,
|
||||||
UserPostInvitationAcceptResponse,
|
UserPostInvitationAcceptResponse,
|
||||||
)
|
)
|
||||||
from src.user.dependencies import (
|
from src.user.dependencies import (
|
||||||
user_model_claims_dependency,
|
user_model_claims_dependency,
|
||||||
user_model_query_dependency,
|
user_model_query_dependency,
|
||||||
)
|
)
|
||||||
from src.user.service import send_invitation
|
from src.user.service import send_invitation
|
||||||
from src.organisation.models import Organisation as Org
|
from src.organisation.models import Organisation as Org
|
||||||
|
|
||||||
from src.auth.dependencies import (
|
from src.auth.dependencies import (
|
||||||
super_admin_dependency,
|
super_admin_dependency,
|
||||||
org_model_root_claim_body_dependency,
|
org_model_root_claim_body_dependency,
|
||||||
)
|
)
|
||||||
from src.auth.service import claims_dependency
|
from src.auth.service import claims_dependency
|
||||||
from src.database import DbSession
|
from src.database import DbSession
|
||||||
from src.utils import verify_email_token
|
from src.utils import verify_email_token
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/user",
|
prefix="/user",
|
||||||
tags=["User"],
|
tags=["User"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/self/claims",
|
"/self/claims",
|
||||||
summary="Get current user OIDC claims.",
|
summary="Get current user OIDC claims.",
|
||||||
response_model=OIDCClaims,
|
response_model=OIDCClaims,
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
responses={
|
responses={
|
||||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def current_user_claims(user: claims_dependency):
|
async def current_user_claims(user: claims_dependency):
|
||||||
"""
|
"""
|
||||||
Returns the full OIDC claims associated with the currently logged-in user.
|
Returns the full OIDC claims associated with the currently logged-in user.
|
||||||
"""
|
"""
|
||||||
user["allowed_origins"] = user.get("allowed-origins", [])
|
user["allowed_origins"] = user.get("allowed-origins", [])
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/self/db",
|
"/self/db",
|
||||||
summary="Get current user hub details.",
|
summary="Get current user hub details.",
|
||||||
response_model=UserResponse,
|
response_model=UserResponse,
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
responses={
|
responses={
|
||||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def current_user(user_model: user_model_claims_dependency):
|
async def current_user(user_model: user_model_claims_dependency):
|
||||||
"""
|
"""
|
||||||
Returns the database details associated with the currently logged-in user.
|
Returns the database details associated with the currently logged-in user.
|
||||||
"""
|
"""
|
||||||
return user_model
|
return user_model
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"",
|
"",
|
||||||
summary="Get user hub details by ID.",
|
summary="Get user hub details by ID.",
|
||||||
response_model=UserResponse,
|
response_model=UserResponse,
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
responses={
|
responses={
|
||||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||||
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_user_by_id(
|
async def get_user_by_id(
|
||||||
user_model: user_model_query_dependency, su: super_admin_dependency
|
user_model: user_model_query_dependency, su: super_admin_dependency
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Returns the database details associated with the provided user ID.
|
Returns the database details associated with the provided user ID.
|
||||||
"""
|
"""
|
||||||
return user_model
|
return user_model
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete(
|
||||||
"",
|
"",
|
||||||
summary="Delete user from hub by ID.",
|
summary="Delete user from hub by ID.",
|
||||||
status_code=status.HTTP_204_NO_CONTENT,
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
responses={
|
responses={
|
||||||
status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
|
status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
|
||||||
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def soft_delete_user_by_id(
|
async def delete_user_by_id(
|
||||||
db: DbSession,
|
db: DbSession,
|
||||||
user_model: user_model_query_dependency,
|
user_model: user_model_query_dependency,
|
||||||
su: super_admin_dependency,
|
su: super_admin_dependency,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
|
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
|
||||||
"""
|
"""
|
||||||
user_model.active = False
|
db.delete(user_model)
|
||||||
user_model.deleted_at = datetime.now(tz=timezone.utc)
|
db.commit()
|
||||||
db.commit()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/self/orgs",
|
"/self/orgs",
|
||||||
summary="Get all orgs the current user is a member of",
|
summary="Get all orgs the current user is a member of",
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
response_model=UserGetSelfOrgsResponse,
|
response_model=UserGetSelfOrgsResponse,
|
||||||
responses={},
|
responses={},
|
||||||
)
|
)
|
||||||
async def get_user_orgs(user_model: user_model_claims_dependency):
|
async def get_user_orgs(user_model: user_model_claims_dependency):
|
||||||
user_orgs = user_model.organisation_rel
|
user_orgs = user_model.organisation_rel
|
||||||
response = []
|
response = []
|
||||||
for org in user_orgs:
|
for org in user_orgs:
|
||||||
response.append(
|
response.append(
|
||||||
{
|
{
|
||||||
"organisation_id": org.id,
|
"organisation_id": org.id,
|
||||||
"name": org.name,
|
"name": org.name,
|
||||||
"status": org.status,
|
"status": org.status,
|
||||||
"intake_questionnaire": org.intake_questionnaire,
|
"intake_questionnaire": org.intake_questionnaire,
|
||||||
"root_user_email": org.root_user_email,
|
"root_user_email": org.root_user_email,
|
||||||
"billing_contact": {
|
"billing_contact": {
|
||||||
"id": org.billing_contact_id,
|
"id": org.billing_contact_id,
|
||||||
"email": org.billing_contact_rel.email,
|
"email": org.billing_contact_rel.email,
|
||||||
},
|
},
|
||||||
"owner_contact": {
|
"owner_contact": {
|
||||||
"id": org.owner_contact_id,
|
"id": org.owner_contact_id,
|
||||||
"email": org.owner_contact_rel.email,
|
"email": org.owner_contact_rel.email,
|
||||||
},
|
},
|
||||||
"security_contact": {
|
"security_contact": {
|
||||||
"id": org.security_contact_id,
|
"id": org.security_contact_id,
|
||||||
"email": org.security_contact_rel.email,
|
"email": org.security_contact_rel.email,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"organisations": response}
|
return {"organisations": response}
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/invitation",
|
"/invitation",
|
||||||
summary="Send an email invitation for a user to join an org",
|
summary="Send an email invitation for a user to join an org",
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
response_model=UserPostInvitationResponse,
|
response_model=UserPostInvitationResponse,
|
||||||
)
|
)
|
||||||
async def invitation(
|
async def invitation(
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
org_model: org_model_root_claim_body_dependency,
|
org_model: org_model_root_claim_body_dependency,
|
||||||
request_model: UserPostInvitationRequest,
|
request_model: UserPostInvitationRequest,
|
||||||
):
|
):
|
||||||
org_id = org_model.id
|
org_id = org_model.id
|
||||||
org_name = org_model.name
|
org_name = org_model.name
|
||||||
user_email = request_model.user_email
|
user_email = request_model.user_email
|
||||||
|
|
||||||
background_tasks.add_task(
|
background_tasks.add_task(
|
||||||
send_invitation, org_id=org_id, org_name=org_name, user_email=user_email
|
send_invitation, org_id=org_id, org_name=org_name, user_email=user_email
|
||||||
)
|
)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"organisation": org_model,
|
"organisation": org_model,
|
||||||
"invited_email": user_email,
|
"invited_email": user_email,
|
||||||
}
|
}
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/invitation/accept",
|
"/invitation/accept",
|
||||||
summary="Accept email invitation to join an org",
|
summary="Accept email invitation to join an org",
|
||||||
status_code=status.HTTP_200_OK,
|
status_code=status.HTTP_200_OK,
|
||||||
response_model=UserPostInvitationAcceptResponse,
|
response_model=UserPostInvitationAcceptResponse,
|
||||||
)
|
)
|
||||||
async def accept_invitation(
|
async def accept_invitation(
|
||||||
db: DbSession,
|
db: DbSession,
|
||||||
user_model: user_model_claims_dependency,
|
user_model: user_model_claims_dependency,
|
||||||
request_model: UserPostInvitationAcceptRequest,
|
request_model: UserPostInvitationAcceptRequest,
|
||||||
):
|
):
|
||||||
email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model)
|
email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model)
|
||||||
|
|
||||||
org_model = db.get(Org, email_claims["org_id"])
|
org_model = db.get(Org, email_claims["org_id"])
|
||||||
if org_model is None:
|
if org_model is None:
|
||||||
raise OrgNotFoundException()
|
raise OrgNotFoundException()
|
||||||
|
|
||||||
org_model.user_rel.append(user_model)
|
org_model.user_rel.append(user_model)
|
||||||
db.flush()
|
db.flush()
|
||||||
group_model = (
|
group_model = (
|
||||||
db.query(Group)
|
db.query(Group)
|
||||||
.filter(Group.org_id == org_model.id)
|
.filter(Group.org_id == org_model.id)
|
||||||
.filter(Group.name == "Default Users")
|
.filter(Group.name == "Default Users")
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if group_model is not None:
|
if group_model is not None:
|
||||||
user_model.group_rel.append(group_model)
|
user_model.group_rel.append(group_model)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"organisation": org_model,
|
"organisation": org_model,
|
||||||
"user": user_model,
|
"user": user_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
|
||||||
|
|
@ -10,63 +10,63 @@ from src.schemas import CustomBaseModel, OrgIDMixin, OrgSummary, UserSummary
|
||||||
|
|
||||||
|
|
||||||
class OIDCClaims(CustomBaseModel):
|
class OIDCClaims(CustomBaseModel):
|
||||||
exp: int
|
exp: int
|
||||||
iat: int
|
iat: int
|
||||||
auth_time: int
|
auth_time: int
|
||||||
jti: str
|
jti: str
|
||||||
iss: str
|
iss: str
|
||||||
aud: str
|
aud: str
|
||||||
sub: str
|
sub: str
|
||||||
typ: str
|
typ: str
|
||||||
azp: str
|
azp: str
|
||||||
sid: str
|
sid: str
|
||||||
acr: str
|
acr: str
|
||||||
allowed_origins: list[str]
|
allowed_origins: list[str]
|
||||||
realm_access: dict[str, list[str]]
|
realm_access: dict[str, list[str]]
|
||||||
resource_access: dict[str, dict[str, list[str]]]
|
resource_access: dict[str, dict[str, list[str]]]
|
||||||
scope: str
|
scope: str
|
||||||
email_verified: bool
|
email_verified: bool
|
||||||
name: str
|
name: str
|
||||||
preferred_username: str
|
preferred_username: str
|
||||||
given_name: str
|
given_name: str
|
||||||
family_name: str
|
family_name: str
|
||||||
email: str
|
email: str
|
||||||
db_id: int
|
db_id: int
|
||||||
|
|
||||||
|
|
||||||
class OIDCUser(CustomBaseModel):
|
class OIDCUser(CustomBaseModel):
|
||||||
first_name: str
|
first_name: str
|
||||||
last_name: str
|
last_name: str
|
||||||
email: str
|
email: str
|
||||||
oidc_id: str
|
oidc_id: str
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(CustomBaseModel):
|
class UserResponse(CustomBaseModel):
|
||||||
id: int
|
id: int
|
||||||
first_name: str
|
first_name: str
|
||||||
last_name: str
|
last_name: str
|
||||||
email: str
|
email: str
|
||||||
organisations: list[Optional[dict[str, str | int]]]
|
organisations: list[Optional[dict[str, str | int]]]
|
||||||
groups: Optional[dict[str, list[dict[str, str | int]]]] = None
|
groups: Optional[dict[str, list[dict[str, str | int]]]] = None
|
||||||
|
|
||||||
|
|
||||||
class UserPostInvitationRequest(OrgIDMixin):
|
class UserPostInvitationRequest(OrgIDMixin):
|
||||||
user_email: EmailStr
|
user_email: EmailStr
|
||||||
|
|
||||||
|
|
||||||
class UserPostInvitationAcceptRequest(CustomBaseModel):
|
class UserPostInvitationAcceptRequest(CustomBaseModel):
|
||||||
jwt: str
|
jwt: str
|
||||||
|
|
||||||
|
|
||||||
class UserGetSelfOrgsResponse(CustomBaseModel):
|
class UserGetSelfOrgsResponse(CustomBaseModel):
|
||||||
organisations: list[OrgSchema]
|
organisations: list[OrgSchema]
|
||||||
|
|
||||||
|
|
||||||
class UserPostInvitationResponse(CustomBaseModel):
|
class UserPostInvitationResponse(CustomBaseModel):
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
invited_email: EmailStr
|
invited_email: EmailStr
|
||||||
|
|
||||||
|
|
||||||
class UserPostInvitationAcceptResponse(CustomBaseModel):
|
class UserPostInvitationAcceptResponse(CustomBaseModel):
|
||||||
organisation: OrgSummary
|
organisation: OrgSummary
|
||||||
user: UserSummary
|
user: UserSummary
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
"""
|
"""
|
||||||
Module specific business logic for user module
|
Module specific business logic for user module
|
||||||
|
|
||||||
|
Exports:
|
||||||
|
- add_user_to_db: Creates a User record from OIDC claims, or updates user details
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
import logging
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from src.exceptions import UnprocessableContentException
|
from src.exceptions import UnprocessableContentException
|
||||||
|
|
@ -15,50 +17,57 @@ from src.user.schemas import OIDCUser
|
||||||
from src.user.models import User
|
from src.user.models import User
|
||||||
|
|
||||||
|
|
||||||
async def add_user(db: Session, user_claims: dict[str, Any]) -> int:
|
async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
|
||||||
try:
|
try:
|
||||||
valid_user = OIDCUser(
|
valid_user = OIDCUser(
|
||||||
first_name=user_claims["given_name"],
|
first_name=user_claims["given_name"],
|
||||||
last_name=user_claims["family_name"],
|
last_name=user_claims["family_name"],
|
||||||
email=user_claims["email"],
|
email=user_claims["email"],
|
||||||
oidc_id=user_claims["sub"],
|
oidc_id=user_claims["sub"],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
print(e)
|
||||||
raise UnprocessableContentException("Invalid or missing OIDC data")
|
raise UnprocessableContentException("Invalid or missing OIDC data")
|
||||||
|
|
||||||
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
|
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
|
||||||
|
|
||||||
if not db_user:
|
if not db_user:
|
||||||
user_model = User(**valid_user.model_dump())
|
user_model = User(**valid_user.model_dump())
|
||||||
db.add(user_model)
|
db.add(user_model)
|
||||||
user_id = user_model.id
|
user_id = user_model.id
|
||||||
db.commit()
|
db.commit()
|
||||||
return user_id
|
return user_id
|
||||||
|
else:
|
||||||
user_id = db_user.id
|
user_id = db_user.id
|
||||||
db_user.first_name = valid_user.first_name
|
change = False
|
||||||
db_user.last_name = valid_user.last_name
|
if db_user.first_name != valid_user.first_name:
|
||||||
db.commit()
|
db_user.first_name = valid_user.first_name
|
||||||
return user_id
|
change = True
|
||||||
|
if db_user.last_name != valid_user.last_name:
|
||||||
|
db_user.last_name = valid_user.last_name
|
||||||
|
change = True
|
||||||
|
if change:
|
||||||
|
db.add(db_user)
|
||||||
|
db.commit()
|
||||||
|
return user_id
|
||||||
|
|
||||||
|
|
||||||
async def send_invitation(user_email: str, org_name: str, org_id: int):
|
async def send_invitation(user_email: str, org_name: str, org_id: int):
|
||||||
expiry_delta = timedelta(hours=24)
|
expiry_delta = timedelta(hours=24)
|
||||||
expiry = datetime.now(timezone.utc) + expiry_delta
|
expiry = datetime.now(timezone.utc) + expiry_delta
|
||||||
claims = {
|
claims = {
|
||||||
"email": user_email,
|
"email": user_email,
|
||||||
"org_id": org_id,
|
"org_id": org_id,
|
||||||
"exp": expiry,
|
"exp": expiry,
|
||||||
"type": "org_invite",
|
"type": "org_invite",
|
||||||
}
|
}
|
||||||
|
|
||||||
token = await generate_jwt(claims)
|
token = await generate_jwt(claims)
|
||||||
subject = f"You have been invited to join {org_name}"
|
subject = f"You have been invited to join {org_name}"
|
||||||
body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
|
body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
|
||||||
|
|
||||||
await send_email(
|
await send_email(
|
||||||
recipient=user_email,
|
recipient=user_email,
|
||||||
subject=subject,
|
subject=subject,
|
||||||
body=body,
|
body=body,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
67
src/utils.py
67
src/utils.py
|
|
@ -1,5 +1,3 @@
|
||||||
import logging
|
|
||||||
|
|
||||||
from lettermint import Lettermint, ValidationError
|
from lettermint import Lettermint, ValidationError
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from joserfc import jwt, jwk, errors
|
from joserfc import jwt, jwk, errors
|
||||||
|
|
@ -11,56 +9,51 @@ KEY = jwk.import_key(settings.SECRET_KEY.get_secret_value(), "oct")
|
||||||
|
|
||||||
|
|
||||||
async def generate_jwt(claims):
|
async def generate_jwt(claims):
|
||||||
jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims)
|
jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims)
|
||||||
|
|
||||||
return jwt_token
|
return jwt_token
|
||||||
|
|
||||||
|
|
||||||
async def decode_jwt(encoded):
|
async def decode_jwt(encoded):
|
||||||
try:
|
try:
|
||||||
token = jwt.decode(encoded, key=KEY)
|
token = jwt.decode(encoded, key=KEY)
|
||||||
return token.claims
|
return token.claims
|
||||||
except errors.DecodeError:
|
except errors.DecodeError:
|
||||||
raise UnauthorizedException("Invalid JWS")
|
raise UnauthorizedException("Invalid JWS")
|
||||||
|
|
||||||
|
|
||||||
async def verify_email_token(user_model, token):
|
async def verify_email_token(user_model, token):
|
||||||
email_claims = await decode_jwt(token)
|
email_claims = await decode_jwt(token)
|
||||||
|
|
||||||
claimed_email = email_claims["email"]
|
claimed_email = email_claims["email"]
|
||||||
|
|
||||||
expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc)
|
expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc)
|
||||||
|
|
||||||
if expiry < datetime.now(timezone.utc):
|
if expiry < datetime.now(timezone.utc):
|
||||||
raise UnauthorizedException("Invitation expired.")
|
raise UnauthorizedException("Invitation expired.")
|
||||||
|
|
||||||
if user_model.email != claimed_email:
|
if user_model.email != claimed_email:
|
||||||
raise ForbiddenException("The logged in user and email do not match.")
|
raise ForbiddenException("The logged in user and email do not match.")
|
||||||
|
|
||||||
return email_claims
|
return email_claims
|
||||||
|
|
||||||
|
|
||||||
async def send_email(recipient: str, subject: str, body: str):
|
async def send_email(recipient: str, subject: str, body: str):
|
||||||
if settings.ENVIRONMENT.is_testing:
|
lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value())
|
||||||
return
|
|
||||||
|
|
||||||
lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value())
|
if settings.ENVIRONMENT.is_testing or settings.ENVIRONMENT == "local":
|
||||||
|
recipient = "ok@testing.lettermint.co"
|
||||||
|
|
||||||
if settings.ENVIRONMENT == "local":
|
try:
|
||||||
recipient = "ok@testing.lettermint.co"
|
response = (
|
||||||
|
lettermint.email.from_("noreply@sr2.uk")
|
||||||
|
.to(recipient)
|
||||||
|
.subject(subject)
|
||||||
|
.text(body)
|
||||||
|
.send()
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
print(response.status_code)
|
||||||
response = (
|
except ValidationError:
|
||||||
lettermint.email.from_("noreply@sr2.uk")
|
# Error thrown if domain not approved for project
|
||||||
.to(recipient)
|
print("Lettermint validation error")
|
||||||
.subject(subject)
|
|
||||||
.text(body)
|
|
||||||
.send()
|
|
||||||
)
|
|
||||||
logging.info(
|
|
||||||
"Email sent to {} with subject {} (Status: {})".format(
|
|
||||||
recipient, subject, response.status_code
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except ValidationError as e:
|
|
||||||
logging.exception(e)
|
|
||||||
|
|
|
||||||
463
test/conftest.py
463
test/conftest.py
|
|
@ -1,9 +1,8 @@
|
||||||
from fastapi.dependencies.models import Dependant
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
from itertools import combinations
|
from itertools import combinations
|
||||||
from fastapi.routing import APIRoute, iter_route_contexts
|
from fastapi.routing import APIRoute
|
||||||
from httpx import AsyncClient, ASGITransport
|
from httpx import AsyncClient, ASGITransport
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
|
@ -15,7 +14,7 @@ from src.iam.models import Group, Permission, OrgPermissions
|
||||||
from src.auth.service import get_current_user, get_dev_user
|
from src.auth.service import get_current_user, get_dev_user
|
||||||
from src.auth.dependencies import empty_su_list, get_super_admin_list, testing_su_list
|
from src.auth.dependencies import empty_su_list, get_super_admin_list, testing_su_list
|
||||||
from src.main import app # inited FastAPI app
|
from src.main import app # inited FastAPI app
|
||||||
from src.database import engine, get_db_session
|
from src.database import engine, get_db
|
||||||
from src.models import CustomBase
|
from src.models import CustomBase
|
||||||
|
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
@ -23,295 +22,269 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def db_session():
|
def db_session():
|
||||||
CustomBase.metadata.drop_all(bind=engine)
|
CustomBase.metadata.drop_all(bind=engine)
|
||||||
CustomBase.metadata.create_all(bind=engine)
|
CustomBase.metadata.create_all(bind=engine)
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
_seed(db) # extracted seeding logic into a plain function
|
_seed(db) # extracted seeding logic into a plain function
|
||||||
yield db
|
yield db
|
||||||
finally:
|
finally:
|
||||||
db.rollback()
|
db.rollback()
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]:
|
async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]:
|
||||||
def get_db_override():
|
def get_db_override():
|
||||||
return db_session
|
return db_session
|
||||||
|
|
||||||
app.dependency_overrides[get_db_session] = get_db_override
|
app.dependency_overrides[get_db] = get_db_override
|
||||||
app.dependency_overrides[get_current_user] = get_dev_user
|
app.dependency_overrides[get_current_user] = get_dev_user
|
||||||
app.dependency_overrides[get_super_admin_list] = testing_su_list
|
app.dependency_overrides[get_super_admin_list] = testing_su_list
|
||||||
transport = ASGITransport(app=app)
|
transport = ASGITransport(app=app)
|
||||||
async with AsyncClient(
|
async with AsyncClient(
|
||||||
transport=transport, base_url="http://localhost:8000/api/v1"
|
transport=transport, base_url="http://localhost:8000/api/v1"
|
||||||
) as ac:
|
) as ac:
|
||||||
yield ac
|
yield ac
|
||||||
|
|
||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]:
|
async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]:
|
||||||
def get_db_override():
|
def get_db_override():
|
||||||
return db_session
|
return db_session
|
||||||
|
|
||||||
app.dependency_overrides[get_db_session] = get_db_override
|
app.dependency_overrides[get_db] = get_db_override
|
||||||
transport = ASGITransport(app=app)
|
transport = ASGITransport(app=app)
|
||||||
async with AsyncClient(
|
async with AsyncClient(
|
||||||
transport=transport, base_url="http://localhost:8000/api/v1"
|
transport=transport, base_url="http://localhost:8000/api/v1"
|
||||||
) as ac:
|
) as ac:
|
||||||
yield ac
|
yield ac
|
||||||
|
|
||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]:
|
async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]:
|
||||||
def get_db_override():
|
def get_db_override():
|
||||||
return db_session
|
return db_session
|
||||||
|
|
||||||
app.dependency_overrides[get_db_session] = get_db_override
|
app.dependency_overrides[get_db] = get_db_override
|
||||||
app.dependency_overrides[get_current_user] = get_dev_user
|
app.dependency_overrides[get_current_user] = get_dev_user
|
||||||
app.dependency_overrides[get_super_admin_list] = empty_su_list
|
app.dependency_overrides[get_super_admin_list] = empty_su_list
|
||||||
transport = ASGITransport(app=app)
|
transport = ASGITransport(app=app)
|
||||||
async with AsyncClient(
|
async with AsyncClient(
|
||||||
transport=transport, base_url="http://localhost:8000/api/v1"
|
transport=transport, base_url="http://localhost:8000/api/v1"
|
||||||
) as ac:
|
) as ac:
|
||||||
yield ac
|
yield ac
|
||||||
|
|
||||||
app.dependency_overrides.clear()
|
app.dependency_overrides.clear()
|
||||||
|
|
||||||
|
|
||||||
def _seed(db):
|
def _seed(db):
|
||||||
db.add(
|
db.add(
|
||||||
User(
|
User(
|
||||||
email="admin@test.com",
|
email="admin@test.com",
|
||||||
first_name="Admin",
|
first_name="Admin",
|
||||||
last_name="Test",
|
last_name="Test",
|
||||||
oidc_id="abcd-efgh-ijkl-mnop",
|
oidc_id="abcd-efgh-ijkl-mnop",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
db.add(
|
db.add(
|
||||||
User(
|
User(
|
||||||
email="user@orgone.com",
|
email="user@orgone.com",
|
||||||
first_name="User",
|
first_name="User",
|
||||||
last_name="Test",
|
last_name="Test",
|
||||||
oidc_id="abcd-efgh-ijkl-qwer",
|
oidc_id="abcd-efgh-ijkl-qwer",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
db.add(
|
db.add(
|
||||||
User(
|
User(
|
||||||
email="root@orgtwo.com",
|
email="root@orgtwo.com",
|
||||||
first_name="Root",
|
first_name="Root",
|
||||||
last_name="Test",
|
last_name="Test",
|
||||||
oidc_id="abcd-efgh-ijkl-hjkl",
|
oidc_id="abcd-efgh-ijkl-hjkl",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927"))
|
db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927"))
|
||||||
db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927"))
|
db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927"))
|
||||||
db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927"))
|
db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927"))
|
||||||
db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927"))
|
db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927"))
|
||||||
db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927"))
|
db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927"))
|
||||||
db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927"))
|
db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927"))
|
||||||
db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927"))
|
db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927"))
|
||||||
db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927"))
|
db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927"))
|
||||||
db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927"))
|
db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927"))
|
||||||
db.flush()
|
db.flush()
|
||||||
db.add(
|
db.add(
|
||||||
Org(
|
Org(
|
||||||
name="Org One",
|
name="Org One",
|
||||||
root_user_id=1,
|
root_user_id=1,
|
||||||
billing_contact_id=1,
|
billing_contact_id=1,
|
||||||
owner_contact_id=2,
|
owner_contact_id=2,
|
||||||
security_contact_id=3,
|
security_contact_id=3,
|
||||||
status="approved",
|
status="approved",
|
||||||
intake_questionnaire={
|
intake_questionnaire={
|
||||||
"metadata": {"version": 0, "submission_date": None},
|
"metadata": {"version": 0, "submission_date": None},
|
||||||
"questions": {"question_two": "answer two"},
|
"questions": {"question_two": "answer two"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
db.add(
|
db.add(
|
||||||
Org(
|
Org(
|
||||||
name="Org Two",
|
name="Org Two",
|
||||||
root_user_id=3,
|
root_user_id=3,
|
||||||
billing_contact_id=4,
|
billing_contact_id=4,
|
||||||
owner_contact_id=5,
|
owner_contact_id=5,
|
||||||
security_contact_id=6,
|
security_contact_id=6,
|
||||||
status="approved",
|
status="approved",
|
||||||
intake_questionnaire={
|
intake_questionnaire={
|
||||||
"metadata": {"version": 0, "submission_date": None},
|
"metadata": {"version": 0, "submission_date": None},
|
||||||
"questions": {"question_two": "answer two"},
|
"questions": {"question_two": "answer two"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
db.add(
|
db.add(
|
||||||
Org(
|
Org(
|
||||||
name="Org Three",
|
name="Org Three",
|
||||||
root_user_id=1,
|
root_user_id=1,
|
||||||
billing_contact_id=7,
|
billing_contact_id=7,
|
||||||
owner_contact_id=8,
|
owner_contact_id=8,
|
||||||
security_contact_id=9,
|
security_contact_id=9,
|
||||||
status="partial",
|
status="partial",
|
||||||
intake_questionnaire={
|
intake_questionnaire={
|
||||||
"metadata": {"version": 0, "submission_date": None},
|
"metadata": {"version": 0, "submission_date": None},
|
||||||
"questions": {"question_two": "answer two"},
|
"questions": {"question_two": "answer two"},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
db.add(OrgUsers(org_id=1, user_id=2))
|
db.add(OrgUsers(org_id=1, user_id=2))
|
||||||
db.add(Service(name="Test Service", api_key="123456789"))
|
db.add(Service(name="Test Service", api_key="123456789"))
|
||||||
db.add(Permission(service_id=1, resource="test_resource", action="read"))
|
db.add(Permission(service_id=1, resource="test_resource", action="read"))
|
||||||
db.add(Permission(service_id=1, resource="test_resource", action="move"))
|
db.add(Permission(service_id=1, resource="test_resource", action="move"))
|
||||||
db.add(Permission(service_id=1, resource="test_resource", action="delete"))
|
db.add(Permission(service_id=1, resource="test_resource", action="delete"))
|
||||||
db.add(OrgPermissions(org_id=1, permission_id=1))
|
db.add(OrgPermissions(org_id=1, permission_id=1))
|
||||||
db.add(OrgPermissions(org_id=1, permission_id=2))
|
db.add(OrgPermissions(org_id=1, permission_id=2))
|
||||||
db.add(Group(name="Org One Group", org_id=1))
|
db.add(Group(name="Org One Group", org_id=1))
|
||||||
db.add(Group(name="Org Two Group", org_id=2))
|
db.add(Group(name="Org Two Group", org_id=2))
|
||||||
db.add(Group(name="Org One Group Two", org_id=1))
|
db.add(Group(name="Org One Group Two", org_id=1))
|
||||||
db.flush()
|
db.flush()
|
||||||
group_model = db.get(Group, 1)
|
group_model = db.get(Group, 1)
|
||||||
perm_model = db.get(Permission, 1)
|
perm_model = db.get(Permission, 1)
|
||||||
group_model.permission_rel.append(perm_model)
|
group_model.permission_rel.append(perm_model)
|
||||||
user_model = db.get(User, 1)
|
user_model = db.get(User, 1)
|
||||||
org_model = db.get(Org, 1)
|
org_model = db.get(Org, 1)
|
||||||
org_model.user_rel.append(user_model)
|
org_model.user_rel.append(user_model)
|
||||||
org_model.group_rel.append(group_model)
|
org_model.group_rel.append(group_model)
|
||||||
db.flush()
|
db.flush()
|
||||||
group_model.user_rel.append(user_model)
|
group_model.user_rel.append(user_model)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
|
|
||||||
def generate_query_and_status(params) -> list[tuple[str, int]]:
|
def generate_query_and_status(params) -> list[tuple[str, int]]:
|
||||||
possible_values = [0, -1, 42, "banana", ""]
|
possible_values = [0, -1, 42, "banana", ""]
|
||||||
|
|
||||||
defaults = [f"{param}=1" for param in params]
|
defaults = [f"{param}=1" for param in params]
|
||||||
|
|
||||||
# Missing params
|
# Missing params
|
||||||
query_list = [
|
query_list = [
|
||||||
"&".join(combo)
|
"&".join(combo)
|
||||||
for r in range(len(defaults) + 1)
|
for r in range(len(defaults) + 1)
|
||||||
for combo in combinations(defaults, r)
|
for combo in combinations(defaults, r)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Complete query as default for invalid checks
|
# Complete query as default for invalid checks
|
||||||
default_query = query_list.pop(-1)
|
default_query = query_list.pop(-1)
|
||||||
|
|
||||||
# Checks for each param being invalid
|
# Checks for each param being invalid
|
||||||
for param in params:
|
for param in params:
|
||||||
for value in possible_values:
|
for value in possible_values:
|
||||||
new_value = f"&{param}={value}"
|
new_value = f"&{param}={value}"
|
||||||
query_list.append(default_query.replace(f"{param}=1", new_value))
|
query_list.append(default_query.replace(f"{param}=1", new_value))
|
||||||
|
|
||||||
query_and_status = []
|
query_and_status = []
|
||||||
|
|
||||||
# Assign expected status
|
# Assign expected status
|
||||||
for query in query_list:
|
for query in query_list:
|
||||||
# ID 42 is used to represent a non-existent entry. So it should 404.
|
# ID 42 is used to represent a non-existent entry. So it should 404.
|
||||||
status = 404 if "42" in query else 422
|
status = 404 if "42" in query else 422
|
||||||
# Remove leading "&" if present
|
# Remove leading "&" if present
|
||||||
query = query if len(query) > 1 and query[0] != "&" else query[1:]
|
query = query if len(query) > 1 and query[0] != "&" else query[1:]
|
||||||
query_and_status.append((query, status))
|
query_and_status.append((query, status))
|
||||||
|
|
||||||
return query_and_status
|
return query_and_status
|
||||||
|
|
||||||
|
|
||||||
def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]:
|
def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]:
|
||||||
possible_values_int = [0, -1, 42, "banana", ""]
|
possible_values_int = [0, -1, 42, "banana", ""]
|
||||||
possible_values_str = [0, "", "a"]
|
possible_values_str = [0, "", "a"]
|
||||||
|
|
||||||
defaults = [{param: 1 for param in params.keys()}]
|
defaults = [{param: 1 for param in params.keys()}]
|
||||||
|
|
||||||
# Missing params
|
# Missing params
|
||||||
body_list = [
|
body_list = [
|
||||||
{key: ("valid string" if params[key] == "str" else 1) for key in combo}
|
{key: ("valid string" if params[key] == "str" else 1) for key in combo}
|
||||||
for r in range(len(defaults[0].keys()) + 1)
|
for r in range(len(defaults[0].keys()) + 1)
|
||||||
for combo in combinations(defaults[0].keys(), r)
|
for combo in combinations(defaults[0].keys(), r)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Complete body as default for generating invalid checks
|
# Complete body as default for generating invalid checks
|
||||||
default_body = body_list.pop(-1)
|
default_body = body_list.pop(-1)
|
||||||
|
|
||||||
# Generates checks for each param being invalid
|
# Generates checks for each param being invalid
|
||||||
for param, typ in params.items():
|
for param, typ in params.items():
|
||||||
if typ == "int":
|
if typ == "int":
|
||||||
possible_values = possible_values_int
|
possible_values = possible_values_int
|
||||||
elif typ == "str":
|
elif typ == "str":
|
||||||
possible_values = possible_values_str
|
possible_values = possible_values_str
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Unknown type {typ}")
|
raise TypeError(f"Unknown type {typ}")
|
||||||
for value in possible_values:
|
for value in possible_values:
|
||||||
new_record = default_body.copy()
|
new_record = default_body.copy()
|
||||||
new_record[param] = value
|
new_record[param] = value
|
||||||
body_list.append(new_record)
|
body_list.append(new_record)
|
||||||
|
|
||||||
body_and_status = []
|
body_and_status = []
|
||||||
|
|
||||||
# Assign expected status
|
# Assign expected status
|
||||||
for body in body_list:
|
for body in body_list:
|
||||||
# ID 42 is used to represent a non-existent entry. So it should 404.
|
# ID 42 is used to represent a non-existent entry. So it should 404.
|
||||||
status = 404 if 42 in body.values() else 422
|
status = 404 if 42 in body.values() else 422
|
||||||
body_and_status.append((body, status))
|
body_and_status.append((body, status))
|
||||||
return body_and_status
|
return body_and_status
|
||||||
|
|
||||||
|
|
||||||
def get_testable_routes():
|
def get_testable_routes():
|
||||||
routes = []
|
routes = []
|
||||||
|
|
||||||
contexts = list(iter_route_contexts(app.routes))
|
for route in app.routes:
|
||||||
|
if not isinstance(route, APIRoute):
|
||||||
|
continue
|
||||||
|
|
||||||
for route in contexts:
|
for method in route.methods:
|
||||||
if not route.methods:
|
if method in {"HEAD", "OPTIONS"}:
|
||||||
continue
|
continue
|
||||||
if not isinstance(route.route, APIRoute):
|
|
||||||
continue
|
|
||||||
|
|
||||||
dep_func_names = set()
|
routes.append(
|
||||||
|
(
|
||||||
|
method,
|
||||||
|
route.path,
|
||||||
|
route.status_code,
|
||||||
|
route.response_model,
|
||||||
|
route.summary,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
unchecked = []
|
return routes
|
||||||
unchecked.append(route.route.dependant)
|
|
||||||
while unchecked:
|
|
||||||
dependant = unchecked.pop(0)
|
|
||||||
ck = dependant.cache_key[0]
|
|
||||||
if hasattr(ck, "__name__"):
|
|
||||||
dep_func_names.add(ck.__name__)
|
|
||||||
unchecked += [
|
|
||||||
dep for dep in dependant.dependencies if isinstance(dep, Dependant)
|
|
||||||
]
|
|
||||||
|
|
||||||
auth_level = None
|
|
||||||
if "get_current_user" in dep_func_names:
|
|
||||||
auth_level = "User"
|
|
||||||
if (
|
|
||||||
"org_body_root_claims" in dep_func_names
|
|
||||||
or "org_query_root_claims" in dep_func_names
|
|
||||||
):
|
|
||||||
auth_level = "Root User"
|
|
||||||
if "user_model_super_admin" in dep_func_names:
|
|
||||||
auth_level = "Super Admin"
|
|
||||||
if "valid_service_key" in dep_func_names:
|
|
||||||
auth_level = "API Key"
|
|
||||||
|
|
||||||
for method in route.methods:
|
|
||||||
if method in {"HEAD", "OPTIONS"}:
|
|
||||||
continue
|
|
||||||
|
|
||||||
routes.append(
|
|
||||||
(
|
|
||||||
method,
|
|
||||||
route.route.path,
|
|
||||||
route.route.status_code,
|
|
||||||
route.route.response_model,
|
|
||||||
route.route.summary,
|
|
||||||
auth_level,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return routes
|
|
||||||
|
|
||||||
|
|
||||||
|
# with open("endpoints.txt", "w") as f:
|
||||||
|
# for ep in get_testable_routes():
|
||||||
|
# f.write(f"[{ep[0]}]({ep[1]}) -> {ep[2]}: {ep[3]}\n")
|
||||||
|
#
|
||||||
|
#
|
||||||
### Docstring formatted output ###
|
### Docstring formatted output ###
|
||||||
with open("endpoints.txt", "w") as f:
|
# with open("endpoints.txt", "w") as f:
|
||||||
for ep in get_testable_routes():
|
# for ep in get_testable_routes():
|
||||||
f.write(f"- [{ep[0]}]({ep[1]}): [{ep[5]}]: {ep[4]}\n")
|
# f.write(f"- [{ep[0]}]({ep[1]}): []: {ep[4]}\n")
|
||||||
|
|
|
||||||
|
|
@ -8,181 +8,181 @@ import pytest
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
pytestmark = [
|
pytestmark = [
|
||||||
pytest.mark.auth,
|
pytest.mark.auth,
|
||||||
pytest.mark.preapproval,
|
pytest.mark.preapproval,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_auth_approval(no_su_client: AsyncClient):
|
async def test_get_org_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/org?org_id=3")
|
resp = await no_su_client.get("/org?org_id=3")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient):
|
async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.patch(
|
resp = await no_su_client.patch(
|
||||||
"/org/questionnaire",
|
"/org/questionnaire",
|
||||||
json={
|
json={
|
||||||
"organisation_id": 3,
|
"organisation_id": 3,
|
||||||
"intake_questionnaire": {
|
"intake_questionnaire": {
|
||||||
"question_one": "new answer one",
|
"question_one": "new answer one",
|
||||||
"question_two": None,
|
"question_two": None,
|
||||||
"question_three": None,
|
"question_three": None,
|
||||||
},
|
},
|
||||||
"partial": True,
|
"partial": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_users_auth_approval(no_su_client: AsyncClient):
|
async def test_get_org_users_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/org/users?org_id=3")
|
resp = await no_su_client.get("/org/users?org_id=3")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_groups_auth_approval(no_su_client: AsyncClient):
|
async def test_get_org_groups_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/org/groups?org_id=3")
|
resp = await no_su_client.get("/org/groups?org_id=3")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_contact_auth_approval(no_su_client: AsyncClient):
|
async def test_get_org_contact_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing")
|
resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient):
|
async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.patch(
|
resp = await no_su_client.patch(
|
||||||
"/org/contact",
|
"/org/contact",
|
||||||
json={
|
json={
|
||||||
"organisation_id": 3,
|
"organisation_id": 3,
|
||||||
"contact_type": "billing",
|
"contact_type": "billing",
|
||||||
"email": "user@example.com",
|
"email": "user@example.com",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_service_auth_approval(no_su_client: AsyncClient):
|
async def test_get_service_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/service?org_id=3")
|
resp = await no_su_client.get("/service?org_id=3")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_iam_group_permissions_auth_approval(no_su_client: AsyncClient):
|
async def test_get_iam_group_permissions_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1")
|
resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_iam_group_users_auth_approval(no_su_client: AsyncClient):
|
async def test_get_iam_group_users_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1")
|
resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_iam_group_auth_approval(no_su_client: AsyncClient):
|
async def test_post_iam_group_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.post(
|
resp = await no_su_client.post(
|
||||||
"/iam/group", json={"name": "New Group", "organisation_id": 3}
|
"/iam/group", json={"name": "New Group", "organisation_id": 3}
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient):
|
async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.put(
|
resp = await no_su_client.put(
|
||||||
"/iam/group/permission",
|
"/iam/group/permission",
|
||||||
json={"permission_id": 1, "group_id": 2, "organisation_id": 3},
|
json={"permission_id": 1, "group_id": 2, "organisation_id": 3},
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient):
|
async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.put(
|
resp = await no_su_client.put(
|
||||||
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3}
|
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3}
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_iam_permissions_auth_approval(no_su_client: AsyncClient):
|
async def test_get_iam_permissions_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/iam/permissions?org_id=3")
|
resp = await no_su_client.get("/iam/permissions?org_id=3")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_iam_permissions_search_auth_approval(no_su_client: AsyncClient):
|
async def test_post_iam_permissions_search_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.post(
|
resp = await no_su_client.post(
|
||||||
"/iam/permissions/search", json={"organisation_id": 3, "action": "read"}
|
"/iam/permissions/search", json={"organisation_id": 3, "action": "read"}
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_delete_org_user_auth_approval(no_su_client: AsyncClient):
|
async def test_delete_org_user_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.delete("/org/user?org_id=3&user_id=1")
|
resp = await no_su_client.delete("/org/user?org_id=3&user_id=1")
|
||||||
|
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_delete_preapproval_auth_approval(no_su_client: AsyncClient):
|
async def test_delete_preapproval_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.delete("/org/self?org_id=3")
|
resp = await no_su_client.delete("/org/self?org_id=3")
|
||||||
|
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_user_invitation_auth_approval(no_su_client: AsyncClient):
|
async def test_post_user_invitation_auth_approval(no_su_client: AsyncClient):
|
||||||
body = {"user_email": "admin@test.com", "organisation_id": 3}
|
body = {"user_email": "admin@test.com", "organisation_id": 3}
|
||||||
resp = await no_su_client.post("/user/invitation", json=body)
|
resp = await no_su_client.post("/user/invitation", json=body)
|
||||||
|
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_delete_group_permissions_auth_approval(no_su_client: AsyncClient):
|
async def test_delete_group_permissions_auth_approval(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1")
|
resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1")
|
||||||
|
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_delete_group_users_success(no_su_client: AsyncClient):
|
async def test_delete_group_users_success(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1")
|
resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1")
|
||||||
|
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_put_group_user_invitation_success(no_su_client: AsyncClient):
|
async def test_put_group_user_invitation_success(no_su_client: AsyncClient):
|
||||||
body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1}
|
body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1}
|
||||||
resp = await no_su_client.put("/iam/group/user/invitation", json=body)
|
resp = await no_su_client.put("/iam/group/user/invitation", json=body)
|
||||||
|
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert "has not been approved." in resp.json()["detail"]
|
assert "has not been approved." in resp.json()["detail"]
|
||||||
|
|
|
||||||
|
|
@ -5,14 +5,14 @@ from httpx import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
pytestmark = [
|
pytestmark = [
|
||||||
pytest.mark.auth,
|
pytest.mark.auth,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_auth_root_su(default_client: AsyncClient):
|
async def test_get_org_auth_root_su(default_client: AsyncClient):
|
||||||
# If a super admin can access a resource when not the root user
|
# If a super admin can access a resource when not the root user
|
||||||
resp = await default_client.get("/org?org_id=2")
|
resp = await default_client.get("/org?org_id=2")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert resp.json()["organisations"][0]["name"] == "Org Two"
|
assert resp.json()["organisations"][0]["name"] == "Org Two"
|
||||||
|
|
|
||||||
|
|
@ -7,147 +7,147 @@ import pytest
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
pytestmark = [
|
pytestmark = [
|
||||||
pytest.mark.auth,
|
pytest.mark.auth,
|
||||||
pytest.mark.root_user,
|
pytest.mark.root_user,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_auth_root(no_su_client: AsyncClient):
|
async def test_get_org_auth_root(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/org?org_id=2")
|
resp = await no_su_client.get("/org?org_id=2")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient):
|
async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.patch(
|
resp = await no_su_client.patch(
|
||||||
"/org/questionnaire",
|
"/org/questionnaire",
|
||||||
json={
|
json={
|
||||||
"organisation_id": 2,
|
"organisation_id": 2,
|
||||||
"intake_questionnaire": {
|
"intake_questionnaire": {
|
||||||
"question_one": "new answer one",
|
"question_one": "new answer one",
|
||||||
"question_two": None,
|
"question_two": None,
|
||||||
"question_three": None,
|
"question_three": None,
|
||||||
},
|
},
|
||||||
"partial": True,
|
"partial": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_users_auth_root(no_su_client: AsyncClient):
|
async def test_get_org_users_auth_root(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/org/users?org_id=2")
|
resp = await no_su_client.get("/org/users?org_id=2")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_groups_auth_root(no_su_client: AsyncClient):
|
async def test_get_org_groups_auth_root(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/org/groups?org_id=2")
|
resp = await no_su_client.get("/org/groups?org_id=2")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_contact_auth_root(no_su_client: AsyncClient):
|
async def test_get_org_contact_auth_root(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing")
|
resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_contact_auth_root(no_su_client: AsyncClient):
|
async def test_patch_org_contact_auth_root(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.patch(
|
resp = await no_su_client.patch(
|
||||||
"/org/contact",
|
"/org/contact",
|
||||||
json={
|
json={
|
||||||
"organisation_id": 2,
|
"organisation_id": 2,
|
||||||
"contact_type": "billing",
|
"contact_type": "billing",
|
||||||
"email": "user@example.com",
|
"email": "user@example.com",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_service_auth_root(no_su_client: AsyncClient):
|
async def test_get_service_auth_root(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/service?org_id=2")
|
resp = await no_su_client.get("/service?org_id=2")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_iam_group_permissions_auth_root(no_su_client: AsyncClient):
|
async def test_get_iam_group_permissions_auth_root(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1")
|
resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient):
|
async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1")
|
resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_iam_group_auth_root(no_su_client: AsyncClient):
|
async def test_post_iam_group_auth_root(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.post(
|
resp = await no_su_client.post(
|
||||||
"/iam/group", json={"name": "New Group", "organisation_id": 2}
|
"/iam/group", json={"name": "New Group", "organisation_id": 2}
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient):
|
async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.put(
|
resp = await no_su_client.put(
|
||||||
"/iam/group/permission",
|
"/iam/group/permission",
|
||||||
json={"permission_id": 1, "group_id": 2, "organisation_id": 2},
|
json={"permission_id": 1, "group_id": 2, "organisation_id": 2},
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_put_iam_group_user_auth_root(
|
async def test_put_iam_group_user_auth_root(
|
||||||
no_su_client: AsyncClient,
|
no_su_client: AsyncClient,
|
||||||
):
|
):
|
||||||
resp = await no_su_client.put(
|
resp = await no_su_client.put(
|
||||||
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2}
|
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2}
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient):
|
async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/iam/permissions?org_id=2")
|
resp = await no_su_client.get("/iam/permissions?org_id=2")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient):
|
async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.post(
|
resp = await no_su_client.post(
|
||||||
"/iam/permissions/search", json={"organisation_id": 2, "action": "read"}
|
"/iam/permissions/search", json={"organisation_id": 2, "action": "read"}
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be the org's root user" in resp.json()["detail"]
|
assert "Must be the org's root user" in resp.json()["detail"]
|
||||||
|
|
|
||||||
|
|
@ -7,69 +7,69 @@ import pytest
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
pytestmark = [
|
pytestmark = [
|
||||||
pytest.mark.auth,
|
pytest.mark.auth,
|
||||||
pytest.mark.super_admin,
|
pytest.mark.super_admin,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_user_auth_su(no_su_client: AsyncClient):
|
async def test_get_user_auth_su(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.get("/user?user_id=1")
|
resp = await no_su_client.get("/user?user_id=1")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert resp.json()["detail"] == "Must be super admin"
|
assert resp.json()["detail"] == "Must be super admin"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_status_auth_su(no_su_client: AsyncClient):
|
async def test_patch_org_status_auth_su(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.patch(
|
resp = await no_su_client.patch(
|
||||||
"/org/status", json={"organisation_id": 1, "status": "submitted"}
|
"/org/status", json={"organisation_id": 1, "status": "submitted"}
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert resp.json()["detail"] == "Must be super admin"
|
assert resp.json()["detail"] == "Must be super admin"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient):
|
async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.patch(
|
resp = await no_su_client.patch(
|
||||||
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
|
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert resp.json()["detail"] == "Must be super admin"
|
assert resp.json()["detail"] == "Must be super admin"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_service_key_auth_su(no_su_client: AsyncClient):
|
async def test_patch_service_key_auth_su(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.patch("/service/key", json={"service_id": 1})
|
resp = await no_su_client.patch("/service/key", json={"service_id": 1})
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert resp.json()["detail"] == "Must be super admin"
|
assert resp.json()["detail"] == "Must be super admin"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_service_auth_su(no_su_client: AsyncClient):
|
async def test_post_service_auth_su(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.post("/service", json={"name": "New Test Service"})
|
resp = await no_su_client.post("/service", json={"name": "New Test Service"})
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert resp.json()["detail"] == "Must be super admin"
|
assert resp.json()["detail"] == "Must be super admin"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_perm_auth_su(no_su_client: AsyncClient):
|
async def test_post_perm_auth_su(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.post(
|
resp = await no_su_client.post(
|
||||||
"/iam/permission",
|
"/iam/permission",
|
||||||
json={"service_id": 1, "resource": "test_resource", "action": "create"},
|
json={"service_id": 1, "resource": "test_resource", "action": "create"},
|
||||||
)
|
)
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert resp.json()["detail"] == "Must be super admin"
|
assert resp.json()["detail"] == "Must be super admin"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_org_user_auth_su(no_su_client: AsyncClient):
|
async def test_post_org_user_auth_su(no_su_client: AsyncClient):
|
||||||
resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2})
|
resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2})
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
assert "Must be super admin" in resp.json()["detail"]
|
assert "Must be super admin" in resp.json()["detail"]
|
||||||
|
|
|
||||||
|
|
@ -7,22 +7,22 @@ from httpx import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
pytestmark = [
|
pytestmark = [
|
||||||
pytest.mark.auth,
|
pytest.mark.auth,
|
||||||
pytest.mark.user,
|
pytest.mark.user,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_self_db_auth_user(no_user_client: AsyncClient):
|
async def test_get_self_db_auth_user(no_user_client: AsyncClient):
|
||||||
resp = await no_user_client.get("/user/self/db")
|
resp = await no_user_client.get("/user/self/db")
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 401
|
assert resp.status_code == 401
|
||||||
assert resp.json()["detail"] == "Not authenticated"
|
assert resp.json()["detail"] == "Not authenticated"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_org_success_auth_user(no_user_client: AsyncClient):
|
async def test_post_org_success_auth_user(no_user_client: AsyncClient):
|
||||||
resp = await no_user_client.post("/org", json={"name": "New Test Org"})
|
resp = await no_user_client.post("/org", json={"name": "New Test Org"})
|
||||||
assert resp.status_code != 422
|
assert resp.status_code != 422
|
||||||
assert resp.status_code == 401
|
assert resp.status_code == 401
|
||||||
assert resp.json()["detail"] == "Not authenticated"
|
assert resp.json()["detail"] == "Not authenticated"
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from httpx import AsyncClient
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_healthcheck(default_client: AsyncClient):
|
async def test_healthcheck(default_client: AsyncClient):
|
||||||
resp = await default_client.get("/healthcheck")
|
resp = await default_client.get("/healthcheck")
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert resp.json() == {"status": "ok"}
|
assert resp.json() == {"status": "ok"}
|
||||||
|
|
|
||||||
1000
test/test_iam.py
1000
test/test_iam.py
File diff suppressed because it is too large
Load diff
|
|
@ -9,506 +9,506 @@ from .conftest import generate_query_and_status
|
||||||
|
|
||||||
|
|
||||||
pytestmark = [
|
pytestmark = [
|
||||||
pytest.mark.org_module,
|
pytest.mark.org_module,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_success(default_client: AsyncClient):
|
async def test_get_org_success(default_client: AsyncClient):
|
||||||
resp = await default_client.get("/org?org_id=1")
|
resp = await default_client.get("/org?org_id=1")
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
org = data["organisations"][0]
|
org = data["organisations"][0]
|
||||||
|
|
||||||
assert isinstance(org, dict)
|
assert isinstance(org, dict)
|
||||||
assert org["organisation_id"] == 1
|
assert org["organisation_id"] == 1
|
||||||
assert org["name"] == "Org One"
|
assert org["name"] == "Org One"
|
||||||
assert org["status"] == "approved"
|
assert org["status"] == "approved"
|
||||||
assert org["root_user_email"] == "admin@test.com"
|
assert org["root_user_email"] == "admin@test.com"
|
||||||
assert "intake_questionnaire" in org
|
assert "intake_questionnaire" in org
|
||||||
assert isinstance(org["intake_questionnaire"], dict)
|
assert isinstance(org["intake_questionnaire"], dict)
|
||||||
|
|
||||||
assert org["billing_contact"]["email"] == "billing@orgone.com"
|
assert org["billing_contact"]["email"] == "billing@orgone.com"
|
||||||
assert org["owner_contact"]["email"] == "owner@orgone.com"
|
assert org["owner_contact"]["email"] == "owner@orgone.com"
|
||||||
assert org["security_contact"]["email"] == "security@orgone.com"
|
assert org["security_contact"]["email"] == "security@orgone.com"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
|
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_status_checks(
|
async def test_get_org_status_checks(
|
||||||
default_client: AsyncClient, query: str, expected_status: int
|
default_client: AsyncClient, query: str, expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.get(f"/org?{query}")
|
resp = await default_client.get(f"/org?{query}")
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_org_success(default_client: AsyncClient):
|
async def test_post_org_success(default_client: AsyncClient):
|
||||||
resp = await default_client.post("/org", json={"name": "New Test Org"})
|
resp = await default_client.post("/org", json={"name": "New Test Org"})
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert resp.status_code == 201
|
assert resp.status_code == 201
|
||||||
assert data["name"] == "New Test Org"
|
assert data["name"] == "New Test Org"
|
||||||
assert data["status"] == "partial"
|
assert data["status"] == "partial"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"body, expected_status",
|
"body, expected_status",
|
||||||
[
|
[
|
||||||
({"name": 42}, 422),
|
({"name": 42}, 422),
|
||||||
({}, 422),
|
({}, 422),
|
||||||
({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422),
|
({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_org_status_checks(
|
async def test_post_org_status_checks(
|
||||||
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.post("/org", json=body)
|
resp = await default_client.post("/org", json=body)
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient):
|
async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient):
|
||||||
resp = await default_client.patch(
|
resp = await default_client.patch(
|
||||||
"/org/questionnaire",
|
"/org/questionnaire",
|
||||||
json={
|
json={
|
||||||
"organisation_id": 3,
|
"organisation_id": 3,
|
||||||
"intake_questionnaire": {
|
"intake_questionnaire": {
|
||||||
"question_one": "new answer one",
|
"question_one": "new answer one",
|
||||||
"question_two": None,
|
"question_two": None,
|
||||||
"question_three": None,
|
"question_three": None,
|
||||||
},
|
},
|
||||||
"partial": True,
|
"partial": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert data["name"] == "Org Three"
|
assert data["name"] == "Org Three"
|
||||||
assert data["status"] == "partial"
|
assert data["status"] == "partial"
|
||||||
assert "intake_questionnaire" in data
|
assert "intake_questionnaire" in data
|
||||||
assert isinstance(data["intake_questionnaire"], dict)
|
assert isinstance(data["intake_questionnaire"], dict)
|
||||||
metadata = data["intake_questionnaire"]["metadata"]
|
metadata = data["intake_questionnaire"]["metadata"]
|
||||||
assert metadata["version"] == 0
|
assert metadata["version"] == 0
|
||||||
assert metadata["submission_date"] is None
|
assert metadata["submission_date"] is None
|
||||||
questions = data["intake_questionnaire"]["questions"]
|
questions = data["intake_questionnaire"]["questions"]
|
||||||
assert questions["question_one"] == "new answer one"
|
assert questions["question_one"] == "new answer one"
|
||||||
assert questions["question_two"] == "answer two"
|
assert questions["question_two"] == "answer two"
|
||||||
assert questions["question_three"] is None
|
assert questions["question_three"] is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"body, expected_status",
|
"body, expected_status",
|
||||||
[
|
[
|
||||||
({"organisation_id": 42}, 404),
|
({"organisation_id": 42}, 404),
|
||||||
({"organisation_id": "Org One"}, 422),
|
({"organisation_id": "Org One"}, 422),
|
||||||
({"organisation_id": ""}, 422),
|
({"organisation_id": ""}, 422),
|
||||||
({}, 422),
|
({}, 422),
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
"organisation_id": "1",
|
"organisation_id": "1",
|
||||||
"intake_questionnaire": {"question_one": 42},
|
"intake_questionnaire": {"question_one": 42},
|
||||||
"partial": True,
|
"partial": True,
|
||||||
},
|
},
|
||||||
422,
|
422,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
{"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}},
|
{"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}},
|
||||||
422,
|
422,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
"organisation_id": "1",
|
"organisation_id": "1",
|
||||||
"intake_questionnaire": {"question_one": "valid"},
|
"intake_questionnaire": {"question_one": "valid"},
|
||||||
"partial": 42,
|
"partial": 42,
|
||||||
},
|
},
|
||||||
422,
|
422,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_questionnaire_partial_status_checks(
|
async def test_patch_questionnaire_partial_status_checks(
|
||||||
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.patch("/org/questionnaire", json=body)
|
resp = await default_client.patch("/org/questionnaire", json=body)
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient):
|
async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient):
|
||||||
resp = await default_client.patch(
|
resp = await default_client.patch(
|
||||||
"/org/questionnaire",
|
"/org/questionnaire",
|
||||||
json={
|
json={
|
||||||
"organisation_id": 3,
|
"organisation_id": 3,
|
||||||
"intake_questionnaire": {
|
"intake_questionnaire": {
|
||||||
"question_one": "new answer one",
|
"question_one": "new answer one",
|
||||||
"question_two": None,
|
"question_two": None,
|
||||||
"question_three": None,
|
"question_three": None,
|
||||||
},
|
},
|
||||||
"partial": False,
|
"partial": False,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert data["name"] == "Org Three"
|
assert data["name"] == "Org Three"
|
||||||
assert data["status"] == "submitted"
|
assert data["status"] == "submitted"
|
||||||
assert "intake_questionnaire" in data
|
assert "intake_questionnaire" in data
|
||||||
assert isinstance(data["intake_questionnaire"], dict)
|
assert isinstance(data["intake_questionnaire"], dict)
|
||||||
metadata = data["intake_questionnaire"]["metadata"]
|
metadata = data["intake_questionnaire"]["metadata"]
|
||||||
assert metadata["version"] == 0
|
assert metadata["version"] == 0
|
||||||
assert metadata["submission_date"] is not None
|
assert metadata["submission_date"] is not None
|
||||||
questions = data["intake_questionnaire"]["questions"]
|
questions = data["intake_questionnaire"]["questions"]
|
||||||
assert questions["question_one"] == "new answer one"
|
assert questions["question_one"] == "new answer one"
|
||||||
assert questions["question_two"] == "answer two"
|
assert questions["question_two"] == "answer two"
|
||||||
assert questions["question_three"] is None
|
assert questions["question_three"] is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"]
|
"status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"]
|
||||||
)
|
)
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_status_success(default_client: AsyncClient, status: str):
|
async def test_patch_org_status_success(default_client: AsyncClient, status: str):
|
||||||
resp = await default_client.patch(
|
resp = await default_client.patch(
|
||||||
"/org/status", json={"organisation_id": 1, "status": status}
|
"/org/status", json={"organisation_id": 1, "status": status}
|
||||||
)
|
)
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert data["name"] == "Org One"
|
assert data["name"] == "Org One"
|
||||||
assert data["status"] == status
|
assert data["status"] == status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"body, expected_status",
|
"body, expected_status",
|
||||||
[
|
[
|
||||||
({"organisation_id": 42}, 404),
|
({"organisation_id": 42}, 404),
|
||||||
({"organisation_id": "Org One"}, 422),
|
({"organisation_id": "Org One"}, 422),
|
||||||
({"organisation_id": ""}, 422),
|
({"organisation_id": ""}, 422),
|
||||||
({}, 422),
|
({}, 422),
|
||||||
({"organisation_id": "1", "status": True}, 422),
|
({"organisation_id": "1", "status": True}, 422),
|
||||||
({"organisation_id": "1", "status": 42}, 422),
|
({"organisation_id": "1", "status": 42}, 422),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_status_status_checks(
|
async def test_patch_org_status_status_checks(
|
||||||
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.patch("/org/status", json=body)
|
resp = await default_client.patch("/org/status", json=body)
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_users_success(default_client: AsyncClient):
|
async def test_get_org_users_success(default_client: AsyncClient):
|
||||||
resp = await default_client.get("/org/users?org_id=1")
|
resp = await default_client.get("/org/users?org_id=1")
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
assert "users" in data
|
assert "users" in data
|
||||||
assert isinstance(data["users"], list)
|
assert isinstance(data["users"], list)
|
||||||
assert len(data["users"]) == 2
|
assert len(data["users"]) == 2
|
||||||
|
|
||||||
user = data["users"][0]
|
user = data["users"][0]
|
||||||
assert isinstance(user, dict)
|
assert isinstance(user, dict)
|
||||||
assert user["email"] == "admin@test.com"
|
assert user["email"] == "admin@test.com"
|
||||||
assert user["id"] == 1
|
assert user["id"] == 1
|
||||||
|
|
||||||
assert "organisation" in data
|
assert "organisation" in data
|
||||||
assert data["organisation"]["name"] == "Org One"
|
assert data["organisation"]["name"] == "Org One"
|
||||||
assert data["organisation"]["id"] == 1
|
assert data["organisation"]["id"] == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
|
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_users_status_checks(
|
async def test_get_org_users_status_checks(
|
||||||
default_client: AsyncClient, query: str, expected_status: int
|
default_client: AsyncClient, query: str, expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.get(f"/org/users?{query}")
|
resp = await default_client.get(f"/org/users?{query}")
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_org_user_success(default_client: AsyncClient):
|
async def test_post_org_user_success(default_client: AsyncClient):
|
||||||
resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3})
|
resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3})
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert "organisation" in data
|
assert "organisation" in data
|
||||||
assert isinstance(data["organisation"], dict)
|
assert isinstance(data["organisation"], dict)
|
||||||
assert data["organisation"]["id"] == 1
|
assert data["organisation"]["id"] == 1
|
||||||
assert data["organisation"]["name"] == "Org One"
|
assert data["organisation"]["name"] == "Org One"
|
||||||
|
|
||||||
assert "users" in data
|
assert "users" in data
|
||||||
assert isinstance(data["users"], list)
|
assert isinstance(data["users"], list)
|
||||||
assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1
|
assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"body, expected_status",
|
"body, expected_status",
|
||||||
[
|
[
|
||||||
({"organisation_id": 42}, 404),
|
({"organisation_id": 42}, 404),
|
||||||
({}, 422),
|
({}, 422),
|
||||||
({"organisation_id": 1, "user_id": "id"}, 422),
|
({"organisation_id": 1, "user_id": "id"}, 422),
|
||||||
({"user_id": 2}, 422),
|
({"user_id": 2}, 422),
|
||||||
({"organisation_id": 1, "user_id": 42}, 404),
|
({"organisation_id": 1, "user_id": 42}, 404),
|
||||||
({"organisation_id": 1, "user_id": 1}, 409),
|
({"organisation_id": 1, "user_id": 1}, 409),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_org_user_status_checks(
|
async def test_post_org_user_status_checks(
|
||||||
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.post("/org/user", json=body)
|
resp = await default_client.post("/org/user", json=body)
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_root_user_success(default_client: AsyncClient):
|
async def test_patch_org_root_user_success(default_client: AsyncClient):
|
||||||
resp = await default_client.patch(
|
resp = await default_client.patch(
|
||||||
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
|
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
|
||||||
)
|
)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert data["name"] == "Org One"
|
assert data["name"] == "Org One"
|
||||||
assert data["root_user_email"] == "user@orgone.com"
|
assert data["root_user_email"] == "user@orgone.com"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"body, expected_status",
|
"body, expected_status",
|
||||||
[
|
[
|
||||||
({"organisation_id": 42, "user_id": 2}, 404),
|
({"organisation_id": 42, "user_id": 2}, 404),
|
||||||
({"organisation_id": "Org One", "user_id": 2}, 422),
|
({"organisation_id": "Org One", "user_id": 2}, 422),
|
||||||
({"organisation_id": "", "user_id": 2}, 422),
|
({"organisation_id": "", "user_id": 2}, 422),
|
||||||
({}, 422),
|
({}, 422),
|
||||||
({"user_id": 2}, 422),
|
({"user_id": 2}, 422),
|
||||||
({"user_id": 42}, 404),
|
({"user_id": 42}, 404),
|
||||||
({"organisation_id": 1, "user_id": "Test User"}, 422),
|
({"organisation_id": 1, "user_id": "Test User"}, 422),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_root_user_status_checks(
|
async def test_patch_root_user_status_checks(
|
||||||
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.patch("/org/root_user", json=body)
|
resp = await default_client.patch("/org/root_user", json=body)
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_root_user_non_member(default_client: AsyncClient):
|
async def test_patch_org_root_user_non_member(default_client: AsyncClient):
|
||||||
resp = await default_client.patch(
|
resp = await default_client.patch(
|
||||||
"/org/root_user", json={"organisation_id": 1, "user_id": 3}
|
"/org/root_user", json={"organisation_id": 1, "user_id": 3}
|
||||||
)
|
)
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert resp.status_code == 422
|
assert resp.status_code == 422
|
||||||
assert data["detail"] == "This user does not belong to your organisation."
|
assert data["detail"] == "This user does not belong to your organisation."
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_groups_success(default_client: AsyncClient):
|
async def test_get_org_groups_success(default_client: AsyncClient):
|
||||||
resp = await default_client.get("/org/groups?org_id=1")
|
resp = await default_client.get("/org/groups?org_id=1")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert "organisation" in data
|
assert "organisation" in data
|
||||||
assert isinstance(data["organisation"], dict)
|
assert isinstance(data["organisation"], dict)
|
||||||
assert data["organisation"]["id"] == 1
|
assert data["organisation"]["id"] == 1
|
||||||
assert data["organisation"]["name"] == "Org One"
|
assert data["organisation"]["name"] == "Org One"
|
||||||
|
|
||||||
assert "groups" in data
|
assert "groups" in data
|
||||||
assert isinstance(data["groups"], list)
|
assert isinstance(data["groups"], list)
|
||||||
group = data["groups"][0]
|
group = data["groups"][0]
|
||||||
assert isinstance(group, dict)
|
assert isinstance(group, dict)
|
||||||
assert group["id"] == 1
|
assert group["id"] == 1
|
||||||
assert group["name"] == "Org One Group"
|
assert group["name"] == "Org One Group"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
|
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_groups_status_checks(
|
async def test_get_org_groups_status_checks(
|
||||||
default_client: AsyncClient, query: str, expected_status: int
|
default_client: AsyncClient, query: str, expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.get(f"/org/groups?{query}")
|
resp = await default_client.get(f"/org/groups?{query}")
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("contact_type", ["billing", "security", "owner"])
|
@pytest.mark.parametrize("contact_type", ["billing", "security", "owner"])
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str):
|
async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str):
|
||||||
resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}")
|
resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}")
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
assert "organisation" in data
|
assert "organisation" in data
|
||||||
assert data["organisation"]["id"] == 1
|
assert data["organisation"]["id"] == 1
|
||||||
assert data["organisation"]["name"] == "Org One"
|
assert data["organisation"]["name"] == "Org One"
|
||||||
|
|
||||||
attributes = [
|
attributes = [
|
||||||
"email",
|
"email",
|
||||||
"first_name",
|
"first_name",
|
||||||
"last_name",
|
"last_name",
|
||||||
"phonenumber",
|
"phonenumber",
|
||||||
"vat_number",
|
"vat_number",
|
||||||
"address",
|
"address",
|
||||||
]
|
]
|
||||||
|
|
||||||
for attribute in attributes:
|
for attribute in attributes:
|
||||||
assert attribute in data["contact"]
|
assert attribute in data["contact"]
|
||||||
|
|
||||||
address_attributes = [
|
address_attributes = [
|
||||||
"post_office_box_number",
|
"post_office_box_number",
|
||||||
"street_address",
|
"street_address",
|
||||||
"street_address_line_2",
|
"street_address_line_2",
|
||||||
"locality",
|
"locality",
|
||||||
"address_region",
|
"address_region",
|
||||||
"country_code",
|
"country_code",
|
||||||
"postal_code",
|
"postal_code",
|
||||||
]
|
]
|
||||||
|
|
||||||
for attribute in address_attributes:
|
for attribute in address_attributes:
|
||||||
assert attribute in data["contact"]["address"]
|
assert attribute in data["contact"]["address"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"query, expected_status",
|
"query, expected_status",
|
||||||
[
|
[
|
||||||
("org_id=42&contact_type=billing", 404),
|
("org_id=42&contact_type=billing", 404),
|
||||||
("org_id=banana&contact_type=billing", 422),
|
("org_id=banana&contact_type=billing", 422),
|
||||||
("", 422),
|
("", 422),
|
||||||
("org_id=1&contact_type=contact", 422),
|
("org_id=1&contact_type=contact", 422),
|
||||||
("contact_type=billing", 422),
|
("contact_type=billing", 422),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_org_contact_status_checks(
|
async def test_get_org_contact_status_checks(
|
||||||
default_client: AsyncClient, query: str, expected_status: int
|
default_client: AsyncClient, query: str, expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.get(f"/org/contact?{query}")
|
resp = await default_client.get(f"/org/contact?{query}")
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"key, value",
|
"key, value",
|
||||||
[
|
[
|
||||||
("email", "user@example.com"),
|
("email", "user@example.com"),
|
||||||
("first_name", "John"),
|
("first_name", "John"),
|
||||||
("last_name", "Doe"),
|
("last_name", "Doe"),
|
||||||
("phonenumber", "+441234567890"),
|
("phonenumber", "+441234567890"),
|
||||||
("vat_number", "GB123456789"),
|
("vat_number", "GB123456789"),
|
||||||
("post_office_box_number", "PO Box 123"),
|
("post_office_box_number", "PO Box 123"),
|
||||||
("street_address", "123 Example Street"),
|
("street_address", "123 Example Street"),
|
||||||
("street_address_line_2", "Suite 4B"),
|
("street_address_line_2", "Suite 4B"),
|
||||||
("locality", "Glasgow"),
|
("locality", "Glasgow"),
|
||||||
("address_region", "Glasgow City"),
|
("address_region", "Glasgow City"),
|
||||||
("country_code", "GB"),
|
("country_code", "GB"),
|
||||||
("postal_code", "G1 1AA"),
|
("postal_code", "G1 1AA"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str):
|
async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str):
|
||||||
resp = await default_client.patch(
|
resp = await default_client.patch(
|
||||||
"/org/contact",
|
"/org/contact",
|
||||||
json={"organisation_id": 1, "contact_type": "billing", key: value},
|
json={"organisation_id": 1, "contact_type": "billing", key: value},
|
||||||
)
|
)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert "organisation" in data
|
assert "organisation" in data
|
||||||
assert data["organisation"]["id"] == 1
|
assert data["organisation"]["id"] == 1
|
||||||
assert data["organisation"]["name"] == "Org One"
|
assert data["organisation"]["name"] == "Org One"
|
||||||
|
|
||||||
attributes = [
|
attributes = [
|
||||||
"email",
|
"email",
|
||||||
"first_name",
|
"first_name",
|
||||||
"last_name",
|
"last_name",
|
||||||
"phonenumber",
|
"phonenumber",
|
||||||
"vat_number",
|
"vat_number",
|
||||||
"address",
|
"address",
|
||||||
]
|
]
|
||||||
|
|
||||||
for attribute in attributes:
|
for attribute in attributes:
|
||||||
assert attribute in data["contact"]
|
assert attribute in data["contact"]
|
||||||
|
|
||||||
address_attributes = [
|
address_attributes = [
|
||||||
"post_office_box_number",
|
"post_office_box_number",
|
||||||
"street_address",
|
"street_address",
|
||||||
"street_address_line_2",
|
"street_address_line_2",
|
||||||
"locality",
|
"locality",
|
||||||
"address_region",
|
"address_region",
|
||||||
"country_code",
|
"country_code",
|
||||||
"postal_code",
|
"postal_code",
|
||||||
]
|
]
|
||||||
|
|
||||||
for attribute in address_attributes:
|
for attribute in address_attributes:
|
||||||
assert attribute in data["contact"]["address"]
|
assert attribute in data["contact"]["address"]
|
||||||
|
|
||||||
if key in data["contact"]:
|
if key in data["contact"]:
|
||||||
assert data["contact"][key] == value
|
assert data["contact"][key] == value
|
||||||
elif key in data["contact"]["address"]:
|
elif key in data["contact"]["address"]:
|
||||||
assert data["contact"]["address"][key] == value
|
assert data["contact"]["address"][key] == value
|
||||||
else:
|
else:
|
||||||
pytest.fail(f"Invalid contact key: {key}")
|
pytest.fail(f"Invalid contact key: {key}")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"body, expected_status",
|
"body, expected_status",
|
||||||
[
|
[
|
||||||
({"organisation_id": 42, "contact_type": "billing"}, 404),
|
({"organisation_id": 42, "contact_type": "billing"}, 404),
|
||||||
({"organisation_id": 1, "contact_type": "security"}, 200),
|
({"organisation_id": 1, "contact_type": "security"}, 200),
|
||||||
({"organisation_id": 1, "contact_type": "owner"}, 200),
|
({"organisation_id": 1, "contact_type": "owner"}, 200),
|
||||||
({"organisation_id": "Org One", "contact_type": "billing"}, 422),
|
({"organisation_id": "Org One", "contact_type": "billing"}, 422),
|
||||||
({"organisation_id": "", "contact_type": "billing"}, 422),
|
({"organisation_id": "", "contact_type": "billing"}, 422),
|
||||||
({}, 422),
|
({}, 422),
|
||||||
({"organisation_id": 1, "contact_type": "not_real"}, 422),
|
({"organisation_id": 1, "contact_type": "not_real"}, 422),
|
||||||
({"organisation_id": 1, "contact_type": 42}, 422),
|
({"organisation_id": 1, "contact_type": 42}, 422),
|
||||||
({"organisation_id": 1, "contact_type": ""}, 422),
|
({"organisation_id": 1, "contact_type": ""}, 422),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_org_contact_status_checks(
|
async def test_patch_org_contact_status_checks(
|
||||||
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.patch("/org/contact", json=body)
|
resp = await default_client.patch("/org/contact", json=body)
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_delete_org_success(default_client: AsyncClient):
|
async def test_delete_org_success(default_client: AsyncClient):
|
||||||
resp = await default_client.delete("/org?org_id=1")
|
resp = await default_client.delete("/org?org_id=1")
|
||||||
|
|
||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_delete_org_users_success(default_client: AsyncClient):
|
async def test_delete_org_users_success(default_client: AsyncClient):
|
||||||
resp = await default_client.delete("/org/user?org_id=1&user_id=2")
|
resp = await default_client.delete("/org/user?org_id=1&user_id=2")
|
||||||
|
|
||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_delete_preapproval_org_success(default_client: AsyncClient):
|
async def test_delete_preapproval_org_success(default_client: AsyncClient):
|
||||||
resp = await default_client.delete("/org/self?org_id=3")
|
resp = await default_client.delete("/org/self?org_id=3")
|
||||||
|
|
||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
|
||||||
|
|
@ -8,90 +8,90 @@ from httpx import AsyncClient
|
||||||
from .conftest import generate_query_and_status, generate_body_and_status
|
from .conftest import generate_query_and_status, generate_body_and_status
|
||||||
|
|
||||||
pytestmark = [
|
pytestmark = [
|
||||||
pytest.mark.service_module,
|
pytest.mark.service_module,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_services_success(default_client: AsyncClient):
|
async def test_get_services_success(default_client: AsyncClient):
|
||||||
resp = await default_client.get("/service?org_id=1")
|
resp = await default_client.get("/service?org_id=1")
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert "services" in data
|
assert "services" in data
|
||||||
assert isinstance(data["services"], list)
|
assert isinstance(data["services"], list)
|
||||||
assert data["services"][0]["id"] == 1
|
assert data["services"][0]["id"] == 1
|
||||||
assert data["services"][0]["name"] == "Test Service"
|
assert data["services"][0]["name"] == "Test Service"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
|
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_services_status_checks(
|
async def test_get_services_status_checks(
|
||||||
default_client: AsyncClient, query: str, expected_status: int
|
default_client: AsyncClient, query: str, expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.get(f"/service?{query}")
|
resp = await default_client.get(f"/service?{query}")
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_service_success(default_client: AsyncClient):
|
async def test_post_service_success(default_client: AsyncClient):
|
||||||
resp = await default_client.post("/service", json={"name": "New Test Service"})
|
resp = await default_client.post("/service", json={"name": "New Test Service"})
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert "service" in data
|
assert "service" in data
|
||||||
assert isinstance(data["service"], dict)
|
assert isinstance(data["service"], dict)
|
||||||
assert data["service"]["name"] == "New Test Service"
|
assert data["service"]["name"] == "New Test Service"
|
||||||
assert data["service"]["id"] == 2
|
assert data["service"]["id"] == 2
|
||||||
assert isinstance(data["service"]["api_key"], str)
|
assert isinstance(data["service"]["api_key"], str)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("body, expected_status", generate_body_and_status({"name": "str"}))
|
@pytest.mark.parametrize("body, expected_status", generate_body_and_status({"name": "str"}))
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_service_status_checks(
|
async def test_post_service_status_checks(
|
||||||
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.post("/service", json=body)
|
resp = await default_client.post("/service", json=body)
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_service_conflict(default_client: AsyncClient):
|
async def test_post_service_conflict(default_client: AsyncClient):
|
||||||
resp = await default_client.post("/service", json={"name": "Test Service"})
|
resp = await default_client.post("/service", json={"name": "Test Service"})
|
||||||
|
|
||||||
assert resp.status_code == 409
|
assert resp.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_service_success(default_client: AsyncClient):
|
async def test_patch_service_success(default_client: AsyncClient):
|
||||||
resp = await default_client.patch("/service/key", json={"service_id": 1})
|
resp = await default_client.patch("/service/key", json={"service_id": 1})
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert "service" in data
|
assert "service" in data
|
||||||
assert isinstance(data["service"], dict)
|
assert isinstance(data["service"], dict)
|
||||||
assert data["service"]["name"] == "Test Service"
|
assert data["service"]["name"] == "Test Service"
|
||||||
assert data["service"]["id"] == 1
|
assert data["service"]["id"] == 1
|
||||||
assert isinstance(data["service"]["api_key"], str)
|
assert isinstance(data["service"]["api_key"], str)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"body, expected_status",
|
"body, expected_status",
|
||||||
generate_body_and_status({"service_id": "int"}),
|
generate_body_and_status({"service_id": "int"}),
|
||||||
)
|
)
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_patch_services_status_checks(
|
async def test_patch_services_status_checks(
|
||||||
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
default_client: AsyncClient, body: dict[str, str], expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.patch("/service/key", json=body)
|
resp = await default_client.patch("/service/key", json=body)
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_delete_service_success(default_client: AsyncClient):
|
async def test_delete_service_success(default_client: AsyncClient):
|
||||||
resp = await default_client.delete("/service?service_id=1")
|
resp = await default_client.delete("/service?service_id=1")
|
||||||
|
|
||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
|
||||||
|
|
@ -5,202 +5,197 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
from fastapi.routing import APIRoute, iter_route_contexts
|
from fastapi.routing import APIRoute
|
||||||
|
|
||||||
from .conftest import generate_query_and_status
|
from .conftest import generate_query_and_status
|
||||||
|
|
||||||
|
|
||||||
pytestmark = [
|
pytestmark = [
|
||||||
pytest.mark.user_module,
|
pytest.mark.user_module,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_self_db_success(default_client: AsyncClient):
|
async def test_get_self_db_success(default_client: AsyncClient):
|
||||||
resp = await default_client.get("/user/self/db")
|
resp = await default_client.get("/user/self/db")
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert data["first_name"] == "Admin"
|
assert data["first_name"] == "Admin"
|
||||||
assert data["last_name"] == "Test"
|
assert data["last_name"] == "Test"
|
||||||
assert data["email"] == "admin@test.com"
|
assert data["email"] == "admin@test.com"
|
||||||
assert "organisations" in data
|
assert "organisations" in data
|
||||||
assert isinstance(data["organisations"], list)
|
assert isinstance(data["organisations"], list)
|
||||||
assert "groups" in data
|
assert "groups" in data
|
||||||
assert isinstance(data["groups"], dict)
|
assert isinstance(data["groups"], dict)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_user_success(default_client: AsyncClient):
|
async def test_get_user_success(default_client: AsyncClient):
|
||||||
resp = await default_client.get("/user?user_id=1")
|
resp = await default_client.get("/user?user_id=1")
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert data["first_name"] == "Admin"
|
assert data["first_name"] == "Admin"
|
||||||
assert data["last_name"] == "Test"
|
assert data["last_name"] == "Test"
|
||||||
assert data["email"] == "admin@test.com"
|
assert data["email"] == "admin@test.com"
|
||||||
assert "organisations" in data
|
assert "organisations" in data
|
||||||
assert isinstance(data["organisations"], list)
|
assert isinstance(data["organisations"], list)
|
||||||
assert "groups" in data
|
assert "groups" in data
|
||||||
assert isinstance(data["groups"], dict)
|
assert isinstance(data["groups"], dict)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["user_id"]))
|
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["user_id"]))
|
||||||
async def test_get_user_status_checks(
|
async def test_get_user_status_checks(
|
||||||
default_client: AsyncClient, query: str, expected_status: int
|
default_client: AsyncClient, query: str, expected_status: int
|
||||||
):
|
):
|
||||||
resp = await default_client.get(f"/user?{query}")
|
resp = await default_client.get(f"/user?{query}")
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_delete_user_success(default_client: AsyncClient):
|
async def test_delete_user_success(default_client: AsyncClient):
|
||||||
resp = await default_client.delete("/user?user_id=1")
|
resp = await default_client.delete("/user?user_id=1")
|
||||||
|
|
||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_user_invitation_success(default_client: AsyncClient):
|
async def test_post_user_invitation_success(default_client: AsyncClient):
|
||||||
body = {"user_email": "admin@test.com", "organisation_id": 1}
|
body = {"user_email": "admin@test.com", "organisation_id": 1}
|
||||||
resp = await default_client.post("/user/invitation", json=body)
|
resp = await default_client.post("/user/invitation", json=body)
|
||||||
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert "organisation" in data
|
assert "organisation" in data
|
||||||
assert isinstance(data["organisation"], dict)
|
assert isinstance(data["organisation"], dict)
|
||||||
assert data["organisation"]["id"] == 1
|
assert data["organisation"]["id"] == 1
|
||||||
assert data["organisation"]["name"] == "Org One"
|
assert data["organisation"]["name"] == "Org One"
|
||||||
|
|
||||||
assert "invited_email" in data
|
assert "invited_email" in data
|
||||||
assert isinstance(data["invited_email"], str)
|
assert isinstance(data["invited_email"], str)
|
||||||
assert data["invited_email"] == "admin@test.com"
|
assert data["invited_email"] == "admin@test.com"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"body, expected_status",
|
"body, expected_status",
|
||||||
[
|
[
|
||||||
({"organisation_id": 42, "user_email": "admin@test.com"}, 404),
|
({"organisation_id": 42, "user_email": "admin@test.com"}, 404),
|
||||||
({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422),
|
({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422),
|
||||||
({"organisation_id": "", "user_email": "admin@test.com"}, 422),
|
({"organisation_id": "", "user_email": "admin@test.com"}, 422),
|
||||||
({}, 422),
|
({}, 422),
|
||||||
({"user_email": 42}, 422),
|
({"user_email": 42}, 422),
|
||||||
({"organisation_id": 1, "user_email": "Test User"}, 422),
|
({"organisation_id": 1, "user_email": "Test User"}, 422),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_user_invitation_status_checks(
|
async def test_post_user_invitation_status_checks(
|
||||||
default_client: AsyncClient, body, expected_status
|
default_client: AsyncClient, body, expected_status
|
||||||
):
|
):
|
||||||
resp = await default_client.post("/user/invitation", json=body)
|
resp = await default_client.post("/user/invitation", json=body)
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"body, expected_status",
|
"body, expected_status",
|
||||||
[
|
[
|
||||||
({"jwt": "invalid"}, 401),
|
({"jwt": "invalid"}, 401),
|
||||||
({"jwt": ""}, 401),
|
({"jwt": ""}, 401),
|
||||||
({"jwt": None}, 422),
|
({"jwt": None}, 422),
|
||||||
({"jwt": 42}, 422),
|
({"jwt": 42}, 422),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_post_user_invitation_accept_status_checks(
|
async def test_post_user_invitation_accept_status_checks(
|
||||||
default_client: AsyncClient, body, expected_status
|
default_client: AsyncClient, body, expected_status
|
||||||
):
|
):
|
||||||
resp = await default_client.post("/user/invitation/accept", json=body)
|
resp = await default_client.post("/user/invitation/accept", json=body)
|
||||||
|
|
||||||
assert resp.status_code == expected_status
|
assert resp.status_code == expected_status
|
||||||
|
|
||||||
if resp.status_code == 401:
|
if resp.status_code == 401:
|
||||||
assert resp.json()["detail"] == "Invalid JWS"
|
assert resp.json()["detail"] == "Invalid JWS"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_self_orgs_success(default_client: AsyncClient):
|
async def test_get_self_orgs_success(default_client: AsyncClient):
|
||||||
resp = await default_client.get("/user/self/orgs")
|
resp = await default_client.get("/user/self/orgs")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
||||||
assert "organisations" in data
|
assert "organisations" in data
|
||||||
assert isinstance(data["organisations"], list)
|
assert isinstance(data["organisations"], list)
|
||||||
assert len(data["organisations"]) > 0
|
assert len(data["organisations"]) > 0
|
||||||
|
|
||||||
org = data["organisations"][0]
|
org = data["organisations"][0]
|
||||||
assert org["organisation_id"] == 1
|
assert org["organisation_id"] == 1
|
||||||
assert org["name"] == "Org One"
|
assert org["name"] == "Org One"
|
||||||
assert org["status"] == "approved"
|
assert org["status"] == "approved"
|
||||||
assert org["root_user_email"] == "admin@test.com"
|
assert org["root_user_email"] == "admin@test.com"
|
||||||
assert "intake_questionnaire" in org
|
assert "intake_questionnaire" in org
|
||||||
assert isinstance(org["intake_questionnaire"], dict)
|
assert isinstance(org["intake_questionnaire"], dict)
|
||||||
|
|
||||||
assert isinstance(org["billing_contact"], dict)
|
assert isinstance(org["billing_contact"], dict)
|
||||||
assert org["billing_contact"]["email"] == "billing@orgone.com"
|
assert org["billing_contact"]["email"] == "billing@orgone.com"
|
||||||
assert org["billing_contact"]["id"] == 1
|
assert org["billing_contact"]["id"] == 1
|
||||||
|
|
||||||
assert isinstance(org["owner_contact"], dict)
|
assert isinstance(org["owner_contact"], dict)
|
||||||
assert org["owner_contact"]["email"] == "owner@orgone.com"
|
assert org["owner_contact"]["email"] == "owner@orgone.com"
|
||||||
assert org["owner_contact"]["id"] == 2
|
assert org["owner_contact"]["id"] == 2
|
||||||
|
|
||||||
assert isinstance(org["security_contact"], dict)
|
assert isinstance(org["security_contact"], dict)
|
||||||
assert org["security_contact"]["email"] == "security@orgone.com"
|
assert org["security_contact"]["email"] == "security@orgone.com"
|
||||||
assert org["security_contact"]["id"] == 3
|
assert org["security_contact"]["id"] == 3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_self_orgs_dynamic(default_client: AsyncClient):
|
async def test_get_self_orgs_dynamic(default_client: AsyncClient):
|
||||||
method = "GET"
|
method = "GET"
|
||||||
path = "/user/self/orgs"
|
path = "/user/self/orgs"
|
||||||
expected_data = {
|
expected_data = {
|
||||||
"organisations": [
|
"organisations": [
|
||||||
{
|
{
|
||||||
"organisation_id": 1,
|
"organisation_id": 1,
|
||||||
"name": "Org One",
|
"name": "Org One",
|
||||||
"status": "approved",
|
"status": "approved",
|
||||||
"root_user_email": "admin@test.com",
|
"root_user_email": "admin@test.com",
|
||||||
"owner_contact": {"email": "owner@orgone.com", "id": 2},
|
"owner_contact": {"email": "owner@orgone.com", "id": 2},
|
||||||
"security_contact": {"email": "security@orgone.com", "id": 3},
|
"security_contact": {"email": "security@orgone.com", "id": 3},
|
||||||
"billing_contact": {"email": "billing@orgone.com", "id": 1},
|
"billing_contact": {"email": "billing@orgone.com", "id": 1},
|
||||||
"intake_questionnaire": {
|
"intake_questionnaire": {
|
||||||
"questions": {
|
"questions": {
|
||||||
"question_one": None,
|
"question_one": None,
|
||||||
"question_three": None,
|
"question_three": None,
|
||||||
"question_two": "answer two",
|
"question_two": "answer two",
|
||||||
},
|
},
|
||||||
"metadata": {"version": 0, "submission_date": None},
|
"metadata": {"version": 0, "submission_date": None},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
resp = await default_client.get(path)
|
resp = await default_client.get(path)
|
||||||
|
|
||||||
contexts = list(iter_route_contexts(default_client._transport.app.routes)) # ty:ignore[unresolved-attribute]
|
route = next(
|
||||||
|
route
|
||||||
|
for route in default_client._transport.app.routes # ty:ignore[unresolved-attribute]
|
||||||
|
if isinstance(route, APIRoute) and path in route.path and method in route.methods
|
||||||
|
)
|
||||||
|
|
||||||
route = next(
|
assert resp.status_code == route.status_code
|
||||||
route.route
|
if route.status_code == 204:
|
||||||
for route in contexts
|
return
|
||||||
if isinstance(route.route, APIRoute)
|
|
||||||
and path in route.route.path
|
|
||||||
and isinstance(route.methods, set)
|
|
||||||
and method in route.methods
|
|
||||||
)
|
|
||||||
|
|
||||||
assert resp.status_code == route.status_code
|
expected_response_schema = route.response_model
|
||||||
if route.status_code == 204:
|
data = resp.json()
|
||||||
return
|
|
||||||
|
|
||||||
expected_response_schema = route.response_model
|
response_model = expected_response_schema(**data)
|
||||||
data = resp.json()
|
assert isinstance(response_model, expected_response_schema)
|
||||||
|
|
||||||
response_model = expected_response_schema(**data)
|
expected_response_model = expected_response_schema(**expected_data)
|
||||||
assert isinstance(response_model, expected_response_schema)
|
|
||||||
|
|
||||||
expected_response_model = expected_response_schema(**expected_data)
|
assert response_model == expected_response_model
|
||||||
|
|
||||||
assert response_model == expected_response_model
|
|
||||||
|
|
|
||||||
11
uv.lock
generated
11
uv.lock
generated
|
|
@ -6,9 +6,6 @@ requires-python = ">=3.12"
|
||||||
exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values.
|
exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values.
|
||||||
exclude-newer-span = "P2W"
|
exclude-newer-span = "P2W"
|
||||||
|
|
||||||
[options.exclude-newer-package]
|
|
||||||
fastapi = "2026-06-22T00:00:00Z"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "alembic"
|
name = "alembic"
|
||||||
version = "1.18.4"
|
version = "1.18.4"
|
||||||
|
|
@ -241,7 +238,7 @@ dev = [
|
||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "alembic", specifier = ">=1.18.4" },
|
{ name = "alembic", specifier = ">=1.18.4" },
|
||||||
{ name = "email-validator", specifier = ">=2.3.0" },
|
{ name = "email-validator", specifier = ">=2.3.0" },
|
||||||
{ name = "fastapi", specifier = ">=0.138.0" },
|
{ name = "fastapi", specifier = ">=0.136.3" },
|
||||||
{ name = "httptools", specifier = ">=0.7.1" },
|
{ name = "httptools", specifier = ">=0.7.1" },
|
||||||
{ name = "httpx", specifier = ">=0.28.1" },
|
{ name = "httpx", specifier = ">=0.28.1" },
|
||||||
{ name = "itsdangerous", specifier = ">=2.2.0" },
|
{ name = "itsdangerous", specifier = ">=2.2.0" },
|
||||||
|
|
@ -352,7 +349,7 @@ wheels = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fastapi"
|
name = "fastapi"
|
||||||
version = "0.138.0"
|
version = "0.136.3"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "annotated-doc" },
|
{ name = "annotated-doc" },
|
||||||
|
|
@ -361,9 +358,9 @@ dependencies = [
|
||||||
{ name = "typing-extensions" },
|
{ name = "typing-extensions" },
|
||||||
{ name = "typing-inspection" },
|
{ name = "typing-inspection" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/5b/58/ff455d9fe47c60abadb34b9e05a304b1f05f5ab8000ac01565156b6f5e43/fastapi-0.138.0.tar.gz", hash = "sha256:d445a4877636ad191e7053e08c9bf98cb921a6756776848400bb773d1740c061", size = 419240, upload-time = "2026-06-20T01:18:05.259Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/81/2d/ff8d91d7b564d464629a0fd50a4489c97fcb836ac230bf3a7269232a9b1f/fastapi-0.136.3.tar.gz", hash = "sha256:e487fae93ad408e6f47641ee4dfe389864fd7bec92e547ea8498fc13f43e83ab", size = 396410, upload-time = "2026-05-23T18:53:15.192Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/6c/ff/8496d9847a5fedae775eb49460722d3efaa80487854273e9647ae876218c/fastapi-0.138.0-py3-none-any.whl", hash = "sha256:b6f54fd1bd72c80b0f899f172c61a600f6f7af9b43d4d772a018f35624048cb0", size = 126779, upload-time = "2026-06-20T01:18:03.483Z" },
|
{ url = "https://files.pythonhosted.org/packages/e0/82/45359b62a067409bd929ae8a56b8ed13e5a8c8a61194b3c236920999ab83/fastapi-0.136.3-py3-none-any.whl", hash = "sha256:3d2a69bdf04b7e9f3afa292c3bc7a98816bbfafa10bc9b45f3f3700d2f761620", size = 117481, upload-time = "2026-05-23T18:53:16.924Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue