Compare commits

..

No commits in common. "c452c6c0d5421e337bc0628c6b20a03a3ea70147" and "da8917869f70c8ff8c9c1f218a23c7d54d2f1e24" have entirely different histories.

92 changed files with 714 additions and 1847 deletions

View file

@ -1,32 +0,0 @@
"""fix permission unique
Revision ID: b6c8614ef799
Revises: d9dc6986fe38
Create Date: 2026-06-08 16:00:27.533099
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b6c8614ef799'
down_revision: Union[str, Sequence[str], None] = 'd9dc6986fe38'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_unique_constraint('uniq_permission_resource_and_action', 'permission', ['service_id', 'resource', 'action'])
# ### end Alembic commands ###
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint('uniq_permission_resource_and_action', 'permission', type_='unique')
# ### end Alembic commands ###

View file

@ -7,10 +7,6 @@ exclude = [
".alembic" ".alembic"
] ]
[tool.ruff.format]
quote-style = "double"
indent-style = "tab"
[project] [project]
name = "cloud-api" name = "cloud-api"
version = "0.1.0" version = "0.1.0"

View file

@ -17,7 +17,6 @@ Exports:
- Dependencies should be used for db model get and validation where possible - Dependencies should be used for db model get and validation where possible
- Verify module level docstring is still accurate after updates - Verify module level docstring is still accurate after updates
""" """
from fastapi import APIRouter from fastapi import APIRouter

View file

@ -4,7 +4,6 @@ Router endpoints for the admin module
Exports: Exports:
- router: fastapi.APIRouter - router: fastapi.APIRouter
""" """
from fastapi import APIRouter from fastapi import APIRouter
router = APIRouter( router = APIRouter(

View file

@ -1,7 +1,6 @@
""" """
This module hooks the routers for the main endpoints into a single router for importing to the app. This module hooks the routers for the main endpoints into a single router for importing to the app.
""" """
from fastapi import APIRouter from fastapi import APIRouter
from src.auth.router import router as auth_router from src.auth.router import router as auth_router
@ -13,7 +12,9 @@ from src.iam.router import router as iam_router
from src.service.router import router as service_router from src.service.router import router as service_router
api_router = APIRouter(prefix="/api/v1") api_router = APIRouter(
prefix="/api/v1"
)
api_router.include_router(auth_router) api_router.include_router(auth_router)
api_router.include_router(contact_router) api_router.include_router(contact_router)

View file

@ -4,10 +4,8 @@ Configurations for the auth module
Exports: Exports:
- auth_settings: Contains OIDC information - auth_settings: Contains OIDC information
""" """
from src.config import CustomBaseSettings from src.config import CustomBaseSettings
class AuthConfig(CustomBaseSettings): class AuthConfig(CustomBaseSettings):
OIDC_CONFIG: str = "" OIDC_CONFIG: str = ""
OIDC_ISSUER: str = "" OIDC_ISSUER: str = ""

View file

@ -7,24 +7,18 @@ Exports:
- org_model_root_claim_body_dependency: org_model: verifies org exists and user is either root or su, gets org from body - org_model_root_claim_body_dependency: org_model: verifies org exists and user is either root or su, gets org from body
- super_admin_dependency: user_model: verifies the user is a super admin - super_admin_dependency: user_model: verifies the user is a super admin
""" """
from typing import Annotated from typing import Annotated
from fastapi import Depends from fastapi import Depends
from src.user.dependencies import user_model_claims_dependency from src.user.dependencies import user_model_claims_dependency
from src.user.models import User from src.user.models import User
from src.organisation.dependencies import ( from src.organisation.dependencies import org_model_query_dependency, org_model_body_dependency
org_model_query_dependency,
org_model_body_dependency,
)
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
from src.auth.exceptions import UnauthorizedException from src.auth.exceptions import UnauthorizedException
async def org_query_user_claims( async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency):
org_model: org_model_query_dependency, user_model: user_model_claims_dependency
):
if user_model in org_model.user_rel: if user_model in org_model.user_rel:
return True return True
@ -34,11 +28,7 @@ async def org_query_user_claims(
org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)] org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)]
async def org_query_root_claims( async def org_query_root_claims(user_model: user_model_claims_dependency, org_model: org_model_query_dependency, su_emails: su_list_dependency):
user_model: user_model_claims_dependency,
org_model: org_model_query_dependency,
su_emails: su_list_dependency,
):
if org_model.root_user_id == user_model.id: if org_model.root_user_id == user_model.id:
return org_model return org_model
@ -51,16 +41,10 @@ async def org_query_root_claims(
raise UnauthorizedException(message="Must be the org's root user") raise UnauthorizedException(message="Must be the org's root user")
org_model_root_claim_query_dependency = Annotated[ org_model_root_claim_query_dependency = Annotated[type[Org], Depends(org_query_root_claims)]
type[Org], Depends(org_query_root_claims)
]
async def org_body_root_claims( async def org_body_root_claims(user_model: user_model_claims_dependency, org_model: org_model_body_dependency, su_emails: su_list_dependency):
user_model: user_model_claims_dependency,
org_model: org_model_body_dependency,
su_emails: su_list_dependency,
):
if org_model.root_user_id == user_model.id: if org_model.root_user_id == user_model.id:
return org_model return org_model
@ -73,29 +57,21 @@ async def org_body_root_claims(
raise UnauthorizedException(message="Must be the org's root user") raise UnauthorizedException(message="Must be the org's root user")
org_model_root_claim_body_dependency = Annotated[ org_model_root_claim_body_dependency = Annotated[type[Org], Depends(org_body_root_claims)]
type[Org], Depends(org_body_root_claims)
]
def get_super_admin_list(): def get_super_admin_list():
return [] return []
def empty_su_list(): def empty_su_list():
return [] return []
def testing_su_list(): def testing_su_list():
return ["admin@test.com"] return ["admin@test.com"]
su_list_dependency = Annotated[list[User], Depends(get_super_admin_list)] su_list_dependency = Annotated[list[User], Depends(get_super_admin_list)]
async def user_model_super_admin(user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency):
async def user_model_super_admin(
user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency
):
if user_model.email in super_admin_emails: if user_model.email in super_admin_emails:
return user_model return user_model

View file

@ -4,7 +4,6 @@ Module specific exceptions for the auth module
Exceptions: Exceptions:
- UnauthorizedException: Takes an optional message string - UnauthorizedException: Takes an optional message string
""" """
from typing import Optional from typing import Optional
from fastapi import HTTPException, status from fastapi import HTTPException, status

View file

@ -4,7 +4,6 @@ Router endpoints for the auth module
Exports: Exports:
- router: fastapi.APIRouter - router: fastapi.APIRouter
""" """
from fastapi import APIRouter from fastapi import APIRouter
router = APIRouter( router = APIRouter(

View file

@ -4,7 +4,6 @@ Module specific business logic for the auth module
Exports: Exports:
- claims_dependency: Dict[str, Any] containing OIDC claims and database ID - claims_dependency: Dict[str, Any] containing OIDC claims and database ID
""" """
import json import json
import requests import requests
@ -26,14 +25,11 @@ from src.database import db_dependency
oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG)
oidc_dependency = Annotated[str, Depends(oidc)] oidc_dependency = Annotated[str, Depends(oidc)]
def get_dev_user(): def get_dev_user():
return {"db_id": 1} return {"db_id": 1}
async def get_current_user( async def get_current_user(oidc_auth_string: oidc_dependency, db: db_dependency) -> dict[str, Any]:
oidc_auth_string: oidc_dependency, db: db_dependency
) -> dict[str, Any]:
config_url = urlopen(auth_settings.OIDC_CONFIG) config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read()) config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"] jwks_uri = config["jwks_uri"]
@ -45,7 +41,10 @@ async def get_current_user(
"iss": {"essential": True, "value": auth_settings.OIDC_ISSUER}, "iss": {"essential": True, "value": auth_settings.OIDC_ISSUER},
} }
token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys) token = jwt.decode(
oidc_auth_string.replace("Bearer ", ""),
jwk_keys
)
claims_requests = jwt.JWTClaimsRegistry(**claims_options) claims_requests = jwt.JWTClaimsRegistry(**claims_options)

View file

@ -36,7 +36,6 @@ class Config(CustomBaseSettings):
DATABASE_HOSTNAME: str = "localhost" DATABASE_HOSTNAME: str = "localhost"
DATABASE_CREDENTIALS: SecretStr = ":" DATABASE_CREDENTIALS: SecretStr = ":"
settings = Config() settings = Config()
DATABASE_NAME = settings.DATABASE_NAME DATABASE_NAME = settings.DATABASE_NAME
@ -44,14 +43,10 @@ DATABASE_PORT = settings.DATABASE_PORT
DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME
DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value() DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value()
# this will support special chars for credentials # this will support special chars for credentials
_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str( _DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split(":")
DATABASE_CREDENTIALS
).split(":")
_QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD)) _QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD))
SQLALCHEMY_DATABASE_URI = SecretStr( SQLALCHEMY_DATABASE_URI = SecretStr(f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}")
f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
)
if settings.ENVIRONMENT == Environment.TESTING: if settings.ENVIRONMENT == Environment.TESTING:
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:") SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")

View file

@ -4,7 +4,6 @@ Global constants
Classes: Classes:
- Environment(StrEnum): LOCAL, TESTING, STAGING, PRODUCTION - Environment(StrEnum): LOCAL, TESTING, STAGING, PRODUCTION
""" """
from enum import StrEnum, auto from enum import StrEnum, auto

View file

@ -4,7 +4,6 @@ Exceptions related to the contact module
Exports: Exports:
- ContactNotFoundException: Takes an optional contact ID int - ContactNotFoundException: Takes an optional contact ID int
""" """
from typing import Optional from typing import Optional
from fastapi import HTTPException, status from fastapi import HTTPException, status
@ -12,11 +11,7 @@ from fastapi import HTTPException, status
class ContactNotFoundException(HTTPException): class ContactNotFoundException(HTTPException):
def __init__(self, contact_id: Optional[int] = None) -> None: def __init__(self, contact_id: Optional[int] = None) -> None:
detail = ( detail = "Contact not found" if contact_id is None else f"Contact with ID '{contact_id}' was not found."
"Contact not found"
if contact_id is None
else f"Contact with ID '{contact_id}' was not found."
)
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,

View file

@ -5,7 +5,6 @@ Models:
- Contact: id[pk], email, first_name, last_name, phonenumber, vat_number - Contact: id[pk], email, first_name, last_name, phonenumber, vat_number
street_address, street_address_line_2, post_office_box_number, address_locality, country_code, address_region, postal_code street_address, street_address_line_2, post_office_box_number, address_locality, country_code, address_region, postal_code
""" """
from sqlalchemy import Column, Integer, String, ForeignKey from sqlalchemy import Column, Integer, String, ForeignKey
from src.database import Base from src.database import Base
@ -29,6 +28,4 @@ class Contact(Base):
address_region = Column(String, default=None, nullable=True) address_region = Column(String, default=None, nullable=True)
postal_code = Column(String) postal_code = Column(String)
org_id = Column( org_id = Column(Integer, ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False)
Integer, ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False
)

View file

@ -1,7 +1,6 @@
""" """
Router endpoints for the contact module Router endpoints for the contact module
""" """
from fastapi import APIRouter from fastapi import APIRouter

View file

@ -5,7 +5,6 @@ Models:
- ContactAddress - ContactAddress
- ContactModel: Contains ContactAddress as a property - ContactModel: Contains ContactAddress as a property
""" """
from typing import Optional from typing import Optional
from pydantic import EmailStr, ConfigDict from pydantic import EmailStr, ConfigDict

View file

@ -5,7 +5,6 @@ Exports:
- db_dependency - db_dependency
- Base (sqlalchemy base model) - Base (sqlalchemy base model)
""" """
from typing import Annotated from typing import Annotated
from sqlalchemy import create_engine, StaticPool from sqlalchemy import create_engine, StaticPool
from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session
@ -17,11 +16,7 @@ from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings
if global_settings.ENVIRONMENT == Environment.TESTING: if global_settings.ENVIRONMENT == Environment.TESTING:
connect_args = {"check_same_thread": False} connect_args = {"check_same_thread": False}
engine = create_engine( engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value(), connect_args=connect_args, poolclass=StaticPool)
SQLALCHEMY_DATABASE_URI.get_secret_value(),
connect_args=connect_args,
poolclass=StaticPool,
)
else: else:
engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value()) engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value())
@ -41,7 +36,5 @@ def get_db():
db_dependency = Annotated[Session, Depends(get_db)] db_dependency = Annotated[Session, Depends(get_db)]
class Base(DeclarativeBase): class Base(DeclarativeBase):
pass pass

View file

@ -5,7 +5,6 @@ Exports:
- UnprocessableContentException - UnprocessableContentException
- ConflictException - ConflictException
""" """
from typing import Optional from typing import Optional
from fastapi import HTTPException, status from fastapi import HTTPException, status

View file

@ -6,7 +6,6 @@ Exports:
- group_model_body_dependency: group_model: Gets group model from db, if it exists. Uses group_id from request body. - group_model_body_dependency: group_model: Gets group model from db, if it exists. Uses group_id from request body.
- perm_model_body_dependency: perm_model: Gets perm model from db, if it exists. Uses perm_id from request body. - perm_model_body_dependency: perm_model: Gets perm model from db, if it exists. Uses perm_id from request body.
""" """
from typing import Annotated, Optional from typing import Annotated, Optional
from fastapi import Depends, Query from fastapi import Depends, Query
@ -18,22 +17,17 @@ from src.iam.exceptions import GroupNotFoundException, PermNotFoundException
from src.iam.schemas import GroupIDMixin, PermIDMixin from src.iam.schemas import GroupIDMixin, PermIDMixin
def get_group_model_query( def get_group_model_query(db: db_dependency, group_id: Annotated[int, Query(gt=0)]) -> type[Group]:
db: db_dependency, group_id: Annotated[int, Query(gt=0)]
) -> type[Group]:
group_model = db.get(Group, group_id) group_model = db.get(Group, group_id)
if group_model is None: if group_model is None:
raise GroupNotFoundException(group_id) raise GroupNotFoundException(group_id)
return group_model return group_model
group_model_query_dependency = Annotated[type[Group], Depends(get_group_model_query)] group_model_query_dependency = Annotated[type[Group], Depends(get_group_model_query)]
def get_group_model_body( def get_group_model_body(db: db_dependency, request_model: Optional[GroupIDMixin] = None) -> type[Group]:
db: db_dependency, request_model: Optional[GroupIDMixin] = None
) -> type[Group]:
group_id = getattr(request_model, "group_id", None) group_id = getattr(request_model, "group_id", None)
if group_id is None: if group_id is None:
raise GroupNotFoundException() raise GroupNotFoundException()
@ -43,13 +37,10 @@ def get_group_model_body(
return group_model return group_model
group_model_body_dependency = Annotated[type[Group], Depends(get_group_model_body)] group_model_body_dependency = Annotated[type[Group], Depends(get_group_model_body)]
def get_perm_model_body( def get_perm_model_body(db: db_dependency, request_model: Optional[PermIDMixin] = None) -> type[Permission]:
db: db_dependency, request_model: Optional[PermIDMixin] = None
) -> type[Permission]:
perm_id = getattr(request_model, "permission_id", None) perm_id = getattr(request_model, "permission_id", None)
if perm_id is None: if perm_id is None:
raise PermNotFoundException raise PermNotFoundException
@ -59,18 +50,4 @@ def get_perm_model_body(
return perm_model return perm_model
perm_model_body_dependency = Annotated[type[Permission], Depends(get_perm_model_body)] perm_model_body_dependency = Annotated[type[Permission], Depends(get_perm_model_body)]
def get_perm_model_query(
db: db_dependency, perm_id: Annotated[int, Query(gt=0)]
) -> type[Permission]:
perm_model = db.get(Permission, perm_id)
if perm_model is None:
raise PermNotFoundException(perm_id)
return perm_model
perm_model_query_dependency = Annotated[type[Permission], Depends(get_perm_model_query)]

View file

@ -5,7 +5,6 @@ Exceptions:
- GroupNotFoundException: Takes an optional group_id int - GroupNotFoundException: Takes an optional group_id int
- PermNotFoundException: Takes an optional perm_id int - PermNotFoundException: Takes an optional perm_id int
""" """
from typing import Optional from typing import Optional
from fastapi import HTTPException, status from fastapi import HTTPException, status
@ -13,11 +12,7 @@ from fastapi import HTTPException, status
class GroupNotFoundException(HTTPException): class GroupNotFoundException(HTTPException):
def __init__(self, group_id: Optional[int] = None) -> None: def __init__(self, group_id: Optional[int] = None) -> None:
detail = ( detail = "Group not found" if group_id is None else f"User with ID '{group_id}' was not found."
"Group not found"
if group_id is None
else f"User with ID '{group_id}' was not found."
)
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
@ -26,11 +21,7 @@ class GroupNotFoundException(HTTPException):
class PermNotFoundException(HTTPException): class PermNotFoundException(HTTPException):
def __init__(self, perm_id: Optional[int] = None) -> None: def __init__(self, perm_id: Optional[int] = None) -> None:
detail = ( detail = "Permission not found" if perm_id is None else f"User with ID '{perm_id}' was not found."
"Permission not found"
if perm_id is None
else f"User with ID '{perm_id}' was not found."
)
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,

View file

@ -17,7 +17,6 @@ Models:
- UserGroups: - UserGroups:
- org_id[FK][PK], user_id[FK][PK], group_id[FK][PK] - org_id[FK][PK], user_id[FK][PK], group_id[FK][PK]
""" """
from sqlalchemy import Column, Integer, String, ForeignKey, UniqueConstraint from sqlalchemy import Column, Integer, String, ForeignKey, UniqueConstraint
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@ -33,14 +32,7 @@ class Permission(Base):
service_id = Column(Integer, ForeignKey("service.id", ondelete="CASCADE")) service_id = Column(Integer, ForeignKey("service.id", ondelete="CASCADE"))
__table_args__ = ( UniqueConstraint("service_id", "resource", "action", name="uniq_permission_resource_and_action")
UniqueConstraint(
"service_id",
"resource",
"action",
name="uniq_permission_resource_and_action",
),
)
service_rel = relationship("Service", foreign_keys=[service_id]) service_rel = relationship("Service", foreign_keys=[service_id])
@ -49,10 +41,13 @@ class Permission(Base):
return self.service_rel.name return self.service_rel.name
group_rel = relationship( group_rel = relationship(
"Group", secondary="group_permissions", back_populates="permission_rel" "Group",
secondary="group_permissions",
back_populates="permission_rel"
) )
class Group(Base): class Group(Base):
__tablename__ = "group" __tablename__ = "group"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
@ -60,30 +55,28 @@ class Group(Base):
org_id = Column(Integer, ForeignKey("organisation.id", ondelete="CASCADE")) org_id = Column(Integer, ForeignKey("organisation.id", ondelete="CASCADE"))
user_rel = relationship("User", secondary="user_groups", back_populates="group_rel") user_rel = relationship(
"User",
secondary="user_groups",
back_populates="group_rel"
)
org_rel = relationship("Organisation", back_populates="group_rel") org_rel = relationship("Organisation", back_populates="group_rel")
permission_rel = relationship( permission_rel = relationship(
"Permission", secondary="group_permissions", back_populates="group_rel" "Permission",
secondary="group_permissions",
back_populates="group_rel"
) )
class GroupPermissions(Base): class GroupPermissions(Base):
__tablename__ = "group_permissions" __tablename__ = "group_permissions"
group_id = Column( group_id = Column(Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True)
Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True permission_id = Column(Integer, ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True)
)
permission_id = Column(
Integer, ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
)
class UserGroups(Base): class UserGroups(Base):
__tablename__ = "user_groups" __tablename__ = "user_groups"
user_id = Column( user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True)
Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True group_id = Column(Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True)
)
group_id = Column(
Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
)

View file

@ -15,9 +15,9 @@ Endpoints:
- [DELETE](/iam/permission): [super admin]: Removes a permission - [DELETE](/iam/permission): [super admin]: Removes a permission
- [GET](/iam/permissions/search): [root user]: Returns a list of permissions matching a filter(service|resource|action) - [GET](/iam/permissions/search): [root user]: Returns a list of permissions matching a filter(service|resource|action)
""" """
from fastapi import APIRouter, status from fastapi import APIRouter, status
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from psycopg import errors
from src.service.exceptions import ServiceNotFoundException from src.service.exceptions import ServiceNotFoundException
from src.exceptions import ConflictException from src.exceptions import ConflictException
@ -25,50 +25,21 @@ from src.database import db_dependency
from src.schemas import ResourceName from src.schemas import ResourceName
from src.auth.exceptions import UnauthorizedException from src.auth.exceptions import UnauthorizedException
from src.auth.service import claims_dependency from src.auth.service import claims_dependency
from src.auth.dependencies import ( from src.auth.dependencies import org_model_root_claim_query_dependency, org_model_root_claim_body_dependency, \
org_model_root_claim_query_dependency, super_admin_dependency
org_model_root_claim_body_dependency,
super_admin_dependency,
)
from src.user.models import User from src.user.models import User
from src.user.dependencies import ( from src.user.dependencies import user_model_body_dependency
user_model_body_dependency,
user_model_query_dependency,
)
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
from src.service.models import Service from src.service.models import Service
from src.iam.service import service_key_dependency from src.iam.service import service_key_dependency
from src.iam.models import ( from src.iam.models import Permission as Perm, GroupPermissions as GPerms, Group, UserGroups
Permission as Perm, from src.iam.dependencies import group_model_query_dependency, group_model_body_dependency, perm_model_body_dependency
GroupPermissions as GPerms, from src.iam.schemas import IAMGetGroupPermissionsResponse, IAMGetGroupUsersResponse, IAMPostGroupRequest, \
Group, GroupSchema, IAMPostGroupResponse, IAMPutGroupPermissionRequest, IAMPutGroupPermissionResponse, \
UserGroups, IAMPutGroupUserRequest, IAMPutGroupUserResponse, IAMDeleteGroupPermissionRequest, IAMDeleteGroupPermissionResponse, \
) IAMDeleteGroupUserRequest, IAMDeleteGroupUserResponse, IAMGetPermissionsResponse, IAMPostPermissionRequest, \
from src.iam.dependencies import ( IAMPostPermissionResponse, IAMDeletePermissionRequest, IAMGetPermissionsSearchRequest, IAMGetPermissionsSearchResponse
group_model_query_dependency,
group_model_body_dependency,
perm_model_body_dependency,
perm_model_query_dependency,
)
from src.iam.schemas import (
IAMGetGroupPermissionsResponse,
IAMGetGroupUsersResponse,
IAMPostGroupRequest,
GroupSchema,
IAMPostGroupResponse,
IAMPutGroupPermissionRequest,
IAMPutGroupPermissionResponse,
IAMPutGroupUserRequest,
IAMPutGroupUserResponse,
IAMDeleteGroupPermissionResponse,
IAMDeleteGroupUserResponse,
IAMGetPermissionsResponse,
IAMPostPermissionRequest,
IAMPostPermissionResponse,
IAMGetPermissionsSearchRequest,
IAMGetPermissionsSearchResponse,
)
router = APIRouter( router = APIRouter(
tags=["IAM"], tags=["IAM"],
@ -77,21 +48,15 @@ router = APIRouter(
@router.post("/can_act_on_resource") @router.post("/can_act_on_resource")
async def can_act_on_resource( async def can_act_on_resource(valid_key: service_key_dependency, db: db_dependency, user_claims: claims_dependency,
valid_key: service_key_dependency, rn: ResourceName, action: str) -> bool:
db: db_dependency,
user_claims: claims_dependency,
rn: ResourceName,
action: str,
) -> bool:
try: try:
user_id = user_claims["db_id"] user_id = user_claims["db_id"]
rn_org = rn.organisation rn_org = rn.organisation
rn_service = rn.service rn_service = rn.service
rn_resource = rn.resource rn_resource = rn.resource
result = ( result = (db.query(Perm)
db.query(Perm)
.join(Service, Service.id == Perm.service_id) .join(Service, Service.id == Perm.service_id)
.join(GPerms, GPerms.permission_id == Perm.id) .join(GPerms, GPerms.permission_id == Perm.id)
.join(Group, Group.id == GPerms.group_id) .join(Group, Group.id == GPerms.group_id)
@ -114,41 +79,28 @@ async def can_act_on_resource(
@router.get("/group/permissions", response_model=IAMGetGroupPermissionsResponse) @router.get("/group/permissions", response_model=IAMGetGroupPermissionsResponse)
async def get_group_permissions( async def get_group_permissions(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency):
group_model: group_model_query_dependency,
org_model: org_model_root_claim_query_dependency,
):
if group_model.org_id != org_model.id: if group_model.org_id != org_model.id:
raise UnauthorizedException("Group does not belong to this organization") raise UnauthorizedException("Group does not belong to this organization")
return {"permissions": group_model.permission_rel} return {"permissions": group_model.permission_rel}
@router.get("/group/users", response_model=IAMGetGroupUsersResponse) @router.get("/group/users", response_model=IAMGetGroupUsersResponse)
async def get_group_users( async def get_group_users(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency):
group_model: group_model_query_dependency,
org_model: org_model_root_claim_query_dependency,
):
if group_model.org_id != org_model.id: if group_model.org_id != org_model.id:
raise UnauthorizedException("Group does not belong to this organization") raise UnauthorizedException("Group does not belong to this organization")
return {"users": group_model.user_rel} return {"users": group_model.user_rel}
@router.post("/group", response_model=IAMPostGroupResponse) @router.post("/group", response_model=IAMPostGroupResponse)
async def create_group( async def create_group(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPostGroupRequest):
db: db_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: IAMPostGroupRequest,
):
group_model = Group(name=request_model.name, org_id=org_model.id) group_model = Group(name=request_model.name, org_id=org_model.id)
db.add(group_model) db.add(group_model)
try: try:
db.flush() db.flush()
except IntegrityError as e: except IntegrityError as e:
if ( if isinstance(e.orig, errors.UniqueViolation):
getattr(e.orig, "pgcode", None) == "23505" # Postgres unique violation
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
):
raise ConflictException("Group with this name already exists") raise ConflictException("Group with this name already exists")
response = GroupSchema(**group_model.__dict__) response = GroupSchema(**group_model.__dict__)
db.commit() db.commit()
@ -156,13 +108,7 @@ async def create_group(
@router.put("/group/permission", response_model=IAMPutGroupPermissionResponse) @router.put("/group/permission", response_model=IAMPutGroupPermissionResponse)
async def add_group_permission( async def add_group_permission(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupPermissionRequest):
db: db_dependency,
group_model: group_model_body_dependency,
perm_model: perm_model_body_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: IAMPutGroupPermissionRequest,
):
if group_model.org_id != org_model.id: if group_model.org_id != org_model.id:
raise UnauthorizedException("Group does not belong to this organization") raise UnauthorizedException("Group does not belong to this organization")
@ -172,22 +118,13 @@ async def add_group_permission(
group_model.permission_rel.append(perm_model) group_model.permission_rel.append(perm_model)
db.flush() db.flush()
response = IAMPutGroupPermissionResponse( response = IAMPutGroupPermissionResponse(group=GroupSchema(**group_model.__dict__), permissions=group_model.permission_rel)
group=GroupSchema(**group_model.__dict__),
permissions=group_model.permission_rel,
)
db.commit() db.commit()
return response return response
@router.put("/group/user", response_model=IAMPutGroupUserResponse) @router.put("/group/user", response_model=IAMPutGroupUserResponse)
async def add_group_user( async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupUserRequest):
db: db_dependency,
group_model: group_model_body_dependency,
user_model: user_model_body_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: IAMPutGroupUserRequest,
):
if group_model.org_id != org_model.id: if group_model.org_id != org_model.id:
raise UnauthorizedException("Group does not belong to this organization") raise UnauthorizedException("Group does not belong to this organization")
@ -196,117 +133,76 @@ async def add_group_user(
group_model.user_rel.append(user_model) group_model.user_rel.append(user_model)
db.flush() db.flush()
response = IAMPutGroupUserResponse( response = IAMPutGroupUserResponse(group=GroupSchema(**group_model.__dict__), users=group_model.user_rel)
group=GroupSchema(**group_model.__dict__), users=group_model.user_rel
)
db.commit() db.commit()
return response return response
@router.delete("/group/permissions") @router.delete("/group/permissions")
async def remove_group_permissions( async def remove_group_permissions(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupPermissionRequest):
db: db_dependency,
group_model: group_model_query_dependency,
perm_model: perm_model_query_dependency,
org_model: org_model_root_claim_query_dependency,
):
if group_model.org_id != org_model.id: if group_model.org_id != org_model.id:
raise UnauthorizedException("Group does not belong to this organization") raise UnauthorizedException("Group does not belong to this organization")
group_model.permission_rel.remove(perm_model) group_model.permission_rel.remove(perm_model)
db.flush() db.flush()
response = IAMDeleteGroupPermissionResponse( response = IAMDeleteGroupPermissionResponse(group=GroupSchema(**group_model.__dict__),
group=GroupSchema(**group_model.__dict__), permissions=group_model.permission_rel)
permissions=group_model.permission_rel,
)
db.commit() db.commit()
return response return response
@router.delete("/group/user") @router.delete("/group/user")
async def remove_group_user( async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupUserRequest):
db: db_dependency,
group_model: group_model_query_dependency,
user_model: user_model_query_dependency,
org_model: org_model_root_claim_query_dependency,
):
if group_model.org_id != org_model.id: if group_model.org_id != org_model.id:
raise UnauthorizedException("Group does not belong to this organization") raise UnauthorizedException("Group does not belong to this organization")
user_model.group_rel.remove(group_model) user_model.group_rel.remove(group_model)
db.flush() db.flush()
response = IAMDeleteGroupUserResponse( response = IAMDeleteGroupUserResponse(group=GroupSchema(**group_model.__dict__), users=group_model.user_rel)
group=GroupSchema(**group_model.__dict__), users=group_model.user_rel
)
db.commit() db.commit()
return response return response
@router.get("/permissions", response_model=IAMGetPermissionsResponse) @router.get("/permissions", response_model=IAMGetPermissionsResponse)
async def get_permissions( async def get_permissions(db: db_dependency, org_model: org_model_root_claim_query_dependency):
db: db_dependency, org_model: org_model_root_claim_query_dependency
):
permission_models = db.query(Perm).all() permission_models = db.query(Perm).all()
return {"permissions": permission_models} return {"permissions": permission_models}
@router.post("/permission", response_model=IAMPostPermissionResponse) @router.post("/permission", response_model=IAMPostPermissionResponse)
async def create_new_permission( async def create_new_permission(db: db_dependency, su: super_admin_dependency, request_model: IAMPostPermissionRequest):
db: db_dependency,
su: super_admin_dependency,
request_model: IAMPostPermissionRequest,
):
service_model = db.get(Service, request_model.service_id) service_model = db.get(Service, request_model.service_id)
if service_model is None: if service_model is None:
raise ServiceNotFoundException(service_id=request_model.service_id) raise ServiceNotFoundException(service_id=request_model.service_id)
perm_model = Perm(**request_model.__dict__) perm_model = Perm(**request_model.__dict__)
db.add(perm_model)
try: try:
db.flush() db.add(perm_model)
except IntegrityError as e: except IntegrityError as e:
if ( if isinstance(e.orig, errors.UniqueViolation):
getattr(e.orig, "pgcode", None) == "23505" # Postgres unique violation
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
):
raise ConflictException(message="Permission already exists") raise ConflictException(message="Permission already exists")
response = { db.flush()
"service_name": perm_model.service_name, response = {"service_name": perm_model.service_name, "resource": perm_model.resource, "action": perm_model.action}
"resource": perm_model.resource,
"action": perm_model.action,
}
db.commit() db.commit()
return {"permission": response} return {"permission": response}
@router.delete("/permission", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/permission", status_code=status.HTTP_204_NO_CONTENT)
async def delete_permission( async def delete_permission(db: db_dependency, su: super_admin_dependency, perm_model: perm_model_body_dependency, request_model: IAMDeletePermissionRequest):
db: db_dependency,
su: super_admin_dependency,
perm_model: perm_model_query_dependency,
):
db.delete(perm_model) db.delete(perm_model)
db.commit() db.commit()
@router.post("/permissions/search", response_model=IAMGetPermissionsSearchResponse) @router.post("/permissions/search", response_model=IAMGetPermissionsSearchResponse)
async def post_permissions( async def post_permissions(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMGetPermissionsSearchRequest):
db: db_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: IAMGetPermissionsSearchRequest,
):
permission_query = db.query(Perm) permission_query = db.query(Perm)
if request_model.service_id is not None: if request_model.service_id is not None:
permission_query = permission_query.filter( permission_query = permission_query.filter(Perm.service_id == request_model.service_id)
Perm.service_id == request_model.service_id
)
if request_model.resource is not None: if request_model.resource is not None:
permission_query = permission_query.filter( permission_query = permission_query.filter(Perm.resource == request_model.resource)
Perm.resource == request_model.resource
)
if request_model.action is not None: if request_model.action is not None:
permission_query = permission_query.filter(Perm.action == request_model. action) permission_query = permission_query.filter(Perm.action == request_model. action)

View file

@ -6,7 +6,6 @@ Models follow the nomenclature of:
- Mixins: "<Attribute>Mixin" - Mixins: "<Attribute>Mixin"
- Models: "<Module><Method><Resource><Opt:Resource><Direction>" ie "IAMGetGroupPermissionsResponse" - Models: "<Module><Method><Resource><Opt:Resource><Direction>" ie "IAMGetGroupPermissionsResponse"
""" """
from typing import Optional, Annotated from typing import Optional, Annotated
from pydantic import EmailStr, ConfigDict, Field from pydantic import EmailStr, ConfigDict, Field
@ -25,7 +24,6 @@ class UserSchema(CustomBaseModel):
last_name: str last_name: str
email: EmailStr email: EmailStr
class PermissionSchema(CustomBaseModel): class PermissionSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
@ -33,82 +31,73 @@ class PermissionSchema(CustomBaseModel):
resource: str resource: str
action: str action: str
class GroupSchema(CustomBaseModel): class GroupSchema(CustomBaseModel):
id: int id: int
name: str name: str
class GroupIDMixin(CustomBaseModel): class GroupIDMixin(CustomBaseModel):
group_id: int = Field(gt=0) group_id: int = Field(gt=0)
class PermIDMixin(CustomBaseModel): class PermIDMixin(CustomBaseModel):
permission_id: int = Field(gt=0) permission_id: int = Field(gt=0)
class IAMGetGroupPermissionsResponse(CustomBaseModel): class IAMGetGroupPermissionsResponse(CustomBaseModel):
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMGetGroupUsersResponse(CustomBaseModel): class IAMGetGroupUsersResponse(CustomBaseModel):
users : list[UserSchema] users : list[UserSchema]
class IAMPostGroupRequest(OrgIDMixin): class IAMPostGroupRequest(OrgIDMixin):
name: str = Field(min_length=3) name: str = Field(min_length=3)
class IAMPostGroupResponse(CustomBaseModel): class IAMPostGroupResponse(CustomBaseModel):
group: GroupSchema group: GroupSchema
class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin): class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin):
pass pass
class IAMPutGroupPermissionResponse(CustomBaseModel): class IAMPutGroupPermissionResponse(CustomBaseModel):
group: GroupSchema group: GroupSchema
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin): class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin):
pass pass
class IAMPutGroupUserResponse(CustomBaseModel): class IAMPutGroupUserResponse(CustomBaseModel):
group: GroupSchema group: GroupSchema
users: list[UserSchema] users: list[UserSchema]
class IAMDeleteGroupPermissionRequest(GroupIDMixin, PermIDMixin):
pass
class IAMDeleteGroupPermissionResponse(CustomBaseModel): class IAMDeleteGroupPermissionResponse(CustomBaseModel):
group: GroupSchema group: GroupSchema
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMDeleteGroupUserRequest(GroupIDMixin, UserIDMixin):
pass
class IAMDeleteGroupUserResponse(CustomBaseModel): class IAMDeleteGroupUserResponse(CustomBaseModel):
group: GroupSchema group: GroupSchema
users: list[UserSchema] users: list[UserSchema]
class IAMGetPermissionsResponse(CustomBaseModel): class IAMGetPermissionsResponse(CustomBaseModel):
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMPostPermissionRequest(ServiceIDMixin): class IAMPostPermissionRequest(ServiceIDMixin):
resource: str resource: str
action: str action: str
class IAMPostPermissionResponse(CustomBaseModel): class IAMPostPermissionResponse(CustomBaseModel):
permission: PermissionSchema permission: PermissionSchema
class IAMDeletePermissionRequest(PermIDMixin):
pass
class IAMGetPermissionsSearchRequest(OrgIDMixin): class IAMGetPermissionsSearchRequest(OrgIDMixin):
service_id: Annotated[int | None, Field(gt=0)] = None service_id: Annotated[int | None, Field(gt=0)] = None
resource: Optional[str] = None resource: Optional[str] = None
action: Optional[str] = None action: Optional[str] = None
class IAMGetPermissionsSearchResponse(CustomBaseModel): class IAMGetPermissionsSearchResponse(CustomBaseModel):
permissions: list[PermissionSchema] permissions: list[PermissionSchema]

View file

@ -4,7 +4,6 @@ Business logic reusable functions related to IAM
Exports: Exports:
- service_key_dependency: bool: verifies request headers contain the correct api key for the service - service_key_dependency: bool: verifies request headers contain the correct api key for the service
""" """
from typing import Annotated from typing import Annotated
from src.service.models import Service from src.service.models import Service
@ -20,16 +19,10 @@ def valid_service_key(db: db_dependency, request: Request, rn: ResourceName) ->
if not api_key: if not api_key:
raise UnauthorizedException("Missing API key") raise UnauthorizedException("Missing API key")
service = rn.service service = rn.service
result = ( result = db.query(Service).filter(Service.name == service).filter(Service.api_key == api_key).first()
db.query(Service)
.filter(Service.name == service)
.filter(Service.api_key == api_key)
.first()
)
if result is None: if result is None:
raise UnauthorizedException("Invalid API key") raise UnauthorizedException("Invalid API key")
return True return True
service_key_dependency = Annotated[bool, Depends(valid_service_key)] service_key_dependency = Annotated[bool, Depends(valid_service_key)]

View file

@ -1,7 +1,6 @@
""" """
Application root file: Inits the FastAPI application Application root file: Inits the FastAPI application
""" """
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator from typing import AsyncGenerator

View file

@ -1,3 +1,4 @@
""" """
Global database models Global database models
""" """

View file

@ -5,7 +5,6 @@ Classes:
- Status(StrEnum): PARTIAL, SUBMITTED, REMEDIATION, APPROVED, REJECTED, REMOVED - Status(StrEnum): PARTIAL, SUBMITTED, REMEDIATION, APPROVED, REJECTED, REMOVED
- ContactType(StrEnum): BILLING, SECURITY, OWNER - ContactType(StrEnum): BILLING, SECURITY, OWNER
""" """
from enum import StrEnum, auto from enum import StrEnum, auto

View file

@ -5,7 +5,6 @@ Exports:
- org_model_query_dependency: org_model: Gets org model from db, if it exists. Uses org_id from query param. Also verifies if the org has been approved. - org_model_query_dependency: org_model: Gets org model from db, if it exists. Uses org_id from query param. Also verifies if the org has been approved.
- org_model_body_dependency: org_model: Gets org model from db, if it exists. Uses org_id from request body. Also verifies if the org has been approved. - org_model_body_dependency: org_model: Gets org model from db, if it exists. Uses org_id from request body. Also verifies if the org has been approved.
""" """
from typing import Annotated, Optional from typing import Annotated, Optional
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -26,40 +25,25 @@ def get_org_model(db: Session, request: Request, org_id: int):
root = "/api/v1" root = "/api/v1"
pre_approval_endpoints = [ pre_approval_endpoints = [f"PATCH{root}/org/status", f"PATCH{root}/org/questionnaire", f"GET{root}/org", f"GET{root}/org/contact", f"PATCH{root}/org/contact"]
f"PATCH{root}/org/status",
f"PATCH{root}/org/questionnaire",
f"GET{root}/org",
f"GET{root}/org/contact",
f"PATCH{root}/org/contact",
]
current_request = f"{request.method}{request.url.path}" current_request = f"{request.method}{request.url.path}"
if ( if current_request not in pre_approval_endpoints and org_model.status != OrgStatus.APPROVED:
current_request not in pre_approval_endpoints
and org_model.status != OrgStatus.APPROVED
):
raise AwaitingApprovalException(org_id) raise AwaitingApprovalException(org_id)
return org_model return org_model
def get_org_model_query( def get_org_model_query(db: db_dependency, request: Request, org_id: Annotated[int, Query(gt=0)]) -> type[Org]:
db: db_dependency, request: Request, org_id: Annotated[int, Query(gt=0)]
) -> type[Org]:
return get_org_model(db, request, org_id) return get_org_model(db, request, org_id)
org_model_query_dependency = Annotated[type[Org], Depends(get_org_model_query)] org_model_query_dependency = Annotated[type[Org], Depends(get_org_model_query)]
def get_org_model_body( def get_org_model_body(db: db_dependency, request: Request, request_model: OrgIDMixin) -> type[Org]:
db: db_dependency, request: Request, request_model: OrgIDMixin
) -> type[Org]:
org_id: Optional[int] = getattr(request_model, "organisation_id", None) org_id: Optional[int] = getattr(request_model, "organisation_id", None)
if org_id is None: if org_id is None:
raise OrgNotFoundException raise OrgNotFoundException
return get_org_model(db, request, org_id) return get_org_model(db, request, org_id)
org_model_body_dependency = Annotated[type[Org], Depends(get_org_model_body)] org_model_body_dependency = Annotated[type[Org], Depends(get_org_model_body)]

View file

@ -5,7 +5,6 @@ Exceptions:
- OrgNotFoundException: Takes an optional org_id int - OrgNotFoundException: Takes an optional org_id int
- AwaitingApprovalException: Takes an optional org_id int - AwaitingApprovalException: Takes an optional org_id int
""" """
from typing import Optional from typing import Optional
from fastapi import HTTPException, status from fastapi import HTTPException, status
@ -13,24 +12,15 @@ from fastapi import HTTPException, status
class OrgNotFoundException(HTTPException): class OrgNotFoundException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None: def __init__(self, org_id: Optional[int] = None) -> None:
detail = ( detail = "Organisation not found" if org_id is None else f"Organisation with ID '{org_id}' was not found."
"Organisation not found"
if org_id is None
else f"Organisation with ID '{org_id}' was not found."
)
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )
class AwaitingApprovalException(HTTPException): class AwaitingApprovalException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None: def __init__(self, org_id: Optional[int] = None) -> None:
detail = ( detail = "Organisation has not been approved." if org_id is None else f"Organisation with ID '{org_id}' has not been approved."
"Organisation has not been approved."
if org_id is None
else f"Organisation with ID '{org_id}' has not been approved."
)
super().__init__( super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail, detail=detail,

View file

@ -13,7 +13,6 @@ Models:
- owner_contact_rel: ORM relationship to Contact with owner_contact FK - owner_contact_rel: ORM relationship to Contact with owner_contact FK
- OrgUsers: org_id[FK][PK], user_id[FK][PK] - OrgUsers: org_id[FK][PK], user_id[FK][PK]
""" """
from sqlalchemy import Column, Integer, String, ForeignKey, JSON from sqlalchemy import Column, Integer, String, ForeignKey, JSON
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@ -35,7 +34,9 @@ class Organisation(Base):
owner_contact_id = Column(Integer, ForeignKey("contact.id")) owner_contact_id = Column(Integer, ForeignKey("contact.id"))
user_rel = relationship( user_rel = relationship(
"User", secondary="orgusers", back_populates="organisation_rel" "User",
secondary="orgusers",
back_populates="organisation_rel"
) )
group_rel = relationship("Group", back_populates="org_rel") group_rel = relationship("Group", back_populates="org_rel")
@ -53,9 +54,5 @@ class Organisation(Base):
class OrgUsers(Base): class OrgUsers(Base):
__tablename__ = "orgusers" __tablename__ = "orgusers"
org_id = Column( org_id = Column(Integer, ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True)
Integer, ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True)
)
user_id = Column(
Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
)

View file

@ -15,11 +15,11 @@ Endpoints:
- [GET](/org/contact): [root user]: Gets the (contact_type) contact for an org(id) - [GET](/org/contact): [root user]: Gets the (contact_type) contact for an org(id)
- [PATCH](/org/contact): [root user]: Updates the (contact_type) contact for an org(id). Any number of details can be changed. - [PATCH](/org/contact): [root user]: Updates the (contact_type) contact for an org(id). Any number of details can be changed.
""" """
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, status from fastapi import APIRouter, status
from fastapi.params import Query from fastapi.params import Query
from psycopg.errors import UniqueViolation
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from src.contact.schemas import ContactModel from src.contact.schemas import ContactModel
@ -28,42 +28,17 @@ from src.contact.models import Contact
from src.contact.schemas import ContactAddress from src.contact.schemas import ContactAddress
from src.contact.exceptions import ContactNotFoundException from src.contact.exceptions import ContactNotFoundException
from src.database import db_dependency from src.database import db_dependency
from src.user.dependencies import ( from src.user.dependencies import user_model_body_dependency, user_model_claims_dependency
user_model_body_dependency, from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency, org_model_root_claim_body_dependency
user_model_claims_dependency,
user_model_query_dependency,
)
from src.auth.dependencies import (
super_admin_dependency,
org_model_root_claim_query_dependency,
org_model_root_claim_body_dependency,
)
from src.organisation.dependencies import ( from src.organisation.dependencies import org_model_body_dependency
org_model_body_dependency,
org_model_query_dependency,
)
from src.organisation.constants import ContactType from src.organisation.constants import ContactType
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
from src.organisation.schemas import ( from src.organisation.schemas import OrgPostOrgRequest, OrgPatchQuestionnaireRequest, OrgPatchStatusRequest, \
OrgPostOrgRequest, OrgPatchContactRequest, \
OrgPatchQuestionnaireRequest, OrgPostUserRequest, OrgGetUserResponse, OrgGetContactResponse, OrgGetOrgResponse, OrgPatchRootRequest, \
OrgPatchStatusRequest, OrgGetGroupResponse, OrgDeleteUserRequest, OrgDeleteOrgRequest, OrgPostOrgResponse, OrgPatchQuestionnaireResponse, \
OrgPatchContactRequest, OrgPatchStatusResponse, OrgPostUserResponse, OrgPatchRootResponse, Questionnaire, OrgPatchContactResponse
OrgPostUserRequest,
OrgGetUserResponse,
OrgGetContactResponse,
OrgGetOrgResponse,
OrgPatchRootRequest,
OrgGetGroupResponse,
OrgPostOrgResponse,
OrgPatchQuestionnaireResponse,
OrgPatchStatusResponse,
OrgPostUserResponse,
OrgPatchRootResponse,
Questionnaire,
OrgPatchContactResponse,
)
router = APIRouter( router = APIRouter(
prefix="/org", prefix="/org",
@ -71,22 +46,16 @@ router = APIRouter(
) )
@router.get( @router.get("",
"",
summary="Get org details by ID.", summary="Get org details by ID.",
response_model=OrgGetOrgResponse, response_model=OrgGetOrgResponse,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
responses={ responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
status.HTTP_404_NOT_FOUND: {"description": "Organisation not found"}, status.HTTP_404_NOT_FOUND: {"description": "Organisation not found"},
status.HTTP_422_UNPROCESSABLE_CONTENT: { status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Missing or invalid org_id query parameter"},
"description": "Missing or invalid org_id query parameter" status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
}, })
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be org root user."
},
},
)
async def get_org_by_id(org_model: org_model_root_claim_query_dependency): async def get_org_by_id(org_model: org_model_root_claim_query_dependency):
""" """
Returns organisation details including key member email addresses Returns organisation details including key member email addresses
@ -99,35 +68,23 @@ async def get_org_by_id(org_model: org_model_root_claim_query_dependency):
"billing_contact": org_model.billing_contact_rel.email, "billing_contact": org_model.billing_contact_rel.email,
"security_contact": org_model.security_contact_rel.email, "security_contact": org_model.security_contact_rel.email,
"root_user": org_model.root_user_email, "root_user": org_model.root_user_email,
"intake_questionnaire": org_model.intake_questionnaire, "intake_questionnaire": org_model.intake_questionnaire
} }
return response return response
@router.post( @router.post("",
"",
summary="Create new organisation.", summary="Create new organisation.",
status_code=status.HTTP_201_CREATED, status_code=status.HTTP_201_CREATED,
response_model=OrgPostOrgResponse, response_model=OrgPostOrgResponse,
responses={ responses={
status.HTTP_201_CREATED: {"description": "Successfully created organisation."}, status.HTTP_201_CREATED: {"description": "Successfully created organisation."},
status.HTTP_422_UNPROCESSABLE_CONTENT: { status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
"description": "Invalid data in request." status.HTTP_401_UNAUTHORIZED: {"description": "User must be logged in with OIDC to create organisation."},
}, status.HTTP_409_CONFLICT: {"description": "Organisation with this name already exists."},
status.HTTP_401_UNAUTHORIZED: { })
"description": "User must be logged in with OIDC to create organisation." async def create_org(db: db_dependency, user_model: user_model_claims_dependency, request_model: OrgPostOrgRequest):
},
status.HTTP_409_CONFLICT: {
"description": "Organisation with this name already exists."
},
},
)
async def create_org(
db: db_dependency,
user_model: user_model_claims_dependency,
request_model: OrgPostOrgRequest,
):
""" """
Creates a new organisation with optional questionnaire (to be completed or submitted). Creates a new organisation with optional questionnaire (to be completed or submitted).
ALl organisations are given the "partial" status on creation. See update_questionnaire() for more details. ALl organisations are given the "partial" status on creation. See update_questionnaire() for more details.
@ -144,21 +101,12 @@ async def create_org(
try: try:
db.flush() db.flush()
except IntegrityError as e: except IntegrityError as e:
if ( if isinstance(e.orig, UniqueViolation):
getattr(e.orig, "pgcode", None) == "23505" # Postgres unique violation raise ConflictException(message="Organisation with this name already exists")
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
):
raise ConflictException(
message="Organisation with this name already exists"
)
# Adds currently logged-in user to org users list and sets them as root_user # Adds currently logged-in user to org users list and sets them as root_user
org_model.user_rel.append(user_model) org_model.user_rel.append(user_model)
org_model.root_user_rel = user_model org_model.root_user_rel = user_model
for contact_type in [ for contact_type in ["billing_contact_id", "security_contact_id", "owner_contact_id"]:
"billing_contact_id",
"security_contact_id",
"owner_contact_id",
]:
contact_model = Contact(org_id=org_model.id) contact_model = Contact(org_id=org_model.id)
db.add(contact_model) db.add(contact_model)
db.flush() db.flush()
@ -168,26 +116,16 @@ async def create_org(
return response return response
@router.patch( @router.patch("/questionnaire",
"/questionnaire",
summary="Update questionnaire.", summary="Update questionnaire.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=OrgPatchQuestionnaireResponse, response_model=OrgPatchQuestionnaireResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successfully updated questionnaire."}, status.HTTP_200_OK: {"description": "Successfully updated questionnaire."},
status.HTTP_422_UNPROCESSABLE_CONTENT: { status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
"description": "Invalid data in request." status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
}, })
status.HTTP_401_UNAUTHORIZED: { async def update_questionnaire(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchQuestionnaireRequest):
"description": "Not authorised. Must be org root user."
},
},
)
async def update_questionnaire(
db: db_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: OrgPatchQuestionnaireRequest,
):
""" """
Route for updating questionnaire. Route for updating questionnaire.
The partial bool allows for submission of partially completed questionnaire and/or The partial bool allows for submission of partially completed questionnaire and/or
@ -212,29 +150,16 @@ async def update_questionnaire(
return response return response
@router.patch( @router.patch("/status",
"/status",
summary="Update status of organisation.", summary="Update status of organisation.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=OrgPatchStatusResponse, response_model=OrgPatchStatusResponse,
responses={ responses={
status.HTTP_200_OK: { status.HTTP_200_OK: {"description": "Successfully updated organisation status."},
"description": "Successfully updated organisation status." status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
}, status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be super admin."},
status.HTTP_422_UNPROCESSABLE_CONTENT: { })
"description": "Invalid data in request." async def update_status(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchStatusRequest):
},
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be super admin."
},
},
)
async def update_status(
db: db_dependency,
org_model: org_model_body_dependency,
su: super_admin_dependency,
request_model: OrgPatchStatusRequest,
):
""" """
Sets an organisation's status. This is the endpoint for approving or denying an organisation after reviewing the questionnaire. Sets an organisation's status. This is the endpoint for approving or denying an organisation after reviewing the questionnaire.
""" """
@ -245,57 +170,33 @@ async def update_status(
return response return response
@router.get( @router.get("/users",
"/users",
summary="Get email addresses of users of the organisation.", summary="Get email addresses of users of the organisation.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=OrgGetUserResponse, response_model=OrgGetUserResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successful retrieval of users."}, status.HTTP_200_OK: {"description": "Successful retrieval of users."},
status.HTTP_422_UNPROCESSABLE_CONTENT: { status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Org ID missing or invalid."},
"description": "Org ID missing or invalid." status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
}, })
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be org root user."
},
},
)
async def get_users(org_model: org_model_root_claim_query_dependency): async def get_users(org_model: org_model_root_claim_query_dependency):
""" """
Returns a list of the email addresses of all users of the organisation. Returns a list of the email addresses of all users of the organisation.
""" """
return { return {"users": [user.email for user in org_model.user_rel], "organisation": org_model}
"users": [user.email for user in org_model.user_rel],
"organisation": org_model,
}
@router.post( @router.post("/user",
"/user",
summary="Add user to the organisation.", summary="Add user to the organisation.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=OrgPostUserResponse, response_model=OrgPostUserResponse,
responses={ responses={
status.HTTP_200_OK: { status.HTTP_200_OK: {"description": "Successfully added user to the organisation."},
"description": "Successfully added user to the organisation." status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
}, status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
status.HTTP_401_UNAUTHORIZED: { status.HTTP_409_CONFLICT: {"description": "User is already a member of the organisation."},
"description": "Not authorised. Must be org root user." })
}, async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgPostUserRequest):
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Invalid data in request."
},
status.HTTP_409_CONFLICT: {
"description": "User is already a member of the organisation."
},
},
)
async def add_user_to_org(
db: db_dependency,
org_model: org_model_root_claim_body_dependency,
user_model: user_model_body_dependency,
request_model: OrgPostUserRequest,
):
""" """
Adds a user to the organisation. Adds a user to the organisation.
""" """
@ -308,27 +209,15 @@ async def add_user_to_org(
return response return response
@router.delete( @router.delete("",
"",
summary="Delete organisation from the hub.", summary="Delete organisation from the hub.",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
responses={ responses={
status.HTTP_204_NO_CONTENT: { status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."},
"description": "Successfully deleted organisation." status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be super admin."},
}, status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Org ID missing or invalid."},
status.HTTP_401_UNAUTHORIZED: { })
"description": "Not authorised. Must be super admin." async def delete_organisation_by_id(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgDeleteOrgRequest):
},
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Org ID missing or invalid."
},
},
)
async def delete_organisation_by_id(
db: db_dependency,
org_model: org_model_query_dependency,
su: super_admin_dependency,
):
""" """
Removes an organisation from the hub. Removes an organisation from the hub.
""" """
@ -336,59 +225,37 @@ async def delete_organisation_by_id(
db.commit() db.commit()
@router.patch( @router.patch("/root_user",
"/root_user",
summary="Update the root user of the organisation.", summary="Update the root user of the organisation.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=OrgPatchRootResponse, response_model=OrgPatchRootResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successfully updated root user."}, status.HTTP_200_OK: {"description": "Successfully updated root user."},
status.HTTP_422_UNPROCESSABLE_CONTENT: { status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
"description": "Invalid data in request." status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be super admin."},
}, })
status.HTTP_401_UNAUTHORIZED: { async def update_root_user(db: db_dependency, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchRootRequest):
"description": "Not authorised. Must be super admin."
},
},
)
async def update_root_user(
db: db_dependency,
org_model: org_model_body_dependency,
user_model: user_model_body_dependency,
su: super_admin_dependency,
request_model: OrgPatchRootRequest,
):
""" """
Promotes an existing organisation user to the root user, giving them full control of the org. Promotes an existing organisation user to the root user, giving them full control of the org.
""" """
if user_model not in org_model.user_rel: if user_model not in org_model.user_rel:
raise UnprocessableContentException( raise UnprocessableContentException(message="This user does not belong to your organisation.")
message="This user does not belong to your organisation."
)
org_model.root_user_rel = user_model org_model.root_user_rel = user_model
db.flush() db.flush()
response = OrgPatchRootResponse( response = OrgPatchRootResponse(name=org_model.name, root_user_email=org_model.root_user_email)
name=org_model.name, root_user_email=org_model.root_user_email
)
db.commit() db.commit()
return response return response
@router.get( @router.get("/groups",
"/groups",
summary="Get all organisation IAM groups.", summary="Get all organisation IAM groups.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=OrgGetGroupResponse, response_model=OrgGetGroupResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successful retrieval of IAM groups."}, status.HTTP_200_OK: {"description": "Successful retrieval of IAM groups."},
status.HTTP_422_UNPROCESSABLE_CONTENT: { status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Org ID missing or invalid."},
"description": "Org ID missing or invalid." status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
}, })
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be org root user."
},
},
)
async def get_org_groups(org_model: org_model_root_claim_query_dependency): async def get_org_groups(org_model: org_model_root_claim_query_dependency):
""" """
Returns a list of the names of all IAM groups created by the organisation. Returns a list of the names of all IAM groups created by the organisation.
@ -396,25 +263,15 @@ async def get_org_groups(org_model: org_model_root_claim_query_dependency):
return {"groups": [group.name for group in org_model.group_rel]} return {"groups": [group.name for group in org_model.group_rel]}
@router.delete( @router.delete("/user",
"/user",
summary="Remove user from organisation.", summary="Remove user from organisation.",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
responses={ responses={
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed user."}, status.HTTP_204_NO_CONTENT: {"description": "Successfully removed user."},
status.HTTP_401_UNAUTHORIZED: { status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
"description": "Not authorised. Must be org root user." status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
}, })
status.HTTP_422_UNPROCESSABLE_CONTENT: { async def remove_user_from_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgDeleteUserRequest):
"description": "Invalid data in request."
},
},
)
async def remove_user_from_org(
db: db_dependency,
org_model: org_model_root_claim_query_dependency,
user_model: user_model_query_dependency,
):
""" """
Revokes a user's membership in an organisation. Revokes a user's membership in an organisation.
""" """
@ -425,27 +282,16 @@ async def remove_user_from_org(
db.commit() db.commit()
@router.get( @router.get("/contact",
"/contact",
summary="Get contact for organisation.", summary="Get contact for organisation.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=OrgGetContactResponse, response_model=OrgGetContactResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successful retrieval of contact."}, status.HTTP_200_OK: {"description": "Successful retrieval of contact."},
status.HTTP_422_UNPROCESSABLE_CONTENT: { status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
"description": "Invalid data in request." status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
}, })
status.HTTP_401_UNAUTHORIZED: { async def get_contact(org_model: org_model_root_claim_query_dependency, contact_type: Annotated[ContactType, Query(description="Must be billing|security|owner")]):
"description": "Not authorised. Must be org root user."
},
},
)
async def get_contact(
org_model: org_model_root_claim_query_dependency,
contact_type: Annotated[
ContactType, Query(description="Must be billing|security|owner")
],
):
""" """
Gets full details for a contact point at an organisation. Gets full details for a contact point at an organisation.
""" """
@ -463,33 +309,21 @@ async def get_contact(
raise ContactNotFoundException() raise ContactNotFoundException()
address = ContactAddress.model_validate(contact_model) address = ContactAddress.model_validate(contact_model)
contact_response = ContactModel.model_construct( contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address)
**contact_model.__dict__, address=address
)
return {"contact": contact_response, "organisation": org_model} return {"contact": contact_response, "organisation": org_model}
@router.patch( @router.patch("/contact",
"/contact",
summary="Update contact for organisation.", summary="Update contact for organisation.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=OrgPatchContactResponse, response_model=OrgPatchContactResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successfully updated contact."}, status.HTTP_200_OK: {"description": "Successfully updated contact."},
status.HTTP_422_UNPROCESSABLE_CONTENT: { status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
"description": "Invalid data in request." status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
}, })
status.HTTP_401_UNAUTHORIZED: { async def update_contact(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchContactRequest):
"description": "Not authorised. Must be org root user."
},
},
)
async def update_contact(
db: db_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: OrgPatchContactRequest,
):
""" """
Updates details for a contact point at an organisation. Updates details for a contact point at an organisation.
""" """
@ -517,9 +351,7 @@ async def update_contact(
db.flush() db.flush()
address = ContactAddress.model_validate(contact_model) address = ContactAddress.model_validate(contact_model)
contact_response = ContactModel.model_construct( contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address)
**contact_model.__dict__, address=address
)
db.commit() db.commit()

View file

@ -6,7 +6,6 @@ Models follow the nomenclature of:
- Mixins: "<Attribute>Mixin" - Mixins: "<Attribute>Mixin"
- Models: "<Module><Method><Resource><Opt:Resource><Direction>" ie "OrgPostOrgRequest" - Models: "<Module><Method><Resource><Opt:Resource><Direction>" ie "OrgPostOrgRequest"
""" """
from typing import Optional from typing import Optional
from pydantic import EmailStr, ConfigDict, Field from pydantic import EmailStr, ConfigDict, Field
@ -21,13 +20,11 @@ from src.organisation.constants import Status, ContactType
class OrgIDMixin(CustomBaseModel): class OrgIDMixin(CustomBaseModel):
organisation_id: int = Field(gt=0) organisation_id: int = Field(gt=0)
class Questionnaire(CustomBaseModel): class Questionnaire(CustomBaseModel):
question_one: Optional[str] = None question_one: Optional[str] = None
question_two: Optional[str] = None question_two: Optional[str] = None
question_three: Optional[str] = None question_three: Optional[str] = None
class OrgSchema(CustomBaseModel): class OrgSchema(CustomBaseModel):
id: int id: int
name: str name: str
@ -37,32 +34,26 @@ class OrgPostOrgRequest(CustomBaseModel):
name: str name: str
intake_questionnaire: Optional[Questionnaire] = None intake_questionnaire: Optional[Questionnaire] = None
class OrgPostOrgResponse(CustomBaseModel): class OrgPostOrgResponse(CustomBaseModel):
name: str name: str
status: Status status: Status
class OrgPatchQuestionnaireRequest(OrgIDMixin): class OrgPatchQuestionnaireRequest(OrgIDMixin):
intake_questionnaire: Questionnaire intake_questionnaire: Questionnaire
partial: bool partial: bool
class OrgPatchQuestionnaireResponse(CustomBaseModel): class OrgPatchQuestionnaireResponse(CustomBaseModel):
name: str name: str
intake_questionnaire: Questionnaire intake_questionnaire: Questionnaire
status: Status status: Status
class OrgPatchStatusRequest(OrgIDMixin): class OrgPatchStatusRequest(OrgIDMixin):
status: Status status: Status
class OrgPatchStatusResponse(CustomBaseModel): class OrgPatchStatusResponse(CustomBaseModel):
name: str name: str
status: Status status: Status
class OrgPatchContactRequest(OrgIDMixin): class OrgPatchContactRequest(OrgIDMixin):
contact_type: ContactType contact_type: ContactType
@ -79,47 +70,41 @@ class OrgPatchContactRequest(OrgIDMixin):
country_code: Optional[str] = None country_code: Optional[str] = None
postal_code: Optional[str] = None postal_code: Optional[str] = None
class OrgPostUserRequest(OrgIDMixin, UserIDMixin): class OrgPostUserRequest(OrgIDMixin, UserIDMixin):
pass pass
class OrgPostUserResponse(CustomBaseModel): class OrgPostUserResponse(CustomBaseModel):
users: list[str] users: list[str]
class OrgDeleteUserRequest(OrgIDMixin, UserIDMixin):
pass
class OrgPatchRootRequest(OrgIDMixin, UserIDMixin): class OrgPatchRootRequest(OrgIDMixin, UserIDMixin):
pass pass
class OrgPatchRootResponse(CustomBaseModel): class OrgPatchRootResponse(CustomBaseModel):
name: str name: str
root_user_email: str root_user_email: str
class OrgGetUserResponse(CustomBaseModel): class OrgGetUserResponse(CustomBaseModel):
users: list[str] users: list[str]
organisation: OrgSchema organisation: OrgSchema
class OrgGetGroupResponse(CustomBaseModel): class OrgGetGroupResponse(CustomBaseModel):
groups: list[str] groups: list[str]
class OrgGetContactResponse(CustomBaseModel): class OrgGetContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel contact: ContactModel
organisation: OrgSchema organisation: OrgSchema
class OrgPatchContactResponse(CustomBaseModel): class OrgPatchContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel contact: ContactModel
organisation: OrgSchema organisation: OrgSchema
class OrgGetOrgResponse(CustomBaseModel): class OrgGetOrgResponse(CustomBaseModel):
id: int id: int
name: str name: str
@ -129,3 +114,6 @@ class OrgGetOrgResponse(CustomBaseModel):
billing_contact: Optional[str] = None billing_contact: Optional[str] = None
security_contact: Optional[str] = None security_contact: Optional[str] = None
intake_questionnaire: Optional[Questionnaire] = None intake_questionnaire: Optional[Questionnaire] = None
class OrgDeleteOrgRequest(OrgIDMixin):
pass

View file

@ -5,7 +5,6 @@ Exports:
- CustomBaseModel: Schema used for all other Pydantic models - CustomBaseModel: Schema used for all other Pydantic models
- ResourceName - ResourceName
""" """
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional

View file

@ -5,7 +5,6 @@ Exports:
- service_model_query_dependency: service_model: Gets service model from db, if it exists. Uses service_id from query param. - service_model_query_dependency: service_model: Gets service model from db, if it exists. Uses service_id from query param.
- service_model_body_dependency: service_model: Gets service model from db, if it exists. Uses service_id from request body. - service_model_body_dependency: service_model: Gets service model from db, if it exists. Uses service_id from request body.
""" """
from typing import Annotated from typing import Annotated
from fastapi import Depends, Query from fastapi import Depends, Query
@ -16,19 +15,14 @@ from src.service.models import Service
from src.service.schemas import ServiceIDMixin from src.service.schemas import ServiceIDMixin
async def get_service_model_query( async def get_service_model_query(db: db_dependency, service_id: Annotated[int, Query(gt=0)]):
db: db_dependency, service_id: Annotated[int, Query(gt=0)]
):
service_model = db.get(Service, service_id) service_model = db.get(Service, service_id)
if service_model is None: if service_model is None:
raise ServiceNotFoundException(service_id=service_id) raise ServiceNotFoundException(service_id=service_id)
return service_model return service_model
service_model_query_dependency = Annotated[type[Service], Depends(get_service_model_query)]
service_model_query_dependency = Annotated[
type[Service], Depends(get_service_model_query)
]
async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixin): async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixin):
@ -38,7 +32,4 @@ async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixi
return service_model return service_model
service_model_body_dependency = Annotated[type[Service], Depends(get_service_model_body)]
service_model_body_dependency = Annotated[
type[Service], Depends(get_service_model_body)
]

View file

@ -4,7 +4,6 @@ Exceptions related to the services module
Exceptions: Exceptions:
- ServiceNotFoundException: Takes an optional service_id int - ServiceNotFoundException: Takes an optional service_id int
""" """
from typing import Optional from typing import Optional
from fastapi import HTTPException, status from fastapi import HTTPException, status
@ -12,11 +11,7 @@ from fastapi import HTTPException, status
class ServiceNotFoundException(HTTPException): class ServiceNotFoundException(HTTPException):
def __init__(self, service_id: Optional[int] = None) -> None: def __init__(self, service_id: Optional[int] = None) -> None:
detail = ( detail = "Service not found" if service_id is None else f"Service with ID '{service_id}' was not found."
"Service not found"
if service_id is None
else f"Service with ID '{service_id}' was not found."
)
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,

View file

@ -5,7 +5,6 @@ Models:
- Service: - Service:
- id[PK], name[U], api_key[U] - id[PK], name[U], api_key[U]
""" """
from sqlalchemy import Column, Integer, String from sqlalchemy import Column, Integer, String
from src.database import Base from src.database import Base

View file

@ -7,31 +7,19 @@ Endpoints:
- [PATCH](/key): [super_admin]: Refreshes the API key for a service(id), returning a new one. - [PATCH](/key): [super_admin]: Refreshes the API key for a service(id), returning a new one.
- [DELETE](/): [super_admin]: Removes a service(id) from the hub. - [DELETE](/): [super_admin]: Removes a service(id) from the hub.
""" """
from fastapi import APIRouter, status from fastapi import APIRouter, status
from psycopg.errors import UniqueViolation
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from src.exceptions import ConflictException from src.exceptions import ConflictException
from src.database import db_dependency from src.database import db_dependency
from src.auth.dependencies import ( from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency
super_admin_dependency,
org_model_root_claim_query_dependency,
)
from src.service.models import Service from src.service.models import Service
from src.service.utils import generate_api_key from src.service.utils import generate_api_key
from src.service.dependencies import ( from src.service.dependencies import service_model_body_dependency
service_model_body_dependency, from src.service.schemas import ServiceGetServiceResponse, ServicePostServiceRequest, ServicePostServiceResponse, \
service_model_query_dependency, ServiceWithKeySchema, ServicePatchKeyResponse, ServicePatchKeyRequest, ServiceDeleteServiceRequest
)
from src.service.schemas import (
ServiceGetServiceResponse,
ServicePostServiceRequest,
ServicePostServiceResponse,
ServiceWithKeySchema,
ServicePatchKeyResponse,
ServicePatchKeyRequest,
)
router = APIRouter( router = APIRouter(
tags=["Service"], tags=["Service"],
@ -39,19 +27,15 @@ router = APIRouter(
) )
@router.get( @router.get("/",
"/",
summary="Get all services", summary="Get all services",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=ServiceGetServiceResponse, response_model=ServiceGetServiceResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
}, })
) async def get_all_services(db: db_dependency, org_model: org_model_root_claim_query_dependency):
async def get_all_services(
db: db_dependency, org_model: org_model_root_claim_query_dependency
):
""" """
Returns the ID and name of all services registered to the hub. Returns the ID and name of all services registered to the hub.
""" """
@ -60,24 +44,16 @@ async def get_all_services(
return {"services": permission_models} return {"services": permission_models}
@router.post( @router.post("/",
"/",
summary="Register a new service.", summary="Register a new service.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=ServicePostServiceResponse, response_model=ServicePostServiceResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successfully registered a new service"}, status.HTTP_200_OK: {"description": "Successfully registered a new service"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_409_CONFLICT: { status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"},
"description": "Service with this name already exists" })
}, async def register_service(db: db_dependency, su: super_admin_dependency, request_model: ServicePostServiceRequest):
},
)
async def register_service(
db: db_dependency,
su: super_admin_dependency,
request_model: ServicePostServiceRequest,
):
""" """
Registers a new service to the hub, generating and returning an API key for it. Registers a new service to the hub, generating and returning an API key for it.
""" """
@ -88,32 +64,23 @@ async def register_service(
try: try:
db.flush() db.flush()
except IntegrityError as e: except IntegrityError as e:
if ( if isinstance(e.orig, UniqueViolation):
getattr(e.orig, "pgcode", None) == "23505" # Postgres unique violation
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
):
raise ConflictException(message="Service with this name already exists") raise ConflictException(message="Service with this name already exists")
response = ServiceWithKeySchema(**service_model.__dict__) response = ServiceWithKeySchema(**service_model.__dict__)
db.commit() db.commit()
return {"service": response} return {"service": response}
@router.patch( @router.patch("/key",
"/key",
summary="Regenerate service API key.", summary="Regenerate service API key.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=ServicePatchKeyResponse, response_model=ServicePatchKeyResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successful update of API key"}, status.HTTP_200_OK: {"description": "Successful update of API key"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
}, })
) async def regenerate_api_key(db: db_dependency, su: super_admin_dependency,
async def regenerate_api_key( service_model: service_model_body_dependency, request_model: ServicePatchKeyRequest):
db: db_dependency,
su: super_admin_dependency,
service_model: service_model_body_dependency,
request_model: ServicePatchKeyRequest,
):
""" """
Generates and returns a new API key for the service to access the hub. Generates and returns a new API key for the service to access the hub.
""" """
@ -126,22 +93,15 @@ async def regenerate_api_key(
return {"service": response} return {"service": response}
@router.delete( @router.delete("/",
"/",
summary="Remove a service.", summary="Remove a service.",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
responses={ responses={
status.HTTP_204_NO_CONTENT: { status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"},
"description": "Successfully removed service from db"
},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
}, })
) async def remove_service(db: db_dependency, service_model: service_model_body_dependency, su: super_admin_dependency,
async def remove_service( request_model: ServiceDeleteServiceRequest):
db: db_dependency,
service_model: service_model_query_dependency,
su: super_admin_dependency,
):
""" """
Removes a service from the hub. Removes a service from the hub.
""" """

View file

@ -6,42 +6,36 @@ Models follow the nomenclature of:
- Mixins: "<Attribute>Mixin" - Mixins: "<Attribute>Mixin"
- Models: "<Module><Method><Resource><Opt:Resource><Direction>" ie "ServiceGetServiceResponse" - Models: "<Module><Method><Resource><Opt:Resource><Direction>" ie "ServiceGetServiceResponse"
""" """
from pydantic import ConfigDict, Field from pydantic import ConfigDict, Field
from src.schemas import CustomBaseModel from src.schemas import CustomBaseModel
class ServiceIDMixin(CustomBaseModel): class ServiceIDMixin(CustomBaseModel):
service_id: int = Field(gt=0) service_id: int = Field(gt=0)
class ServiceSchema(CustomBaseModel): class ServiceSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int id: int
name: str name: str
class ServiceWithKeySchema(ServiceSchema): class ServiceWithKeySchema(ServiceSchema):
api_key: str api_key: str
class ServiceGetServiceResponse(CustomBaseModel): class ServiceGetServiceResponse(CustomBaseModel):
services: list[ServiceSchema] services: list[ServiceSchema]
class ServicePostServiceRequest(CustomBaseModel): class ServicePostServiceRequest(CustomBaseModel):
name: str name: str
class ServicePostServiceResponse(CustomBaseModel): class ServicePostServiceResponse(CustomBaseModel):
service: ServiceWithKeySchema service: ServiceWithKeySchema
class ServicePatchKeyRequest(ServiceIDMixin): class ServicePatchKeyRequest(ServiceIDMixin):
pass pass
class ServicePatchKeyResponse(CustomBaseModel): class ServicePatchKeyResponse(CustomBaseModel):
service: ServiceWithKeySchema service: ServiceWithKeySchema
class ServiceDeleteServiceRequest(ServiceIDMixin):
pass

View file

@ -4,7 +4,6 @@ Non-business logic reusable functions and classes for the services module
Exports: Exports:
- generate_api_key(): returns a new UUID - generate_api_key(): returns a new UUID
""" """
import uuid import uuid

View file

@ -6,7 +6,6 @@ Exports:
- user_model_query_dependency: user_model: Gets user model from db, if it exists. Uses user_id from query param - user_model_query_dependency: user_model: Gets user model from db, if it exists. Uses user_id from query param
- user_model_body_dependency: user_model: Gets user model from db, if it exists. Uses user_id from request body. - user_model_body_dependency: user_model: Gets user model from db, if it exists. Uses user_id from request body.
""" """
from typing import Annotated from typing import Annotated
from fastapi import Depends, Query from fastapi import Depends, Query
@ -29,7 +28,6 @@ async def get_user_model_claims(claims: claims_dependency, db: db_dependency):
return user_model return user_model
user_model_claims_dependency = Annotated[type[User], Depends(get_user_model_claims)] user_model_claims_dependency = Annotated[type[User], Depends(get_user_model_claims)]
@ -40,7 +38,6 @@ async def get_user_model_query(db: db_dependency, user_id: Annotated[int, Query(
return user_model return user_model
user_model_query_dependency = Annotated[type[User], Depends(get_user_model_query)] user_model_query_dependency = Annotated[type[User], Depends(get_user_model_query)]
@ -51,5 +48,4 @@ async def get_user_model_body(db: db_dependency, request_model: UserIDMixin):
return user_model return user_model
user_model_body_dependency = Annotated[type[User], Depends(get_user_model_body)] user_model_body_dependency = Annotated[type[User], Depends(get_user_model_body)]

View file

@ -4,7 +4,6 @@ Exceptions related to the user module
Exceptions: Exceptions:
- UserNotFoundException: Takes an optional user_id int - UserNotFoundException: Takes an optional user_id int
""" """
from typing import Optional from typing import Optional
from fastapi import HTTPException, status from fastapi import HTTPException, status
@ -12,11 +11,7 @@ from fastapi import HTTPException, status
class UserNotFoundException(HTTPException): class UserNotFoundException(HTTPException):
def __init__(self, user_id: Optional[int] = None) -> None: def __init__(self, user_id: Optional[int] = None) -> None:
detail = ( detail = "User not found" if user_id is None else f"User with ID '{user_id}' was not found."
"User not found"
if user_id is None
else f"User with ID '{user_id}' was not found."
)
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,

View file

@ -9,7 +9,6 @@ Models:
- organisations: Calc property list of organisation_rel.name - organisations: Calc property list of organisation_rel.name
- groups: Calc property dict of {group_rel.org_rel.name: group_rel.name} - groups: Calc property dict of {group_rel.org_rel.name: group_rel.name}
""" """
from collections import defaultdict from collections import defaultdict
from sqlalchemy import Column, Integer, String from sqlalchemy import Column, Integer, String

View file

@ -7,16 +7,11 @@ Endpoints:
- [GET](/user/): [super admin]: Returns user(id) details. - [GET](/user/): [super admin]: Returns user(id) details.
- [DELETE](/user/): [super admin]: Removes a User(id) from the hub database. - [DELETE](/user/): [super admin]: Removes a User(id) from the hub database.
""" """
from fastapi import APIRouter from fastapi import APIRouter
from starlette import status from starlette import status
from src.user.schemas import UserResponse, OIDCClaims from src.user.schemas import UserResponse, OIDCClaims, UserDeleteUserRequest
from src.user.dependencies import ( from src.user.dependencies import user_model_claims_dependency, user_model_query_dependency, user_model_body_dependency
user_model_claims_dependency,
user_model_query_dependency,
user_model_body_dependency,
)
from src.auth.dependencies import super_admin_dependency from src.auth.dependencies import super_admin_dependency
from src.auth.service import claims_dependency from src.auth.service import claims_dependency
@ -28,15 +23,13 @@ router = APIRouter(
) )
@router.get( @router.get("/self/claims",
"/self/claims",
summary="Get current user OIDC claims.", summary="Get current user OIDC claims.",
response_model=OIDCClaims, response_model=OIDCClaims,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
responses={ responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
}, })
)
async def current_user_claims(user: claims_dependency): async def current_user_claims(user: claims_dependency):
""" """
Returns the full OIDC claims associated with the currently logged-in user. Returns the full OIDC claims associated with the currently logged-in user.
@ -45,16 +38,14 @@ async def current_user_claims(user: claims_dependency):
return user return user
@router.get( @router.get("/self/db",
"/self/db",
summary="Get current user hub details.", summary="Get current user hub details.",
response_model=UserResponse, response_model=UserResponse,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
responses={ responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
}, })
)
async def current_user(user_model: user_model_claims_dependency): async def current_user(user_model: user_model_claims_dependency):
""" """
Returns the database details associated with the currently logged-in user. Returns the database details associated with the currently logged-in user.
@ -62,39 +53,30 @@ async def current_user(user_model: user_model_claims_dependency):
return user_model return user_model
@router.get( @router.get("/",
"/",
summary="Get user hub details by ID.", summary="Get user hub details by ID.",
response_model=UserResponse, response_model=UserResponse,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
responses={ responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
}, })
) async def get_user_by_id(user_model: user_model_query_dependency, su: super_admin_dependency):
async def get_user_by_id(
user_model: user_model_query_dependency, su: super_admin_dependency
):
""" """
Returns the database details associated with the provided user ID. Returns the database details associated with the provided user ID.
""" """
return user_model return user_model
@router.delete( @router.delete("/",
"/",
summary="Delete user from hub by ID.", summary="Delete user from hub by ID.",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
responses={ responses={
status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, status.HTTP_404_NOT_FOUND: {"description": "User not found"},
}, })
) async def delete_user_by_id(db: db_dependency, user_model: user_model_body_dependency, su: super_admin_dependency,
async def delete_user_by_id( request_model: UserDeleteUserRequest):
db: db_dependency,
user_model: user_model_query_dependency,
su: super_admin_dependency,
):
""" """
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login. Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
""" """

View file

@ -1,7 +1,6 @@
""" """
Pydantic models for the user module Pydantic models for the user module
""" """
from typing import Optional from typing import Optional
from pydantic import Field from pydantic import Field
@ -55,3 +54,7 @@ class UserResponse(CustomBaseModel):
class OrgResponse(CustomBaseModel): class OrgResponse(CustomBaseModel):
org_id: int org_id: int
name: str name: str
class UserDeleteUserRequest(UserIDMixin):
pass

View file

@ -4,7 +4,6 @@ Module specific business logic for user module
Exports: Exports:
- add_user_to_db: Creates a User record from OIDC claims, or updates user details - add_user_to_db: Creates a User record from OIDC claims, or updates user details
""" """
from typing import Any from typing import Any
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -17,12 +16,7 @@ from src.user.models import User
async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
try: try:
valid_user = OIDCUser( valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"])
first_name=user_claims["given_name"],
last_name=user_claims["family_name"],
email=user_claims["email"],
oidc_id=user_claims["sub"],
)
except Exception as e: except Exception as e:
print(e) print(e)
raise UnprocessableContentException("Invalid or missing OIDC data") raise UnprocessableContentException("Invalid or missing OIDC data")

View file

@ -37,14 +37,11 @@ def db_session():
async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]: async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override(): def get_db_override():
return db_session return db_session
app.dependency_overrides[get_db] = get_db_override app.dependency_overrides[get_db] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_super_admin_list] = testing_su_list app.dependency_overrides[get_super_admin_list] = testing_su_list
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient( async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac:
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac yield ac
app.dependency_overrides.clear() app.dependency_overrides.clear()
@ -54,58 +51,37 @@ async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]:
async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]: async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override(): def get_db_override():
return db_session return db_session
app.dependency_overrides[get_db] = get_db_override app.dependency_overrides[get_db] = get_db_override
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient( async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac:
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac yield ac
app.dependency_overrides.clear() app.dependency_overrides.clear()
@pytest.fixture @pytest.fixture
async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]: async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override(): def get_db_override():
return db_session return db_session
app.dependency_overrides[get_db] = get_db_override app.dependency_overrides[get_db] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_super_admin_list] = empty_su_list app.dependency_overrides[get_super_admin_list] = empty_su_list
transport = ASGITransport(app=app) transport = ASGITransport(app=app)
async with AsyncClient( async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac:
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac yield ac
app.dependency_overrides.clear() app.dependency_overrides.clear()
def _seed(db): def _seed(db):
db.add( db.add(User(email="admin@test.com", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-mnop"))
User(
email="admin@test.com",
first_name="Admin",
last_name="Test",
oidc_id="abcd-efgh-ijkl-mnop",
)
)
db.add(Contact(org_id=1, email="billing@test.org", phonenumber="07521539927")) db.add(Contact(org_id=1, email="billing@test.org", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="owner@test.org", phonenumber="07521539927")) db.add(Contact(org_id=1, email="owner@test.org", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="security@test.org", phonenumber="07521539927")) db.add(Contact(org_id=1, email="security@test.org", phonenumber="07521539927"))
db.flush() db.flush()
db.add( db.add(Org(name="Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3,
Org( status="approved", intake_questionnaire={"question_two": "answer two"}))
name="Test Org",
root_user_id=1,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
intake_questionnaire={"question_two": "answer two"},
)
)
db.add(Service(name="Test Service", api_key="123456789")) db.add(Service(name="Test Service", api_key="123456789"))
db.add(Permission(service_id=1, resource="test_resource", action="read")) db.add(Permission(service_id=1, resource="test_resource", action="read"))
db.add(Group(name="Test Group", org_id=1)) db.add(Group(name="Test Group", org_id=1))
@ -155,7 +131,6 @@ def generate_query_and_status(params) -> list[tuple[str, int]]:
return query_and_status return query_and_status
# # Produces a text file with method and path for every endpoint in the API # # Produces a text file with method and path for every endpoint in the API
# from fastapi.routing import APIRoute # from fastapi.routing import APIRoute
# #

View file

@ -3,7 +3,6 @@ This test module checks relevant endpoints to ensure only approved orgs get acce
Endpoints not checked here are endpoints that do not require an org check. Endpoints not checked here are endpoints that do not require an org check.
Delete endpoints are currently skipped because the testing system cannot use bodies in deletes. Delete endpoints are currently skipped because the testing system cannot use bodies in deletes.
""" """
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
@ -28,27 +27,18 @@ async def test_get_org_auth_approval(default_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_questionnaire_auth_approval(default_client: AsyncClient): async def test_patch_org_questionnaire_auth_approval(default_client: AsyncClient):
resp = await default_client.patch( resp = await default_client.patch("/org/questionnaire", json={"organisation_id": 1,
"/org/questionnaire", "intake_questionnaire": {"question_one": "new answer one",
json={
"organisation_id": 1,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None, "question_two": None,
"question_three": None, "question_three": None},
}, "partial": True})
"partial": True,
},
)
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_status_auth_approval(default_client: AsyncClient): async def test_patch_org_status_auth_approval(default_client: AsyncClient):
resp = await default_client.patch( resp = await default_client.patch("/org/status", json={"organisation_id": 1, "status": "submitted"})
"/org/status", json={"organisation_id": 1, "status": "submitted"}
)
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
@ -62,42 +52,22 @@ async def test_get_org_users_auth_approval(default_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_user_auth_approval(default_client: AsyncClient, db_session): async def test_post_org_user_auth_approval(default_client: AsyncClient, db_session):
db_session.add( db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush() db_session.flush()
resp = await default_client.post( resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 2})
"/org/user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_root_user_auth_approval( async def test_patch_org_root_user_auth_approval(default_client: AsyncClient, db_session):
default_client: AsyncClient, db_session db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
):
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush() db_session.flush()
db_session.add(OrgUsers(org_id=1, user_id=2)) db_session.add(OrgUsers(org_id=1, user_id=2))
db_session.flush() db_session.flush()
resp = await default_client.patch( resp = await default_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2})
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@ -118,14 +88,8 @@ async def test_get_org_contact_auth_approval(default_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_contact_auth_approval(default_client: AsyncClient): async def test_patch_org_contact_auth_approval(default_client: AsyncClient):
resp = await default_client.patch( resp = await default_client.patch("/org/contact",
"/org/contact", json={"organisation_id": 1, "contact_type": "billing", "email": "user@example.com"})
json={
"organisation_id": 1,
"contact_type": "billing",
"email": "user@example.com",
},
)
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
@ -153,44 +117,26 @@ async def test_get_iam_group_users_auth_approval(default_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_group_auth_approval(default_client: AsyncClient): async def test_post_iam_group_auth_approval(default_client: AsyncClient):
resp = await default_client.post( resp = await default_client.post("/iam/group", json={"name": "New Group", "organisation_id": 1})
"/iam/group", json={"name": "New Group", "organisation_id": 1}
)
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_permission_auth_approval( async def test_put_iam_group_permission_auth_approval(default_client: AsyncClient, db_session):
default_client: AsyncClient, db_session
):
db_session.add(Group(name="Test Group Two", org_id=1)) db_session.add(Group(name="Test Group Two", org_id=1))
db_session.flush() db_session.flush()
resp = await default_client.put( resp = await default_client.put("/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 1})
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 1},
)
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_user_auth_approval( async def test_put_iam_group_user_auth_approval(default_client: AsyncClient, db_session):
default_client: AsyncClient, db_session db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
):
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush() db_session.flush()
resp = await default_client.put( resp = await default_client.put("/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1})
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1}
)
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@ -204,8 +150,6 @@ async def test_get_iam_permissions_auth_approval(default_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_permissions_search_auth_approval(default_client: AsyncClient): async def test_post_iam_permissions_search_auth_approval(default_client: AsyncClient):
resp = await default_client.post( resp = await default_client.post("/iam/permissions/search", json={"organisation_id": 1, "action": "read"})
"/iam/permissions/search", json={"organisation_id": 1, "action": "read"}
)
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]

View file

@ -1,5 +1,5 @@
""" """ """
"""
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
@ -10,26 +10,11 @@ from src.user.models import User
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_auth_root_su(default_client: AsyncClient, db_session): async def test_get_org_auth_root_su(default_client: AsyncClient, db_session):
# If a super admin can access a resource when not the root user # If a super admin can access a resource when not the root user
db_session.add( db_session.add(User(email="admin@test.org", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-4321"))
User(
email="admin@test.org",
first_name="Admin",
last_name="Test",
oidc_id="abcd-efgh-ijkl-4321",
)
)
db_session.flush() db_session.flush()
db_session.add( db_session.add(
Org( Org(name="Test Org Two", root_user_id=2, billing_contact_id=1, owner_contact_id=2, security_contact_id=3,
name="Test Org Two", status="approved", intake_questionnaire={}))
root_user_id=2,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
intake_questionnaire={},
)
)
db_session.flush() db_session.flush()
resp = await default_client.get("/org?org_id=2") resp = await default_client.get("/org?org_id=2")

View file

@ -2,7 +2,6 @@
This module ensures root user only endpoints do return a correctly formatted 401 when user is not the root user for the org This module ensures root user only endpoints do return a correctly formatted 401 when user is not the root user for the org
DELETE endpoints are not tested DELETE endpoints are not tested
""" """
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
@ -13,26 +12,10 @@ from src.iam.models import Group
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def add_second_org(db_session): def add_second_org(db_session):
db_session.add( db_session.add(User(email="admin@test.org", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-4321"))
User(
email="admin@test.org",
first_name="Admin",
last_name="Test",
oidc_id="abcd-efgh-ijkl-4321",
)
)
db_session.flush() db_session.flush()
db_session.add( db_session.add(Org(name="Test Org Two", root_user_id=2, billing_contact_id=1, owner_contact_id=2, security_contact_id=3,
Org( status="approved", intake_questionnaire={}))
name="Test Org Two",
root_user_id=2,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
intake_questionnaire={},
)
)
db_session.flush() db_session.flush()
@ -46,18 +29,11 @@ async def test_get_org_auth_root(no_su_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient): async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch("/org/questionnaire", json={"organisation_id": 2,
"/org/questionnaire", "intake_questionnaire": {"question_one": "new answer one",
json={
"organisation_id": 2,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None, "question_two": None,
"question_three": None, "question_three": None},
}, "partial": True})
"partial": True,
},
)
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@ -73,19 +49,10 @@ async def test_get_org_users_auth_root(no_su_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_user_auth_root(no_su_client: AsyncClient, db_session): async def test_post_org_user_auth_root(no_su_client: AsyncClient, db_session):
db_session.add( db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush() db_session.flush()
resp = await no_su_client.post( resp = await no_su_client.post("/org/user", json={"organisation_id": 2, "user_id": 2})
"/org/user", json={"organisation_id": 2, "user_id": 2}
)
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@ -109,14 +76,8 @@ async def test_get_org_contact_auth_root(no_su_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_contact_auth_root(no_su_client: AsyncClient): async def test_patch_org_contact_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch("/org/contact",
"/org/contact", json={"organisation_id": 2, "contact_type": "billing", "email": "user@example.com"})
json={
"organisation_id": 2,
"contact_type": "billing",
"email": "user@example.com",
},
)
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@ -148,24 +109,17 @@ async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_group_auth_root(no_su_client: AsyncClient): async def test_post_iam_group_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.post( resp = await no_su_client.post("/iam/group", json={"name": "New Group", "organisation_id": 2})
"/iam/group", json={"name": "New Group", "organisation_id": 2}
)
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_permission_auth_root( async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient, db_session):
no_su_client: AsyncClient, db_session
):
db_session.add(Group(name="Test Group Two", org_id=2)) db_session.add(Group(name="Test Group Two", org_id=2))
db_session.flush() db_session.flush()
resp = await no_su_client.put( resp = await no_su_client.put("/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 2})
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 2},
)
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@ -173,19 +127,10 @@ async def test_put_iam_group_permission_auth_root(
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_user_auth_root(no_su_client: AsyncClient, db_session): async def test_put_iam_group_user_auth_root(no_su_client: AsyncClient, db_session):
db_session.add( db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush() db_session.flush()
resp = await no_su_client.put( resp = await no_su_client.put("/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2})
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2}
)
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@ -201,9 +146,7 @@ async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient): async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.post( resp = await no_su_client.post("/iam/permissions/search", json={"organisation_id": 2, "action": "read"})
"/iam/permissions/search", json={"organisation_id": 2, "action": "read"}
)
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]

View file

@ -2,7 +2,6 @@
This module ensures super admin only endpoints do return a correctly formatted 401 when user is not a super admin This module ensures super admin only endpoints do return a correctly formatted 401 when user is not a super admin
DELETE endpoints are not tested DELETE endpoints are not tested
""" """
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
@ -20,9 +19,7 @@ async def test_get_user_auth_su(no_su_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_status_auth_su(no_su_client: AsyncClient): async def test_patch_org_status_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch("/org/status", json={"organisation_id": 1, "status": "submitted"})
"/org/status", json={"organisation_id": 1, "status": "submitted"}
)
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@ -30,21 +27,12 @@ async def test_patch_org_status_auth_su(no_su_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient, db_session): async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient, db_session):
db_session.add( db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush() db_session.flush()
db_session.add(OrgUsers(org_id=1, user_id=2)) db_session.add(OrgUsers(org_id=1, user_id=2))
db_session.flush() db_session.flush()
resp = await no_su_client.patch( resp = await no_su_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2})
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@ -68,10 +56,7 @@ async def test_post_service_auth_su(no_su_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_perm_success(no_su_client: AsyncClient, db_session): async def test_post_perm_success(no_su_client: AsyncClient, db_session):
resp = await no_su_client.post( resp = await no_su_client.post("/iam/permission", json={"service_id": 1, "resource": "test_resource", "action": "create"})
"/iam/permission",
json={"service_id": 1, "resource": "test_resource", "action": "create"},
)
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"

View file

@ -1,7 +1,6 @@
""" """
This testing module removes the testing user override to verify that endpoints with only the user requirement return a 401 error when not logged in This testing module removes the testing user override to verify that endpoints with only the user requirement return a 401 error when not logged in
""" """
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient

View file

@ -1,5 +1,5 @@
""" """ """
"""
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
@ -15,15 +15,13 @@ async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient
body = { body = {
"service": "Test Service", "service": "Test Service",
"organisation": "Test Org", "organisation": "Test Org",
"resource": "test_resource", "resource": "test_resource"
} }
headers = { headers = {
"Authorization": "Bearer not_checked_when_auth_is_disabled", "Authorization": "Bearer not_checked_when_auth_is_disabled",
"X-API-Key": "123456789", "X-API-Key": "123456789"
} }
resp = await default_client.post( resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers)
"/iam/can_act_on_resource?action=read", json=body, headers=headers
)
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -32,20 +30,23 @@ async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient
@pytest.mark.parametrize( @pytest.mark.parametrize(
"service, api_key", "service, api_key",
[("Test Service", "not_the_correct_key"), ("Test Service Two", "123456789")], [
("Test Service", "not_the_correct_key"),
("Test Service Two", "123456789")
],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_act_on_resource_wrong_key( async def test_act_on_resource_wrong_key(default_client: AsyncClient, db_session, service: str, api_key: str):
default_client: AsyncClient, db_session, service: str, api_key: str body = {
): "service": service,
body = {"service": service, "organisation": "Test Org", "resource": "test_resource"} "organisation": "Test Org",
"resource": "test_resource"
}
headers = { headers = {
"Authorization": "Bearer not_checked_when_auth_is_disabled", "Authorization": "Bearer not_checked_when_auth_is_disabled",
"X-API-Key": api_key, "X-API-Key": api_key
} }
resp = await default_client.post( resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers)
"/iam/can_act_on_resource?action=read", json=body, headers=headers
)
data = resp.json() data = resp.json()
assert resp.status_code == 401 assert resp.status_code == 401
@ -57,12 +58,12 @@ async def test_act_on_resource_missing_key(default_client: AsyncClient):
body = { body = {
"service": "Test Service", "service": "Test Service",
"organisation": "Test Org", "organisation": "Test Org",
"resource": "test_resource", "resource": "test_resource"
} }
headers = {"Authorization": "Bearer not_checked_when_auth_is_disabled"} headers = {
resp = await default_client.post( "Authorization": "Bearer not_checked_when_auth_is_disabled"
"/iam/can_act_on_resource?action=read", json=body, headers=headers }
) resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers)
data = resp.json() data = resp.json()
assert resp.status_code == 401 assert resp.status_code == 401
@ -81,17 +82,18 @@ async def test_act_on_resource_missing_key(default_client: AsyncClient):
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_act_on_resource_endpoint_status_checks( async def test_act_on_resource_endpoint_status_checks(default_client: AsyncClient, service, org, resource, action,
default_client: AsyncClient, service, org, resource, action, expected_status: int expected_status: int):
): body = {
body = {"service": service, "organisation": org, "resource": resource} "service": service,
"organisation": org,
"resource": resource
}
headers = { headers = {
"Authorization": "Bearer not_checked_when_auth_is_disabled", "Authorization": "Bearer not_checked_when_auth_is_disabled",
"X-API-Key": "123456789", "X-API-Key": "123456789"
} }
resp = await default_client.post( resp = await default_client.post(f"/iam/can_act_on_resource?action={action}", json=body, headers=headers)
f"/iam/can_act_on_resource?action={action}", json=body, headers=headers
)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -106,23 +108,18 @@ async def test_act_on_resource_endpoint_status_checks(
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_act_on_resource_logic( async def test_act_on_resource_logic(default_client: AsyncClient, db_session, service, org, resource, action,
default_client: AsyncClient, expected_response: bool):
db_session, body = {
service, "service": service,
org, "organisation": org,
resource, "resource": resource
action, }
expected_response: bool,
):
body = {"service": service, "organisation": org, "resource": resource}
headers = { headers = {
"Authorization": "Bearer not_checked_when_auth_is_disabled", "Authorization": "Bearer not_checked_when_auth_is_disabled",
"X-API-Key": "123456789", "X-API-Key": "123456789"
} }
resp = await default_client.post( resp = await default_client.post(f"/iam/can_act_on_resource?action={action}", json=body, headers=headers)
f"/iam/can_act_on_resource?action={action}", json=body, headers=headers
)
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -143,12 +140,11 @@ async def test_get_group_permissions_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", generate_query_and_status(["group_id", "org_id"]) "query, expected_status",
generate_query_and_status(["group_id", "org_id"])
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_group_permissions_status_checks( async def test_get_group_permissions_status_checks(default_client: AsyncClient, db_session, query: str, expected_status: int):
default_client: AsyncClient, db_session, query: str, expected_status: int
):
resp = await default_client.get(f"/iam/group/permissions?{query}") resp = await default_client.get(f"/iam/group/permissions?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -162,19 +158,8 @@ async def test_get_group_permissions_status_checks(
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_group_permissions_mismatch( async def test_get_group_permissions_mismatch(default_client: AsyncClient, db_session, query: str):
default_client: AsyncClient, db_session, query: str db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved"))
):
db_session.add(
Org(
name="Another Test Org",
root_user_id=1,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
)
)
db_session.add(Group(name="Another Test Group", org_id=2)) db_session.add(Group(name="Another Test Group", org_id=2))
db_session.flush() db_session.flush()
resp = await default_client.get(f"/iam/group/permissions?{query}") resp = await default_client.get(f"/iam/group/permissions?{query}")
@ -198,12 +183,11 @@ async def test_get_group_users_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", generate_query_and_status(["group_id", "org_id"]) "query, expected_status",
generate_query_and_status(["group_id", "org_id"])
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_group_users_status_checks( async def test_get_group_users_status_checks(default_client: AsyncClient, query: str, expected_status: int):
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/iam/group/users?{query}") resp = await default_client.get(f"/iam/group/users?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -217,19 +201,8 @@ async def test_get_group_users_status_checks(
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_group_users_mismatch( async def test_get_group_users_mismatch(default_client: AsyncClient, db_session, query: str):
default_client: AsyncClient, db_session, query: str db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved"))
):
db_session.add(
Org(
name="Another Test Org",
root_user_id=1,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
)
)
db_session.add(Group(name="Another Test Group", org_id=2)) db_session.add(Group(name="Another Test Group", org_id=2))
db_session.flush() db_session.flush()
resp = await default_client.get(f"/iam/group/users?{query}") resp = await default_client.get(f"/iam/group/users?{query}")
@ -240,9 +213,7 @@ async def test_get_group_users_mismatch(
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_group_success(default_client: AsyncClient): async def test_post_group_success(default_client: AsyncClient):
resp = await default_client.post( resp = await default_client.post("/iam/group", json={"name": "New Group", "organisation_id": 1})
"/iam/group", json={"name": "New Group", "organisation_id": 1}
)
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -255,23 +226,10 @@ async def test_post_group_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 1, "name": "Test Group"}, 409), ({"organisation_id": 2, "name": "new group"}, 404), # Non-existent organisation, valid name
( ({"organisation_id": "banana", "name": "new group"}, 422), # Invalid organisation ID, valid name
{"organisation_id": 2, "name": "new group"}, ({"organisation_id": "", "name": "new group"}, 422), # Blank organisation ID, valid name
404, ({"organisation_id": -1, "name": "new group"}, 422), # Negative organisation ID, valid name
), # Non-existent organisation, valid name
(
{"organisation_id": "banana", "name": "new group"},
422,
), # Invalid organisation ID, valid name
(
{"organisation_id": "", "name": "new group"},
422,
), # Blank organisation ID, valid name
(
{"organisation_id": -1, "name": "new group"},
422,
), # Negative organisation ID, valid name
({"name": 1}, 422), # Only name ({"name": 1}, 422), # Only name
({}, 422), # Blank body ({}, 422), # Blank body
({"organisation_id": "", "name": ""}, 422), # Both blank ({"organisation_id": "", "name": ""}, 422), # Both blank
@ -282,9 +240,7 @@ async def test_post_group_success(default_client: AsyncClient):
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_group_status_checks( async def test_post_group_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/iam/group", json=body) resp = await default_client.post("/iam/group", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -294,10 +250,7 @@ async def test_post_group_status_checks(
async def test_put_group_perm_success(default_client: AsyncClient, db_session): async def test_put_group_perm_success(default_client: AsyncClient, db_session):
db_session.add(Group(name="Test Group Two", org_id=1)) db_session.add(Group(name="Test Group Two", org_id=1))
db_session.flush() db_session.flush()
resp = await default_client.put( resp = await default_client.put("/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 1})
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 1},
)
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -316,71 +269,36 @@ async def test_put_group_perm_success(default_client: AsyncClient, db_session):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
( ({"organisation_id": 42, "group_id": 1, "permission_id": 1}, 404), # Non-existent organisation
{"organisation_id": 42, "group_id": 1, "permission_id": 1}, ({"organisation_id": "banana", "group_id": 1, "permission_id": 1}, 422), # Invalid organisation ID
404, ({"organisation_id": "", "group_id": 1, "permission_id": 1}, 422), # Blank organisation ID
), # Non-existent organisation ({"organisation_id": -1, "group_id": 1, "permission_id": 1}, 422), # Negative organisation ID
(
{"organisation_id": "banana", "group_id": 1, "permission_id": 1}, ({"organisation_id": 1, "group_id": 42, "permission_id": 1}, 404), # Non-existent group
422, ({"organisation_id": 1, "group_id": "banana", "permission_id": 1}, 422), # Invalid group ID
), # Invalid organisation ID ({"organisation_id": 1, "group_id": "", "permission_id": 1}, 422), # Blank group ID
( ({"organisation_id": 1, "group_id": -1, "permission_id": 1}, 422), # Negative group ID
{"organisation_id": "", "group_id": 1, "permission_id": 1},
422, ({"organisation_id": 1, "group_id": 1, "permission_id": 42}, 404), # Non-existent permission
), # Blank organisation ID ({"organisation_id": 1, "group_id": 1, "permission_id": "banana"}, 422), # Invalid permission ID
( ({"organisation_id": 1, "group_id": 1, "permission_id": ""}, 422), # Blank permission ID
{"organisation_id": -1, "group_id": 1, "permission_id": 1}, ({"organisation_id": 1, "group_id": 1, "permission_id": -1}, 422), # Negative permission ID
422,
), # Negative organisation ID
(
{"organisation_id": 1, "group_id": 42, "permission_id": 1},
404,
), # Non-existent group
(
{"organisation_id": 1, "group_id": "banana", "permission_id": 1},
422,
), # Invalid group ID
(
{"organisation_id": 1, "group_id": "", "permission_id": 1},
422,
), # Blank group ID
(
{"organisation_id": 1, "group_id": -1, "permission_id": 1},
422,
), # Negative group ID
(
{"organisation_id": 1, "group_id": 1, "permission_id": 42},
404,
), # Non-existent permission
(
{"organisation_id": 1, "group_id": 1, "permission_id": "banana"},
422,
), # Invalid permission ID
(
{"organisation_id": 1, "group_id": 1, "permission_id": ""},
422,
), # Blank permission ID
(
{"organisation_id": 1, "group_id": 1, "permission_id": -1},
422,
), # Negative permission ID
({}, 422), # Blank body ({}, 422), # Blank body
({"permission_id": 1}, 422), # Only permission ({"permission_id": 1}, 422), # Only permission
({"organisation_id": 1}, 422), # Only organisation ({"organisation_id": 1}, 422), # Only organisation
({"group_id": 1}, 422), # Only group ({"group_id": 1}, 422), # Only group
({"organisation_id": 1, "permission_id": 1}, 422), # Missing group ({"organisation_id": 1, "permission_id": 1}, 422), # Missing group
({"group_id": 1, "permission_id": 1}, 422), # Missing organisation ({"group_id": 1, "permission_id": 1}, 422), # Missing organisation
({"organisation_id": 1, "group_id": 1}, 422), # Missing permission ({"organisation_id": 1, "group_id": 1}, 422), # Missing permission
(
{"organisation_id": 1, "group_id": 1, "permission_id": 1}, ({"organisation_id": 1, "group_id": 1, "permission_id": 1}, 409), # Permission already in group
409,
), # Permission already in group
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_group_perm_status_checks( async def test_put_group_perm_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.put("/iam/group/permission", json=body) resp = await default_client.put("/iam/group/permission", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -394,19 +312,8 @@ async def test_put_group_perm_status_checks(
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_group_perm_mismatch( async def test_put_group_perm_mismatch(default_client: AsyncClient, db_session, body: dict):
default_client: AsyncClient, db_session, body: dict db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved"))
):
db_session.add(
Org(
name="Another Test Org",
root_user_id=1,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
)
)
db_session.add(Group(name="Another Test Group", org_id=2)) db_session.add(Group(name="Another Test Group", org_id=2))
db_session.flush() db_session.flush()
resp = await default_client.put("/iam/group/permission", json=body) resp = await default_client.put("/iam/group/permission", json=body)
@ -417,19 +324,10 @@ async def test_put_group_perm_mismatch(
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_group_user_success(default_client: AsyncClient, db_session): async def test_put_group_user_success(default_client: AsyncClient, db_session):
db_session.add( db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush() db_session.flush()
resp = await default_client.put( resp = await default_client.put("/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1})
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1}
)
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -449,58 +347,34 @@ async def test_put_group_user_success(default_client: AsyncClient, db_session):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
( ({"organisation_id": 42, "group_id": 1, "user_id": 1}, 404), # Non-existent organisation
{"organisation_id": 42, "group_id": 1, "user_id": 1}, ({"organisation_id": "banana", "group_id": 1, "user_id": 1}, 422), # Invalid organisation ID
404, ({"organisation_id": "", "group_id": 1, "user_id": 1}, 422), # Blank organisation ID
), # Non-existent organisation ({"organisation_id": -1, "group_id": 1, "user_id": 1}, 422), # Negative organisation ID
(
{"organisation_id": "banana", "group_id": 1, "user_id": 1}, ({"organisation_id": 1, "group_id": 42, "user_id": 1}, 404), # Non-existent group
422, ({"organisation_id": 1, "group_id": "banana", "user_id": 1}, 422), # Invalid group ID
), # Invalid organisation ID
(
{"organisation_id": "", "group_id": 1, "user_id": 1},
422,
), # Blank organisation ID
(
{"organisation_id": -1, "group_id": 1, "user_id": 1},
422,
), # Negative organisation ID
(
{"organisation_id": 1, "group_id": 42, "user_id": 1},
404,
), # Non-existent group
(
{"organisation_id": 1, "group_id": "banana", "user_id": 1},
422,
), # Invalid group ID
({"organisation_id": 1, "group_id": "", "user_id": 1}, 422), # Blank group ID ({"organisation_id": 1, "group_id": "", "user_id": 1}, 422), # Blank group ID
( ({"organisation_id": 1, "group_id": -1, "user_id": 1}, 422), # Negative group ID
{"organisation_id": 1, "group_id": -1, "user_id": 1},
422, ({"organisation_id": 1, "group_id": 1, "user_id": 42}, 404), # Non-existent user
), # Negative group ID ({"organisation_id": 1, "group_id": 1, "user_id": "banana"}, 422), # Invalid user ID
(
{"organisation_id": 1, "group_id": 1, "user_id": 42},
404,
), # Non-existent user
(
{"organisation_id": 1, "group_id": 1, "user_id": "banana"},
422,
), # Invalid user ID
({"organisation_id": 1, "group_id": 1, "user_id": ""}, 422), # Blank user ID ({"organisation_id": 1, "group_id": 1, "user_id": ""}, 422), # Blank user ID
({"organisation_id": 1, "group_id": 1, "user_id": -1}, 422), # Negative user ID ({"organisation_id": 1, "group_id": 1, "user_id": -1}, 422), # Negative user ID
({}, 422), # Blank body ({}, 422), # Blank body
({"user_id": 1}, 422), # Only user ({"user_id": 1}, 422), # Only user
({"organisation_id": 1}, 422), # Only organisation ({"organisation_id": 1}, 422), # Only organisation
({"group_id": 1}, 422), # Only group ({"group_id": 1}, 422), # Only group
({"organisation_id": 1, "user_id": 1}, 422), # Missing group ({"organisation_id": 1, "user_id": 1}, 422), # Missing group
({"group_id": 1, "user_id": 1}, 422), # Missing organisation ({"group_id": 1, "user_id": 1}, 422), # Missing organisation
({"organisation_id": 1, "group_id": 1}, 422), # Missing user ({"organisation_id": 1, "group_id": 1}, 422), # Missing user
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_group_user_status_checks( async def test_put_group_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.put("/iam/group/user", json=body) resp = await default_client.put("/iam/group/user", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -520,12 +394,11 @@ async def test_get_permissions_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", generate_query_and_status(["org_id"]) "query, expected_status",
generate_query_and_status(["org_id"])
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_permissions_status_checks( async def test_get_permissions_status_checks(default_client: AsyncClient, query: str, expected_status: int):
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/iam/permissions?{query}") resp = await default_client.get(f"/iam/permissions?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -533,10 +406,7 @@ async def test_get_permissions_status_checks(
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_perm_success(default_client: AsyncClient, db_session): async def test_post_perm_success(default_client: AsyncClient, db_session):
resp = await default_client.post( resp = await default_client.post("/iam/permission", json={"service_id": 1, "resource": "test_resource", "action": "create"})
"/iam/permission",
json={"service_id": 1, "resource": "test_resource", "action": "create"},
)
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -549,51 +419,32 @@ async def test_post_perm_success(default_client: AsyncClient, db_session):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
(
{"service_id": 1, "resource": "test_resource", "action": "read"},
409,
),
# service_id tests # service_id tests
( ({"service_id": 42, "resource": "test_resource", "action": "read"}, 404), # Non-existent service
{"service_id": 42, "resource": "test_resource", "action": "read"}, ({"service_id": "banana", "resource": "test_resource", "action": "read"}, 422), # Invalid service ID
404, ({"service_id": "", "resource": "test_resource", "action": "read"}, 422), # Blank service ID
), # Non-existent service ({"service_id": -1, "resource": "test_resource", "action": "read"}, 422), # Negative service ID
(
{"service_id": "banana", "resource": "test_resource", "action": "read"},
422,
), # Invalid service ID
(
{"service_id": "", "resource": "test_resource", "action": "read"},
422,
), # Blank service ID
(
{"service_id": -1, "resource": "test_resource", "action": "read"},
422,
), # Negative service ID
# resource tests # resource tests
( ({"service_id": 1, "resource": 42, "action": "read"}, 422), # Invalid resource type
{"service_id": 1, "resource": 42, "action": "read"},
422,
), # Invalid resource type
# action tests # action tests
( ({"service_id": 1, "resource": "test_resource", "action": 42}, 422), # Invalid action type
{"service_id": 1, "resource": "test_resource", "action": 42},
422,
), # Invalid action type
# missing/partial body tests # missing/partial body tests
({}, 422), # Blank body ({}, 422), # Blank body
({"resource": "test_resource"}, 422), # Only resource ({"resource": "test_resource"}, 422), # Only resource
({"action": "read"}, 422), # Only action ({"action": "read"}, 422), # Only action
({"service_id": 1}, 422), # Only service ({"service_id": 1}, 422), # Only service
({"service_id": 1, "action": "read"}, 422), # Missing resource ({"service_id": 1, "action": "read"}, 422), # Missing resource
({"service_id": 1, "resource": "test_resource"}, 422), # Missing action ({"service_id": 1, "resource": "test_resource"}, 422), # Missing action
({"resource": "test_resource", "action": "read"}, 422), # Missing service ({"resource": "test_resource", "action": "read"}, 422), # Missing service
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_perm_status_checks( async def test_post_perm_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/iam/permission", json=body) resp = await default_client.post("/iam/permission", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -602,12 +453,8 @@ async def test_post_perm_status_checks(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body", "body",
[ [
{ {"organisation_id": 1, "service_id": 1, "resource": "test_resource", "action": "read"},
"organisation_id": 1,
"service_id": 1,
"resource": "test_resource",
"action": "read",
},
{"organisation_id": 1, "service_id": 1}, {"organisation_id": 1, "service_id": 1},
{"organisation_id": 1, "resource": "test_resource"}, {"organisation_id": 1, "resource": "test_resource"},
{"organisation_id": 1, "action": "read"}, {"organisation_id": 1, "action": "read"},
@ -633,130 +480,28 @@ async def test_post_perm_search_success(default_client: AsyncClient, db_session,
"body, expected_status", "body, expected_status",
[ [
# organisation_id tests # organisation_id tests
( ({"organisation_id": 42, "service_id": 1, "resource": "test_resource", "action": "read"}, 404), # Non-existent organisation
{ ({"organisation_id": "banana", "service_id": 1, "resource": "test_resource", "action": "read"}, 422), # Invalid organisation ID
"organisation_id": 42, ({"organisation_id": "", "service_id": 1, "resource": "test_resource", "action": "read"}, 422), # Blank organisation ID
"service_id": 1, ({"organisation_id": -1, "service_id": 1, "resource": "test_resource", "action": "read"}, 422), # Negative organisation ID
"resource": "test_resource",
"action": "read",
},
404,
), # Non-existent organisation
(
{
"organisation_id": "banana",
"service_id": 1,
"resource": "test_resource",
"action": "read",
},
422,
), # Invalid organisation ID
(
{
"organisation_id": "",
"service_id": 1,
"resource": "test_resource",
"action": "read",
},
422,
), # Blank organisation ID
(
{
"organisation_id": -1,
"service_id": 1,
"resource": "test_resource",
"action": "read",
},
422,
), # Negative organisation ID
# service_id tests # service_id tests
( ({"organisation_id": 1, "service_id": "banana", "resource": "test_resource", "action": "read"}, 422), # Invalid service ID
{ ({"organisation_id": 1, "service_id": "", "resource": "test_resource", "action": "read"}, 422), # Blank service ID
"organisation_id": 1, ({"organisation_id": 1, "service_id": -1, "resource": "test_resource", "action": "read"}, 422), # Negative service ID
"service_id": "banana",
"resource": "test_resource",
"action": "read",
},
422,
), # Invalid service ID
(
{
"organisation_id": 1,
"service_id": "",
"resource": "test_resource",
"action": "read",
},
422,
), # Blank service ID
(
{
"organisation_id": 1,
"service_id": -1,
"resource": "test_resource",
"action": "read",
},
422,
), # Negative service ID
# resource tests # resource tests
( ({"organisation_id": 1, "service_id": 1, "resource": 42, "action": "read"}, 422), # Invalid resource type
{"organisation_id": 1, "service_id": 1, "resource": 42, "action": "read"},
422,
), # Invalid resource type
# action tests # action tests
( ({"organisation_id": 1, "service_id": 1, "resource": "test_resource", "action": 42}, 422), # Invalid action type
{
"organisation_id": 1,
"service_id": 1,
"resource": "test_resource",
"action": 42,
},
422,
), # Invalid action type
# missing/partial body tests # missing/partial body tests
({}, 422), # Blank body ({}, 422), # Blank body
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_perm_search_status_checks( async def test_post_perm_search_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/iam/permissions/search", json=body) resp = await default_client.post("/iam/permissions/search", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_delete_group_permissions_success(default_client: AsyncClient):
resp = await default_client.delete(
"/iam/group/permissions?org_id=1&group_id=1&perm_id=1"
)
data = resp.json()
assert resp.status_code == 200
assert "permissions" in data
assert isinstance(data["permissions"], list)
assert len(data["permissions"]) == 0
assert "group" in data
assert data["group"]["id"] == 1
assert data["group"]["name"] == "Test Group"
@pytest.mark.anyio
async def test_delete_permissions_success(default_client: AsyncClient):
resp = await default_client.delete("/iam/permission?perm_id=1")
assert resp.status_code == 204
@pytest.mark.anyio
async def test_delete_group_users_success(default_client: AsyncClient):
resp = await default_client.delete("/iam/group/user?org_id=1&group_id=1&user_id=1")
data = resp.json()
assert resp.status_code == 200
assert "users" in data
assert isinstance(data["users"], list)
assert len(data["users"]) == 0
assert "group" in data
assert data["group"]["id"] == 1
assert data["group"]["name"] == "Test Group"

View file

@ -1,7 +1,6 @@
""" """
[DELETE] /org/ is not tested because the testing client cannot attach a body to a delete request. [DELETE] /org/ is not tested because the testing client cannot attach a body to a delete request.
""" """
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
@ -25,12 +24,11 @@ async def test_get_org_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", generate_query_and_status(["org_id"]) "query, expected_status",
generate_query_and_status(["org_id"])
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_status_checks( async def test_get_org_status_checks(default_client: AsyncClient, query: str, expected_status: int):
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org?{query}") resp = await default_client.get(f"/org?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -49,40 +47,24 @@ async def test_post_org_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"name": "Test Org"}, 409),
({"name": 42}, 422), ({"name": 42}, 422),
({}, 422), ({}, 422),
({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422), ({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_status_checks( async def test_post_org_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/org", json=body) resp = await default_client.post("/org", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_questionnaire_partial_success( async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient, db_session):
default_client: AsyncClient, db_session
):
org_model = db_session.get(Organisation, 1) org_model = db_session.get(Organisation, 1)
org_model.status = "partial" org_model.status = "partial"
db_session.flush() db_session.flush()
resp = await default_client.patch( resp = await default_client.patch("/org/questionnaire", json={"organisation_id": 1, "intake_questionnaire": {"question_one": "new answer one", "question_two": None, "question_three": None}, "partial": True})
"/org/questionnaire",
json={
"organisation_id": 1,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -101,56 +83,24 @@ async def test_patch_org_questionnaire_partial_success(
({"organisation_id": "Test Org"}, 422), ({"organisation_id": "Test Org"}, 422),
({"organisation_id": ""}, 422), ({"organisation_id": ""}, 422),
({}, 422), ({}, 422),
( ({"organisation_id": "1", "intake_questionnaire": {"question_one": 42}, "partial": True}, 422),
{ ({"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}}, 422),
"organisation_id": "1", ({"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}, "partial": 42}, 422),
"intake_questionnaire": {"question_one": 42},
"partial": True,
},
422,
),
(
{"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}},
422,
),
(
{
"organisation_id": "1",
"intake_questionnaire": {"question_one": "valid"},
"partial": 42,
},
422,
),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_questionnaire_partial_status_checks( async def test_patch_questionnaire_partial_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/questionnaire", json=body) resp = await default_client.patch("/org/questionnaire", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_questionnaire_submit_success( async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient, db_session):
default_client: AsyncClient, db_session
):
org_model = db_session.get(Organisation, 1) org_model = db_session.get(Organisation, 1)
org_model.status = "partial" org_model.status = "partial"
db_session.flush() db_session.flush()
resp = await default_client.patch( resp = await default_client.patch("/org/questionnaire", json={"organisation_id": 1, "intake_questionnaire": {"question_one": "new answer one", "question_two": None, "question_three": None}, "partial": False})
"/org/questionnaire",
json={
"organisation_id": 1,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": False,
},
)
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -163,13 +113,12 @@ async def test_patch_org_questionnaire_submit_success(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"] "status",
["partial", "submitted", "remediation", "approved", "rejected", "removed"]
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_status_success(default_client: AsyncClient, status: str): async def test_patch_org_status_success(default_client: AsyncClient, status: str):
resp = await default_client.patch( resp = await default_client.patch("/org/status", json={"organisation_id": 1, "status": status})
"/org/status", json={"organisation_id": 1, "status": status}
)
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -189,9 +138,7 @@ async def test_patch_org_status_success(default_client: AsyncClient, status: str
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_status_status_checks( async def test_patch_org_status_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/status", json=body) resp = await default_client.patch("/org/status", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -214,12 +161,11 @@ async def test_get_org_users_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", generate_query_and_status(["org_id"]) "query, expected_status",
generate_query_and_status(["org_id"])
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_users_status_checks( async def test_get_org_users_status_checks(default_client: AsyncClient, query: str, expected_status: int):
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org/users?{query}") resp = await default_client.get(f"/org/users?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -227,19 +173,10 @@ async def test_get_org_users_status_checks(
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_user_success(default_client: AsyncClient, db_session): async def test_post_org_user_success(default_client: AsyncClient, db_session):
db_session.add( db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush() db_session.flush()
resp = await default_client.post( resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 2})
"/org/user", json={"organisation_id": 1, "user_id": 2}
)
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -260,17 +197,8 @@ async def test_post_org_user_success(default_client: AsyncClient, db_session):
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_user_status_checks( async def test_post_org_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session):
default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
):
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush() db_session.flush()
resp = await default_client.post("/org/user", json=body) resp = await default_client.post("/org/user", json=body)
@ -280,21 +208,12 @@ async def test_post_org_user_status_checks(
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_root_user_success(default_client: AsyncClient, db_session): async def test_patch_org_root_user_success(default_client: AsyncClient, db_session):
db_session.add( db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush() db_session.flush()
db_session.add(OrgUsers(org_id=1, user_id=2)) db_session.add(OrgUsers(org_id=1, user_id=2))
db_session.flush() db_session.flush()
resp = await default_client.patch( resp = await default_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2})
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -315,17 +234,8 @@ async def test_patch_org_root_user_success(default_client: AsyncClient, db_sessi
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_root_user_status_checks( async def test_patch_root_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session):
default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
):
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush() db_session.flush()
db_session.add(OrgUsers(org_id=1, user_id=2)) db_session.add(OrgUsers(org_id=1, user_id=2))
db_session.flush() db_session.flush()
@ -337,19 +247,10 @@ async def test_patch_root_user_status_checks(
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_root_user_non_member(default_client: AsyncClient, db_session): async def test_patch_org_root_user_non_member(default_client: AsyncClient, db_session):
db_session.add( db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush() db_session.flush()
resp = await default_client.patch( resp = await default_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2})
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
data = resp.json() data = resp.json()
assert resp.status_code == 422 assert resp.status_code == 422
@ -368,23 +269,23 @@ async def test_get_org_groups_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", generate_query_and_status(["org_id"]) "query, expected_status",
generate_query_and_status(["org_id"])
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_groups_status_checks( async def test_get_org_groups_status_checks(default_client: AsyncClient, query: str, expected_status: int):
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org/groups?{query}") resp = await default_client.get(f"/org/groups?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.parametrize("contact_type", ["billing", "security", "owner"]) @pytest.mark.parametrize(
"contact_type",
["billing", "security", "owner"]
)
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str): async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str):
resp = await default_client.get( resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}")
f"/org/contact?org_id=1&contact_type={contact_type}"
)
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -426,9 +327,7 @@ async def test_get_org_contact_success(default_client: AsyncClient, contact_type
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_contact_status_checks( async def test_get_org_contact_status_checks(default_client: AsyncClient, query: str, expected_status: int):
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org/contact?{query}") resp = await default_client.get(f"/org/contact?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -449,16 +348,11 @@ async def test_get_org_contact_status_checks(
("address_region", "Glasgow City"), ("address_region", "Glasgow City"),
("country_code", "GB"), ("country_code", "GB"),
("postal_code", "G1 1AA"), ("postal_code", "G1 1AA"),
], ]
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_contact_success( async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str):
default_client: AsyncClient, key: str, value: str resp = await default_client.patch("/org/contact", json={"organisation_id": 1, "contact_type": "billing", key: value})
):
resp = await default_client.patch(
"/org/contact",
json={"organisation_id": 1, "contact_type": "billing", key: value},
)
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
@ -485,32 +379,7 @@ async def test_patch_org_contact_success(
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_contact_status_checks( async def test_patch_org_contact_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/contact", json=body) resp = await default_client.patch("/org/contact", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_delete_org_success(default_client: AsyncClient):
resp = await default_client.delete("/org?org_id=1")
assert resp.status_code == 204
@pytest.mark.anyio
async def test_delete_org_users_success(db_session, default_client: AsyncClient):
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush()
resp = await default_client.delete("/org/user?org_id=1&user_id=2")
assert resp.status_code == 204

View file

@ -1,7 +1,6 @@
""" """
409 on [POST]/service/ not tested because SQLite throws a different error than Postgres 409 on [POST]/service/ not tested because SQLite throws a different error than Postgres
""" """
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
@ -20,12 +19,11 @@ async def test_get_services_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", generate_query_and_status(["org_id"]) "query, expected_status",
generate_query_and_status(["org_id"])
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_services_status_checks( async def test_get_services_status_checks(default_client: AsyncClient, query: str, expected_status: int):
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/service/?{query}") resp = await default_client.get(f"/service/?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -46,15 +44,12 @@ async def test_post_service_success(default_client: AsyncClient):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"name": "Test Service"}, 409),
({"name": 42}, 422), ({"name": 42}, 422),
({}, 422), ({}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_service_status_checks( async def test_post_services_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/service/", json=body) resp = await default_client.post("/service/", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@ -82,16 +77,7 @@ async def test_patch_service_success(default_client: AsyncClient):
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_services_status_checks( async def test_patch_services_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/service/key", json=body) resp = await default_client.patch("/service/key", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_delete_service_success(default_client: AsyncClient):
resp = await default_client.delete("/service/?service_id=1")
assert resp.status_code == 204

View file

@ -8,7 +8,6 @@ from httpx import AsyncClient
from .conftest import generate_query_and_status from .conftest import generate_query_and_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_self_db_success(default_client: AsyncClient): async def test_get_self_db_success(default_client: AsyncClient):
resp = await default_client.get("/user/self/db") resp = await default_client.get("/user/self/db")
@ -37,18 +36,10 @@ async def test_get_user_success(default_client: AsyncClient):
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", generate_query_and_status(["user_id"]) "query, expected_status",
generate_query_and_status(["user_id"])
) )
async def test_get_user_status_checks( async def test_get_user_status_checks(default_client: AsyncClient, query: str, expected_status: int):
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/user/?{query}") resp = await default_client.get(f"/user/?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_delete_user_success(default_client: AsyncClient):
resp = await default_client.delete("/user/?user_id=1")
assert resp.status_code == 204