minor: ruff format
Some checks failed
ci / ruff (push) Successful in 4s
ci / ty (push) Successful in 4s
ci / tests (push) Failing after 7s
ci / build (push) Has been cancelled

Tabs -> spaces
This commit is contained in:
Chris Milne 2026-06-22 15:04:11 +01:00
parent b2921b73b8
commit fab228bf8f
56 changed files with 3629 additions and 3630 deletions

View file

@ -22,5 +22,5 @@ from fastapi import APIRouter
router = APIRouter( router = APIRouter(
tags=[""], tags=[""],
) )

View file

@ -8,6 +8,6 @@ Exports:
from fastapi import APIRouter from fastapi import APIRouter
router = APIRouter( router = APIRouter(
tags=["admin"], tags=["admin"],
prefix="/admin", prefix="/admin",
) )

View file

@ -26,15 +26,15 @@ api_router.include_router(iam_router)
class HealthCheckResponse(CustomBaseModel): class HealthCheckResponse(CustomBaseModel):
status: str status: str
@api_router.get( @api_router.get(
path="/healthcheck", path="/healthcheck",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=HealthCheckResponse, response_model=HealthCheckResponse,
include_in_schema=False, include_in_schema=False,
) )
def healthcheck(): def healthcheck():
"""Simple health check endpoint.""" """Simple health check endpoint."""
return {"status": "ok"} return {"status": "ok"}

View file

@ -9,9 +9,9 @@ from src.config import CustomBaseSettings
class AuthConfig(CustomBaseSettings): class AuthConfig(CustomBaseSettings):
OIDC_CONFIG: str = "" OIDC_CONFIG: str = ""
OIDC_ISSUER: str = "" OIDC_ISSUER: str = ""
CLIENT_ID: str = "" CLIENT_ID: str = ""
auth_settings = AuthConfig() auth_settings = AuthConfig()

View file

@ -16,92 +16,92 @@ from src.exceptions import ForbiddenException
from src.user.dependencies import user_model_claims_dependency from src.user.dependencies import user_model_claims_dependency
from src.user.models import User from src.user.models import User
from src.organisation.dependencies import ( from src.organisation.dependencies import (
org_model_query_dependency, org_model_query_dependency,
org_model_body_dependency, org_model_body_dependency,
) )
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
async def org_query_user_claims( async def org_query_user_claims(
org_model: org_model_query_dependency, user_model: user_model_claims_dependency org_model: org_model_query_dependency, user_model: user_model_claims_dependency
): ):
if user_model in org_model.user_rel: if user_model in org_model.user_rel:
return True return True
raise ForbiddenException() raise ForbiddenException()
org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)] org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)]
def get_super_admin_list(): def get_super_admin_list():
return [] return []
def empty_su_list(): def empty_su_list():
return [] return []
def testing_su_list(): def testing_su_list():
return ["admin@test.com"] return ["admin@test.com"]
su_list_dependency = Annotated[list[str | None], Depends(get_super_admin_list)] su_list_dependency = Annotated[list[str | None], Depends(get_super_admin_list)]
async def user_model_super_admin( async def user_model_super_admin(
user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency
): ):
if user_model.email in super_admin_emails: if user_model.email in super_admin_emails:
return user_model return user_model
raise ForbiddenException(message="Must be super admin") raise ForbiddenException(message="Must be super admin")
super_admin_dependency = Annotated[User, Depends(user_model_super_admin)] super_admin_dependency = Annotated[User, Depends(user_model_super_admin)]
async def org_query_root_claims( async def org_query_root_claims(
user_model: user_model_claims_dependency, user_model: user_model_claims_dependency,
org_model: org_model_query_dependency, org_model: org_model_query_dependency,
su_emails: su_list_dependency, su_emails: su_list_dependency,
request: Request, request: Request,
): ):
try: try:
if await user_model_super_admin(user_model, su_emails): if await user_model_super_admin(user_model, su_emails):
return org_model return org_model
except ForbiddenException: except ForbiddenException:
pass pass
await org_status_check(org_model, request) await org_status_check(org_model, request)
if org_model.root_user_id == user_model.id: if org_model.root_user_id == user_model.id:
return org_model return org_model
raise ForbiddenException(message="Must be the org's root user") raise ForbiddenException(message="Must be the org's root user")
org_model_root_claim_query_dependency = Annotated[Org, Depends(org_query_root_claims)] org_model_root_claim_query_dependency = Annotated[Org, Depends(org_query_root_claims)]
async def org_body_root_claims( async def org_body_root_claims(
user_model: user_model_claims_dependency, user_model: user_model_claims_dependency,
org_model: org_model_body_dependency, org_model: org_model_body_dependency,
su_emails: su_list_dependency, su_emails: su_list_dependency,
request: Request, request: Request,
): ):
try: try:
if await user_model_super_admin(user_model, su_emails): if await user_model_super_admin(user_model, su_emails):
return org_model return org_model
except ForbiddenException: except ForbiddenException:
pass pass
await org_status_check(org_model, request) await org_status_check(org_model, request)
if org_model.root_user_id == user_model.id: if org_model.root_user_id == user_model.id:
return org_model return org_model
raise ForbiddenException(message="Must be the org's root user") raise ForbiddenException(message="Must be the org's root user")
org_model_root_claim_body_dependency = Annotated[Org, Depends(org_body_root_claims)] org_model_root_claim_body_dependency = Annotated[Org, Depends(org_body_root_claims)]

View file

@ -8,5 +8,5 @@ Exports:
from fastapi import APIRouter from fastapi import APIRouter
router = APIRouter( router = APIRouter(
tags=["auth"], tags=["auth"],
) )

View file

@ -31,56 +31,56 @@ oidc_dependency = Annotated[str, Depends(oidc)]
async def get_dev_user(): async def get_dev_user():
return {"db_id": 1, "email": "chris@sr2.uk"} return {"db_id": 1, "email": "chris@sr2.uk"}
async def get_current_user( async def get_current_user(
oidc_auth_string: oidc_dependency, db: DbSession oidc_auth_string: oidc_dependency, db: DbSession
) -> dict[str, Any]: ) -> dict[str, Any]:
config_url = urlopen(auth_settings.OIDC_CONFIG) config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read()) config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"] jwks_uri = config["jwks_uri"]
key_response = requests.get(jwks_uri) key_response = requests.get(jwks_uri)
jwk_keys = KeySet.import_key_set(key_response.json()) jwk_keys = KeySet.import_key_set(key_response.json())
token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys) token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys)
claims_requests = jwt.JWTClaimsRegistry( claims_requests = jwt.JWTClaimsRegistry(
exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER} exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER}
) )
try: try:
claims_requests.validate(token.claims) claims_requests.validate(token.claims)
except ExpiredTokenError: except ExpiredTokenError:
raise UnauthorizedException(message="Token is expired") raise UnauthorizedException(message="Token is expired")
db_id = await add_user(db, token.claims) db_id = await add_user(db, token.claims)
token.claims["db_id"] = db_id token.claims["db_id"] = db_id
return token.claims return token.claims
claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)] claims_dependency = Annotated[dict[str, Any], Depends(get_current_user)]
async def org_status_check(org_model: Org, request: Request): async def org_status_check(org_model: Org, request: Request):
org_status = OrgStatus(org_model.status) org_status = OrgStatus(org_model.status)
if org_status.is_blocked: if org_status.is_blocked:
raise ForbiddenException("This organisation cannot perform this action.") raise ForbiddenException("This organisation cannot perform this action.")
root = "/api/v1" root = "/api/v1"
pre_approval_endpoints = [ pre_approval_endpoints = [
f"PATCH{root}/org/status", f"PATCH{root}/org/status",
f"PATCH{root}/org/questionnaire", f"PATCH{root}/org/questionnaire",
f"GET{root}/org", f"GET{root}/org",
f"GET{root}/org/contact", f"GET{root}/org/contact",
f"PATCH{root}/org/contact", f"PATCH{root}/org/contact",
f"DELETE{root}/org/self", f"DELETE{root}/org/self",
] ]
current_request = f"{request.method}{request.url.path}" current_request = f"{request.method}{request.url.path}"
if ( if (
current_request not in pre_approval_endpoints current_request not in pre_approval_endpoints
and org_model.status != OrgStatus.APPROVED and org_model.status != OrgStatus.APPROVED
): ):
raise AwaitingApprovalException(org_model.id) raise AwaitingApprovalException(org_model.id)

View file

@ -16,31 +16,31 @@ from src.constants import Environment
class CustomBaseSettings(BaseSettings): class CustomBaseSettings(BaseSettings):
model_config = SettingsConfigDict( model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore" env_file=".env", env_file_encoding="utf-8", extra="ignore"
) )
class Config(CustomBaseSettings): class Config(CustomBaseSettings):
APP_VERSION: str = "0.1" APP_VERSION: str = "0.1"
ENVIRONMENT: Environment = Environment.PRODUCTION ENVIRONMENT: Environment = Environment.PRODUCTION
SECRET_KEY: SecretStr = SecretStr("") SECRET_KEY: SecretStr = SecretStr("")
DISABLE_AUTH: bool = False DISABLE_AUTH: bool = False
CORS_ORIGINS: list[str] = ["*"] CORS_ORIGINS: list[str] = ["*"]
CORS_ORIGINS_REGEX: str | None = None CORS_ORIGINS_REGEX: str | None = None
CORS_HEADERS: list[str] = ["*"] CORS_HEADERS: list[str] = ["*"]
DATABASE_NAME: str = "fastapi-exp" DATABASE_NAME: str = "fastapi-exp"
DATABASE_PORT: str = "5432" DATABASE_PORT: str = "5432"
DATABASE_HOSTNAME: str = "localhost" DATABASE_HOSTNAME: str = "localhost"
DATABASE_CREDENTIALS: SecretStr = SecretStr(":") DATABASE_CREDENTIALS: SecretStr = SecretStr(":")
DATABASE_POOL_SIZE: int = 16 DATABASE_POOL_SIZE: int = 16
DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes
DATABASE_POOL_PRE_PING: bool = True DATABASE_POOL_PRE_PING: bool = True
LETTERMINT_API_TOKEN: SecretStr = SecretStr("") LETTERMINT_API_TOKEN: SecretStr = SecretStr("")
settings = Config() settings = Config()
@ -51,20 +51,20 @@ DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME
DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value() DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value()
# this will support special chars for credentials # this will support special chars for credentials
_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split( _DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split(
":" ":"
) )
_QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD)) _QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD))
SQLALCHEMY_DATABASE_URI = SecretStr( SQLALCHEMY_DATABASE_URI = SecretStr(
f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}" f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
) )
if settings.ENVIRONMENT == Environment.TESTING: if settings.ENVIRONMENT == Environment.TESTING:
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:") SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")
app_configs: dict[str, Any] = {"title": "App API"} app_configs: dict[str, Any] = {"title": "App API"}
if settings.ENVIRONMENT.is_deployed: if settings.ENVIRONMENT.is_deployed:
app_configs["root_path"] = f"/v{settings.APP_VERSION}" app_configs["root_path"] = f"/v{settings.APP_VERSION}"
if not settings.ENVIRONMENT.is_debug: if not settings.ENVIRONMENT.is_debug:
app_configs["openapi_url"] = None # hide docs app_configs["openapi_url"] = None # hide docs

View file

@ -9,29 +9,29 @@ from enum import StrEnum, auto
class Environment(StrEnum): class Environment(StrEnum):
""" """
Enumeration of environments. Enumeration of environments.
Attributes: Attributes:
LOCAL (str): Application is running locally LOCAL (str): Application is running locally
TESTING (str): Application is running in testing mode TESTING (str): Application is running in testing mode
STAGING (str): Application is running in staging mode (ie not testing) STAGING (str): Application is running in staging mode (ie not testing)
PRODUCTION (str): Application is running in production mode PRODUCTION (str): Application is running in production mode
""" """
LOCAL = auto() LOCAL = auto()
TESTING = auto() TESTING = auto()
STAGING = auto() STAGING = auto()
PRODUCTION = auto() PRODUCTION = auto()
@property @property
def is_debug(self): def is_debug(self):
return self in (self.LOCAL, self.STAGING, self.TESTING) return self in (self.LOCAL, self.STAGING, self.TESTING)
@property @property
def is_testing(self): def is_testing(self):
return self == self.TESTING return self == self.TESTING
@property @property
def is_deployed(self) -> bool: def is_deployed(self) -> bool:
return self in (self.STAGING, self.PRODUCTION) return self in (self.STAGING, self.PRODUCTION)

View file

@ -11,13 +11,13 @@ from fastapi import HTTPException, status
class ContactNotFoundException(HTTPException): class ContactNotFoundException(HTTPException):
def __init__(self, contact_id: Optional[int] = None) -> None: def __init__(self, contact_id: Optional[int] = None) -> None:
detail = ( detail = (
"Contact not found" "Contact not found"
if contact_id is None if contact_id is None
else f"Contact with ID '{contact_id}' was not found." else f"Contact with ID '{contact_id}' was not found."
) )
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )

View file

@ -15,22 +15,22 @@ from src.models import CustomBase
class Contact(CustomBase, IdMixin): class Contact(CustomBase, IdMixin):
__tablename__ = "contact" __tablename__ = "contact"
email: Mapped[str] = mapped_column(default=None, nullable=True) email: Mapped[str] = mapped_column(default=None, nullable=True)
first_name: Mapped[str] = mapped_column(default=None, nullable=True) first_name: Mapped[str] = mapped_column(default=None, nullable=True)
last_name: Mapped[str] = mapped_column(default=None, nullable=True) last_name: Mapped[str] = mapped_column(default=None, nullable=True)
phonenumber: Mapped[str] = mapped_column(default=None, nullable=True) phonenumber: Mapped[str] = mapped_column(default=None, nullable=True)
vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True) vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
street_address: Mapped[str] = mapped_column(default=None, nullable=True) street_address: Mapped[str] = mapped_column(default=None, nullable=True)
street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True) street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True)
post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True) post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True)
locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City
country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB
address_region: Mapped[str | None] = mapped_column(default=None, nullable=True) address_region: Mapped[str | None] = mapped_column(default=None, nullable=True)
postal_code: Mapped[str] = mapped_column(default=None, nullable=True) postal_code: Mapped[str] = mapped_column(default=None, nullable=True)
org_id: Mapped[int] = mapped_column( org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False
) )

View file

@ -6,6 +6,6 @@ from fastapi import APIRouter
router = APIRouter( router = APIRouter(
prefix="/contact", prefix="/contact",
tags=["contact"], tags=["contact"],
) )

View file

@ -14,22 +14,22 @@ from src.schemas import CustomBaseModel
class ContactAddress(CustomBaseModel): class ContactAddress(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
post_office_box_number: Optional[str] = None post_office_box_number: Optional[str] = None
street_address: Optional[str] = None street_address: Optional[str] = None
street_address_line_2: Optional[str] = None street_address_line_2: Optional[str] = None
locality: Optional[str] = None locality: Optional[str] = None
address_region: Optional[str] = None address_region: Optional[str] = None
country_code: Optional[str] = None country_code: Optional[str] = None
postal_code: Optional[str] = None postal_code: Optional[str] = None
class ContactModel(CustomBaseModel): class ContactModel(CustomBaseModel):
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
first_name: Optional[str] = None first_name: Optional[str] = None
last_name: Optional[str] = None last_name: Optional[str] = None
phonenumber: Optional[str] = None phonenumber: Optional[str] = None
vat_number: Optional[str] = None vat_number: Optional[str] = None
address: ContactAddress address: ContactAddress

View file

@ -1,6 +1,7 @@
""" """
Database connection and session utilities Database connection and session utilities
""" """
from contextlib import contextmanager from contextlib import contextmanager
from typing import Annotated, Generator from typing import Annotated, Generator
from sqlalchemy import create_engine, StaticPool, Connection from sqlalchemy import create_engine, StaticPool, Connection
@ -29,6 +30,7 @@ else:
sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine) sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine)
@contextmanager @contextmanager
def get_db_connection() -> Generator[Connection, None, None]: def get_db_connection() -> Generator[Connection, None, None]:
with engine.connect() as connection: with engine.connect() as connection:
@ -38,12 +40,15 @@ def get_db_connection() -> Generator[Connection, None, None]:
connection.rollback() connection.rollback()
raise raise
def _get_db_connection() -> Generator[Connection, None]: def _get_db_connection() -> Generator[Connection, None]:
with get_db_connection() as connection: with get_db_connection() as connection:
yield connection yield connection
DbConnection = Annotated[Connection, Depends(_get_db_connection)] DbConnection = Annotated[Connection, Depends(_get_db_connection)]
@contextmanager @contextmanager
def get_db_session() -> Generator[Session, None, None]: def get_db_session() -> Generator[Session, None, None]:
session = sm() session = sm()
@ -60,4 +65,5 @@ def _get_db_session() -> Generator[Session, None]:
with get_db_session() as session: with get_db_session() as session:
yield session yield session
DbSession = Annotated[Session, Depends(_get_db_session)] DbSession = Annotated[Session, Depends(_get_db_session)]

View file

@ -12,36 +12,36 @@ from fastapi import HTTPException, status
class UnprocessableContentException(HTTPException): class UnprocessableContentException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message: Optional[str] = None) -> None:
detail = "Unprocessable content" if not message else message detail = "Unprocessable content" if not message else message
super().__init__( super().__init__(
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=detail, detail=detail,
) )
class ConflictException(HTTPException): class ConflictException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message: Optional[str] = None) -> None:
detail = "Conflict" if not message else message detail = "Conflict" if not message else message
super().__init__( super().__init__(
status_code=status.HTTP_409_CONFLICT, status_code=status.HTTP_409_CONFLICT,
detail=detail, detail=detail,
) )
class ForbiddenException(HTTPException): class ForbiddenException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message: Optional[str] = None) -> None:
detail = "Forbidden" if not message else message detail = "Forbidden" if not message else message
super().__init__( super().__init__(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=detail, detail=detail,
) )
class UnauthorizedException(HTTPException): class UnauthorizedException(HTTPException):
def __init__(self, message: Optional[str] = None) -> None: def __init__(self, message: Optional[str] = None) -> None:
detail = "Not authorized" if not message else message detail = "Not authorized" if not message else message
super().__init__( super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail, detail=detail,
) )

View file

@ -18,59 +18,55 @@ from src.iam.exceptions import GroupNotFoundException, PermNotFoundException
from src.iam.schemas import GroupIDMixin, PermIDMixin from src.iam.schemas import GroupIDMixin, PermIDMixin
def get_group_model_query( def get_group_model_query(db: DbSession, group_id: Annotated[int, Query(gt=0)]) -> Group:
db: DbSession, group_id: Annotated[int, Query(gt=0)] group_model = db.get(Group, group_id)
) -> Group: if group_model is None:
group_model = db.get(Group, group_id) raise GroupNotFoundException(group_id)
if group_model is None:
raise GroupNotFoundException(group_id)
return group_model return group_model
group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)] group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)]
def get_group_model_body( def get_group_model_body(
db: DbSession, request_model: Optional[GroupIDMixin] = None db: DbSession, request_model: Optional[GroupIDMixin] = None
) -> Group: ) -> Group:
group_id = getattr(request_model, "group_id", None) group_id = getattr(request_model, "group_id", None)
if group_id is None: if group_id is None:
raise GroupNotFoundException() raise GroupNotFoundException()
group_model = db.get(Group, group_id) group_model = db.get(Group, group_id)
if group_model is None: if group_model is None:
raise GroupNotFoundException(group_id) raise GroupNotFoundException(group_id)
return group_model return group_model
group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)] group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)]
def get_perm_model_body( def get_perm_model_body(
db: DbSession, request_model: Optional[PermIDMixin] = None db: DbSession, request_model: Optional[PermIDMixin] = None
) -> Permission: ) -> Permission:
perm_id = getattr(request_model, "permission_id", None) perm_id = getattr(request_model, "permission_id", None)
if perm_id is None: if perm_id is None:
raise PermNotFoundException raise PermNotFoundException
perm_model = db.get(Permission, perm_id) perm_model = db.get(Permission, perm_id)
if perm_model is None: if perm_model is None:
raise PermNotFoundException(perm_id) raise PermNotFoundException(perm_id)
return perm_model return perm_model
perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)] perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)]
def get_perm_model_query( def get_perm_model_query(db: DbSession, perm_id: Annotated[int, Query(gt=0)]) -> Permission:
db: DbSession, perm_id: Annotated[int, Query(gt=0)] perm_model = db.get(Permission, perm_id)
) -> Permission: if perm_model is None:
perm_model = db.get(Permission, perm_id) raise PermNotFoundException(perm_id)
if perm_model is None:
raise PermNotFoundException(perm_id)
return perm_model return perm_model
perm_model_query_dependency = Annotated[Permission, Depends(get_perm_model_query)] perm_model_query_dependency = Annotated[Permission, Depends(get_perm_model_query)]

View file

@ -12,26 +12,26 @@ from fastapi import HTTPException, status
class GroupNotFoundException(HTTPException): class GroupNotFoundException(HTTPException):
def __init__(self, group_id: Optional[int] = None) -> None: def __init__(self, group_id: Optional[int] = None) -> None:
detail = ( detail = (
"Group not found" "Group not found"
if group_id is None if group_id is None
else f"User with ID '{group_id}' was not found." else f"User with ID '{group_id}' was not found."
) )
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )
class PermNotFoundException(HTTPException): class PermNotFoundException(HTTPException):
def __init__(self, perm_id: Optional[int] = None) -> None: def __init__(self, perm_id: Optional[int] = None) -> None:
detail = ( detail = (
"Permission not found" "Permission not found"
if perm_id is None if perm_id is None
else f"User with ID '{perm_id}' was not found." else f"User with ID '{perm_id}' was not found."
) )
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )

View file

@ -25,90 +25,90 @@ from src.models import CustomBase, IdMixin
class Permission(CustomBase, IdMixin): class Permission(CustomBase, IdMixin):
__tablename__ = "permission" __tablename__ = "permission"
resource: Mapped[str] resource: Mapped[str]
action: Mapped[str] action: Mapped[str]
service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE")) service_id: Mapped[int] = mapped_column(ForeignKey("service.id", ondelete="CASCADE"))
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(
"service_id", "service_id",
"resource", "resource",
"action", "action",
name="uniq_permission_resource_and_action", name="uniq_permission_resource_and_action",
), ),
) )
service_rel = relationship( service_rel = relationship(
"Service", "Service",
back_populates="permission_rel", back_populates="permission_rel",
foreign_keys="Permission.service_id", foreign_keys="Permission.service_id",
) )
group_rel = relationship( group_rel = relationship(
"Group", secondary="group_permissions", back_populates="permission_rel" "Group", secondary="group_permissions", back_populates="permission_rel"
) )
org_rel = relationship( org_rel = relationship(
"Organisation", secondary="org_permissions", back_populates="permission_rel" "Organisation", secondary="org_permissions", back_populates="permission_rel"
) )
@property @property
def service_name(self): def service_name(self):
return self.service_rel.name return self.service_rel.name
class Group(CustomBase, IdMixin): class Group(CustomBase, IdMixin):
__tablename__ = "group" __tablename__ = "group"
name: Mapped[str] name: Mapped[str]
org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE")) org_id: Mapped[int] = mapped_column(ForeignKey("organisation.id", ondelete="CASCADE"))
__table_args__ = ( __table_args__ = (
UniqueConstraint( UniqueConstraint(
"name", "name",
"org_id", "org_id",
name="uniq_group_name_org_id", name="uniq_group_name_org_id",
), ),
) )
user_rel = relationship("User", secondary="user_groups", back_populates="group_rel") user_rel = relationship("User", secondary="user_groups", back_populates="group_rel")
org_rel = relationship("Organisation", back_populates="group_rel") org_rel = relationship("Organisation", back_populates="group_rel")
permission_rel = relationship( permission_rel = relationship(
"Permission", secondary="group_permissions", back_populates="group_rel" "Permission", secondary="group_permissions", back_populates="group_rel"
) )
class GroupPermissions(CustomBase): class GroupPermissions(CustomBase):
__tablename__ = "group_permissions" __tablename__ = "group_permissions"
group_id: Mapped[int] = mapped_column( group_id: Mapped[int] = mapped_column(
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
) )
permission_id: Mapped[int] = mapped_column( permission_id: Mapped[int] = mapped_column(
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
) )
class UserGroups(CustomBase): class UserGroups(CustomBase):
__tablename__ = "user_groups" __tablename__ = "user_groups"
user_id: Mapped[int] = mapped_column( user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
) )
group_id: Mapped[int] = mapped_column( group_id: Mapped[int] = mapped_column(
ForeignKey("group.id", ondelete="CASCADE"), primary_key=True ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
) )
class OrgPermissions(CustomBase): class OrgPermissions(CustomBase):
__tablename__ = "org_permissions" __tablename__ = "org_permissions"
org_id: Mapped[int] = mapped_column( org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
) )
permission_id: Mapped[int] = mapped_column( permission_id: Mapped[int] = mapped_column(
ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
) )

File diff suppressed because it is too large Load diff

View file

@ -11,151 +11,151 @@ from typing import Optional, Annotated
from pydantic import EmailStr, ConfigDict, Field from pydantic import EmailStr, ConfigDict, Field
from src.schemas import ( from src.schemas import (
CustomBaseModel, CustomBaseModel,
ResourceName, ResourceName,
ServiceIDMixin, ServiceIDMixin,
OrgIDMixin, OrgIDMixin,
UserIDMixin, UserIDMixin,
PermIDMixin, PermIDMixin,
GroupIDMixin, GroupIDMixin,
GroupSummary, GroupSummary,
OrgSummary, OrgSummary,
UserSummary, UserSummary,
) )
class UserSchema(CustomBaseModel): class UserSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int id: int
first_name: str first_name: str
last_name: str last_name: str
email: EmailStr email: EmailStr
class PermissionSchema(CustomBaseModel): class PermissionSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int id: int
service_name: str service_name: str
resource: str resource: str
action: str action: str
class GroupDetails(CustomBaseModel): class GroupDetails(CustomBaseModel):
details: GroupSummary details: GroupSummary
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMCAoRRequest(CustomBaseModel): class IAMCAoRRequest(CustomBaseModel):
action: str action: str
rn: ResourceName rn: ResourceName
class IAMCAoRResponse(CustomBaseModel): class IAMCAoRResponse(CustomBaseModel):
allowed: bool allowed: bool
user: UserSummary user: UserSummary
action: str action: str
rn: ResourceName rn: ResourceName
class IAMGetGroupPermissionsResponse(CustomBaseModel): class IAMGetGroupPermissionsResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
group: GroupSummary group: GroupSummary
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMGetGroupUsersResponse(CustomBaseModel): class IAMGetGroupUsersResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
group: GroupSummary group: GroupSummary
users: list[UserSummary] users: list[UserSummary]
class IAMPostGroupRequest(OrgIDMixin): class IAMPostGroupRequest(OrgIDMixin):
name: str = Field(min_length=3) name: str = Field(min_length=3)
class IAMPostGroupResponse(CustomBaseModel): class IAMPostGroupResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
group: GroupSummary group: GroupSummary
class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin): class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin):
pass pass
class IAMPutGroupPermissionResponse(CustomBaseModel): class IAMPutGroupPermissionResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
group: GroupSummary group: GroupSummary
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin): class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin):
pass pass
class IAMPutGroupUserResponse(CustomBaseModel): class IAMPutGroupUserResponse(CustomBaseModel):
group: GroupSummary group: GroupSummary
users: list[UserSchema] users: list[UserSchema]
class IAMDeleteGroupPermissionResponse(CustomBaseModel): class IAMDeleteGroupPermissionResponse(CustomBaseModel):
group: GroupSummary group: GroupSummary
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMDeleteGroupUserResponse(CustomBaseModel): class IAMDeleteGroupUserResponse(CustomBaseModel):
group: GroupSummary group: GroupSummary
users: list[UserSchema] users: list[UserSchema]
class IAMGetPermissionsResponse(CustomBaseModel): class IAMGetPermissionsResponse(CustomBaseModel):
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMPostPermissionRequest(ServiceIDMixin): class IAMPostPermissionRequest(ServiceIDMixin):
resource: str resource: str
action: str action: str
class IAMPostPermissionResponse(CustomBaseModel): class IAMPostPermissionResponse(CustomBaseModel):
permission: PermissionSchema permission: PermissionSchema
class IAMGetPermissionsSearchRequest(OrgIDMixin): class IAMGetPermissionsSearchRequest(OrgIDMixin):
service_id: Annotated[int | None, Field(gt=0)] = None service_id: Annotated[int | None, Field(gt=0)] = None
resource: Optional[str] = None resource: Optional[str] = None
action: Optional[str] = None action: Optional[str] = None
class IAMGetPermissionsSearchResponse(CustomBaseModel): class IAMGetPermissionsSearchResponse(CustomBaseModel):
permissions: list[PermissionSchema] permissions: list[PermissionSchema]
class IAMPutGroupInvitationRequest(OrgIDMixin, GroupIDMixin): class IAMPutGroupInvitationRequest(OrgIDMixin, GroupIDMixin):
user_email: EmailStr user_email: EmailStr
class IAMPutGroupInvitationResponse(CustomBaseModel): class IAMPutGroupInvitationResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
group: GroupSummary group: GroupSummary
invited_email: EmailStr invited_email: EmailStr
class IAMPutGroupInvitationAcceptRequest(CustomBaseModel): class IAMPutGroupInvitationAcceptRequest(CustomBaseModel):
jwt: str jwt: str
class IAMPutGroupInvitationAcceptResponse(CustomBaseModel): class IAMPutGroupInvitationAcceptResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
user: UserSummary user: UserSummary
group: GroupDetails group: GroupDetails
class IAMPutOrgPermissionsRequest(OrgIDMixin): class IAMPutOrgPermissionsRequest(OrgIDMixin):
permissions: list[int] permissions: list[int]
class IAMPutOrgPermissionsResponse(CustomBaseModel): class IAMPutOrgPermissionsResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
permissions: list[PermissionSchema] permissions: list[PermissionSchema]

View file

@ -23,90 +23,90 @@ from src.service.schemas import HasServiceName
def valid_service_key( def valid_service_key(
db: DbSession, request: Request, request_model: HasServiceName db: DbSession, request: Request, request_model: HasServiceName
) -> bool: ) -> bool:
rn = request_model.rn rn = request_model.rn
api_key = request.headers.get("X-API-Key", None) api_key = request.headers.get("X-API-Key", None)
if not api_key: if not api_key:
raise UnauthorizedException("Missing API key") raise UnauthorizedException("Missing API key")
service = rn.service service = rn.service
result = ( result = (
db.query(Service) db.query(Service)
.filter(Service.name == service) .filter(Service.name == service)
.filter(Service.api_key == api_key) .filter(Service.api_key == api_key)
.first() .first()
) )
if result is None: if result is None:
raise UnauthorizedException("Invalid API key") raise UnauthorizedException("Invalid API key")
return True return True
service_key_dependency = Annotated[bool, Depends(valid_service_key)] service_key_dependency = Annotated[bool, Depends(valid_service_key)]
async def send_user_group_invitation( async def send_user_group_invitation(
user_email: str, org_name: str, org_id: int, group_id: int, group_name: str user_email: str, org_name: str, org_id: int, group_id: int, group_name: str
): ):
expiry_delta = timedelta(hours=24) expiry_delta = timedelta(hours=24)
expiry = datetime.now(timezone.utc) + expiry_delta expiry = datetime.now(timezone.utc) + expiry_delta
claims = { claims = {
"email": user_email, "email": user_email,
"org_id": org_id, "org_id": org_id,
"group_id": group_id, "group_id": group_id,
"group_name": group_name, "group_name": group_name,
"exp": expiry, "exp": expiry,
"type": "group_invite", "type": "group_invite",
} }
token = await generate_jwt(claims) token = await generate_jwt(claims)
subject = f"You have been invited to join a group of {org_name}" subject = f"You have been invited to join a group of {org_name}"
body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}" body = f"You have been invited to join {group_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
await send_email( await send_email(
recipient=user_email, recipient=user_email,
subject=subject, subject=subject,
body=body, body=body,
) )
async def create_group_and_assign_perms( async def create_group_and_assign_perms(
db: Session, org_model: Org, group_name: str, perm_list: list[int] db: Session, org_model: Org, group_name: str, perm_list: list[int]
): ):
new_group = Group(name=group_name, org_id=org_model.id) new_group = Group(name=group_name, org_id=org_model.id)
db.add(new_group) db.add(new_group)
db.flush() db.flush()
for permission in perm_list: for permission in perm_list:
perm_model = db.get(Perm, permission) perm_model = db.get(Perm, permission)
if perm_model is None: if perm_model is None:
continue continue
new_group.permission_rel.append(perm_model) new_group.permission_rel.append(perm_model)
db.flush() db.flush()
return new_group return new_group
async def assign_default_group( async def assign_default_group(
db: DbSession, db: DbSession,
org_model: Org, org_model: Org,
user_model: User, user_model: User,
group_name: str, group_name: str,
perm_list: list[int], perm_list: list[int],
): ):
group_model = ( group_model = (
db.query(Group) db.query(Group)
.filter(Group.org_id == org_model.id) .filter(Group.org_id == org_model.id)
.filter(Group.name == group_name) .filter(Group.name == group_name)
.first() .first()
) )
if group_model is None: if group_model is None:
group_model = await create_group_and_assign_perms( group_model = await create_group_and_assign_perms(
db=db, group_name=group_name, org_model=org_model, perm_list=perm_list db=db, group_name=group_name, org_model=org_model, perm_list=perm_list
) )
user_model.group_rel.append(group_model) user_model.group_rel.append(group_model)
db.flush() db.flush()

View file

@ -1,6 +1,7 @@
""" """
Application root file: Inits the FastAPI application Application root file: Inits the FastAPI application
""" """
import os.path import os.path
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator from typing import AsyncGenerator
@ -19,43 +20,43 @@ from src.auth.service import get_current_user, get_dev_user
@asynccontextmanager @asynccontextmanager
async def lifespan(_application: FastAPI) -> AsyncGenerator: async def lifespan(_application: FastAPI) -> AsyncGenerator:
# Startup # Startup
yield yield
# Shutdown # Shutdown
if settings.ENVIRONMENT.is_deployed: if settings.ENVIRONMENT.is_deployed:
# Just a precaution, should be False anyway # Just a precaution, should be False anyway
settings.DISABLE_AUTH = False settings.DISABLE_AUTH = False
tags_metadata = [ tags_metadata = [
{ {
"name": "User", "name": "User",
"description": "User related operations, includes getting information about the current user", "description": "User related operations, includes getting information about the current user",
}, },
{ {
"name": "Organisation", "name": "Organisation",
"description": "Organisation related operations, includes getting lists of users etc associated with orgs", "description": "Organisation related operations, includes getting lists of users etc associated with orgs",
}, },
{ {
"name": "Service", "name": "Service",
"description": "Services related operations, includes registering services and reissuing API keys", "description": "Services related operations, includes registering services and reissuing API keys",
}, },
{ {
"name": "IAM", "name": "IAM",
"description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.", "description": "Operations related to the role based identity and access management system. This includes management of groups, permissions, and related users.",
}, },
] ]
app = FastAPI( app = FastAPI(
swagger_ui_init_oauth={ swagger_ui_init_oauth={
"clientId": auth_settings.CLIENT_ID, "clientId": auth_settings.CLIENT_ID,
"usePkceWithAuthorizationCodeGrant": True, "usePkceWithAuthorizationCodeGrant": True,
"scopes": "openid profile email", "scopes": "openid profile email",
}, },
openapi_tags=tags_metadata, openapi_tags=tags_metadata,
) )
# Type inspection disabled for middleware injection. # Type inspection disabled for middleware injection.
@ -64,19 +65,19 @@ app = FastAPI(
app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value()) app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY.get_secret_value())
# noinspection PyTypeChecker # noinspection PyTypeChecker
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=settings.CORS_ORIGINS, allow_origins=settings.CORS_ORIGINS,
allow_origin_regex=settings.CORS_ORIGINS_REGEX, allow_origin_regex=settings.CORS_ORIGINS_REGEX,
allow_credentials=True, allow_credentials=True,
allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"), allow_methods=("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"),
allow_headers=settings.CORS_HEADERS, allow_headers=settings.CORS_HEADERS,
) )
if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL): if settings.DISABLE_AUTH and (settings.ENVIRONMENT == Environment.LOCAL):
app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_current_user] = get_dev_user
app.include_router(api_router) app.include_router(api_router)
if os.path.exists("/app/static"): if os.path.exists("/app/static"):
app.frontend("/ui", directory="/app/static", fallback="index.html") app.frontend("/ui", directory="/app/static", fallback="index.html")

View file

@ -10,28 +10,28 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class CustomBase(DeclarativeBase): class CustomBase(DeclarativeBase):
type_annotation_map = { type_annotation_map = {
datetime: DateTime(timezone=True), datetime: DateTime(timezone=True),
dict[str, Any]: JSON, dict[str, Any]: JSON,
} }
class ActivatedMixin: class ActivatedMixin:
active: Mapped[bool] = mapped_column(default=True) active: Mapped[bool] = mapped_column(default=True)
class DeletedTimestampMixin: class DeletedTimestampMixin:
deleted_at: Mapped[datetime | None] = mapped_column(nullable=True) deleted_at: Mapped[datetime | None] = mapped_column(nullable=True)
class DescriptionMixin: class DescriptionMixin:
description: Mapped[str] description: Mapped[str]
class IdMixin: class IdMixin:
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
class TimestampMixin: class TimestampMixin:
created_at: Mapped[datetime] = mapped_column(default=func.now()) created_at: Mapped[datetime] = mapped_column(default=func.now())
updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now()) updated_at: Mapped[datetime] = mapped_column(default=func.now(), onupdate=func.now())

View file

@ -10,48 +10,48 @@ from enum import StrEnum, auto
class Status(StrEnum): class Status(StrEnum):
""" """
Enumeration of organisation statuses. Enumeration of organisation statuses.
Attributes: Attributes:
PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted. PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted.
SUBMITTED (str): Questionnaire submitted but not approved. SUBMITTED (str): Questionnaire submitted but not approved.
REMEDIATION (str): Questionnaire submitted but requires revisions. REMEDIATION (str): Questionnaire submitted but requires revisions.
APPROVED (str): Questionnaire has been approved by an admin. APPROVED (str): Questionnaire has been approved by an admin.
REJECTED (str): Questionnaire has been rejected by an admin. REJECTED (str): Questionnaire has been rejected by an admin.
REMOVED (str): Organisation has been removed. REMOVED (str): Organisation has been removed.
""" """
PARTIAL = auto() PARTIAL = auto()
SUBMITTED = auto() SUBMITTED = auto()
REMEDIATION = auto() REMEDIATION = auto()
APPROVED = auto() APPROVED = auto()
REJECTED = auto() REJECTED = auto()
REMOVED = auto() REMOVED = auto()
@property @property
def is_pre_approval(self): def is_pre_approval(self):
return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION) return self in (self.PARTIAL, self.SUBMITTED, self.REMEDIATION)
@property @property
def is_pre_submission(self): def is_pre_submission(self):
return self in (self.PARTIAL, self.REMEDIATION) return self in (self.PARTIAL, self.REMEDIATION)
@property @property
def is_blocked(self): def is_blocked(self):
return self in (self.REMOVED, self.REJECTED) return self in (self.REMOVED, self.REJECTED)
class ContactType(StrEnum): class ContactType(StrEnum):
""" """
Enumeration of organisation contact types. Enumeration of organisation contact types.
Attributes: Attributes:
BILLING(str): Billing contact. BILLING(str): Billing contact.
SECURITY (str): Security contact. SECURITY (str): Security contact.
OWNER (str): Owner contact. OWNER (str): Owner contact.
""" """
BILLING = auto() BILLING = auto()
SECURITY = auto() SECURITY = auto()
OWNER = auto() OWNER = auto()

View file

@ -18,25 +18,25 @@ from src.organisation.exceptions import OrgNotFoundException
def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org: def get_org_model_query(db: DbSession, org_id: Annotated[int, Query(gt=0)]) -> Org:
org_model = db.get(Org, org_id) org_model = db.get(Org, org_id)
if org_model is None: if org_model is None:
raise OrgNotFoundException(org_id) raise OrgNotFoundException(org_id)
return org_model return org_model
org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)] org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)]
def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org: def get_org_model_body(db: DbSession, request_model: OrgIDMixin) -> Org:
org_id: Optional[int] = getattr(request_model, "organisation_id", None) org_id: Optional[int] = getattr(request_model, "organisation_id", None)
if org_id is None: if org_id is None:
raise OrgNotFoundException() raise OrgNotFoundException()
org_model = db.get(Org, org_id) org_model = db.get(Org, org_id)
if org_model is None: if org_model is None:
raise OrgNotFoundException(org_id) raise OrgNotFoundException(org_id)
return org_model return org_model
org_model_body_dependency = Annotated[Org, Depends(get_org_model_body)] org_model_body_dependency = Annotated[Org, Depends(get_org_model_body)]

View file

@ -12,26 +12,26 @@ from fastapi import HTTPException, status
class OrgNotFoundException(HTTPException): class OrgNotFoundException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None: def __init__(self, org_id: Optional[int] = None) -> None:
detail = ( detail = (
"Organisation not found" "Organisation not found"
if org_id is None if org_id is None
else f"Organisation with ID '{org_id}' was not found." else f"Organisation with ID '{org_id}' was not found."
) )
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )
class AwaitingApprovalException(HTTPException): class AwaitingApprovalException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None: def __init__(self, org_id: Optional[int] = None) -> None:
detail = ( detail = (
"Organisation has not been approved." "Organisation has not been approved."
if org_id is None if org_id is None
else f"Organisation with ID '{org_id}' has not been approved." else f"Organisation with ID '{org_id}' has not been approved."
) )
super().__init__( super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail, detail=detail,
) )

View file

@ -25,51 +25,51 @@ from src.models import CustomBase, TimestampMixin
class Organisation(CustomBase, IdMixin, TimestampMixin, DeletedTimestampMixin): class Organisation(CustomBase, IdMixin, TimestampMixin, DeletedTimestampMixin):
__tablename__ = "organisation" __tablename__ = "organisation"
name: Mapped[str] name: Mapped[str]
status: Mapped[str] = mapped_column(default="partial") status: Mapped[str] = mapped_column(default="partial")
intake_questionnaire: Mapped[dict[str, Any] | None] intake_questionnaire: Mapped[dict[str, Any] | None]
root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id"))
billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
security_contact_id: Mapped[int] = mapped_column( security_contact_id: Mapped[int] = mapped_column(
ForeignKey("contact.id"), nullable=True ForeignKey("contact.id"), nullable=True
) )
owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True)
user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel") user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel")
group_rel = relationship( group_rel = relationship(
"Group", back_populates="org_rel", cascade="all, delete-orphan" "Group", back_populates="org_rel", cascade="all, delete-orphan"
) )
root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id") root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id")
billing_contact_rel = relationship( billing_contact_rel = relationship(
"Contact", foreign_keys="Organisation.billing_contact_id" "Contact", foreign_keys="Organisation.billing_contact_id"
) )
security_contact_rel = relationship( security_contact_rel = relationship(
"Contact", foreign_keys="Organisation.security_contact_id" "Contact", foreign_keys="Organisation.security_contact_id"
) )
owner_contact_rel = relationship( owner_contact_rel = relationship(
"Contact", foreign_keys="Organisation.owner_contact_id" "Contact", foreign_keys="Organisation.owner_contact_id"
) )
permission_rel = relationship( permission_rel = relationship(
"Permission", secondary="org_permissions", back_populates="org_rel" "Permission", secondary="org_permissions", back_populates="org_rel"
) )
@property @property
def root_user_email(self) -> str: def root_user_email(self) -> str:
return self.root_user_rel.email if self.root_user_rel else "" return self.root_user_rel.email if self.root_user_rel else ""
class OrgUsers(CustomBase): class OrgUsers(CustomBase):
__tablename__ = "orgusers" __tablename__ = "orgusers"
org_id: Mapped[int] = mapped_column( org_id: Mapped[int] = mapped_column(
ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
) )
user_id: Mapped[int] = mapped_column( user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
) )

File diff suppressed because it is too large Load diff

View file

@ -12,139 +12,139 @@ from datetime import datetime
from pydantic import EmailStr, ConfigDict, Field from pydantic import EmailStr, ConfigDict, Field
from src.schemas import ( from src.schemas import (
CustomBaseModel, CustomBaseModel,
OrgIDMixin, OrgIDMixin,
UserIDMixin, UserIDMixin,
GroupSummary, GroupSummary,
OrgSummary, OrgSummary,
UserSummary, UserSummary,
) )
from src.contact.schemas import ContactModel from src.contact.schemas import ContactModel
from src.organisation.constants import Status, ContactType from src.organisation.constants import Status, ContactType
from src.organisation.schemas_questionnaires import ( from src.organisation.schemas_questionnaires import (
QuestionnaireQuestionsVersion0 as CurrentQuestions, QuestionnaireQuestionsVersion0 as CurrentQuestions,
questionnaire_union, questionnaire_union,
) )
class QuestionnaireMetadata(CustomBaseModel): class QuestionnaireMetadata(CustomBaseModel):
version: int version: int
submission_date: Optional[datetime] = None submission_date: Optional[datetime] = None
class Questionnaire(CustomBaseModel): class Questionnaire(CustomBaseModel):
metadata: QuestionnaireMetadata metadata: QuestionnaireMetadata
questions: questionnaire_union questions: questionnaire_union
class ContactSummary(CustomBaseModel): class ContactSummary(CustomBaseModel):
id: int id: int
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
class OrgSchema(OrgIDMixin): class OrgSchema(OrgIDMixin):
name: str name: str
status: Status status: Status
root_user_email: EmailStr root_user_email: EmailStr
intake_questionnaire: Optional[Questionnaire] = None intake_questionnaire: Optional[Questionnaire] = None
billing_contact: ContactSummary billing_contact: ContactSummary
owner_contact: ContactSummary owner_contact: ContactSummary
security_contact: ContactSummary security_contact: ContactSummary
class OrgPostOrgRequest(CustomBaseModel): class OrgPostOrgRequest(CustomBaseModel):
name: str = Field(min_length=3) name: str = Field(min_length=3)
intake_questionnaire: Optional[CurrentQuestions] = None intake_questionnaire: Optional[CurrentQuestions] = None
class OrgPostOrgResponse(CustomBaseModel): class OrgPostOrgResponse(CustomBaseModel):
id: int id: int
name: str name: str
status: Status status: Status
class OrgPatchQuestionnaireRequest(OrgIDMixin): class OrgPatchQuestionnaireRequest(OrgIDMixin):
intake_questionnaire: CurrentQuestions intake_questionnaire: CurrentQuestions
partial: bool partial: bool
class OrgPatchQuestionnaireResponse(CustomBaseModel): class OrgPatchQuestionnaireResponse(CustomBaseModel):
id: int id: int
name: str name: str
intake_questionnaire: Questionnaire intake_questionnaire: Questionnaire
status: Status status: Status
class OrgPatchStatusRequest(OrgIDMixin): class OrgPatchStatusRequest(OrgIDMixin):
status: Status status: Status
class OrgPatchStatusResponse(CustomBaseModel): class OrgPatchStatusResponse(CustomBaseModel):
id: int id: int
name: str name: str
status: Status status: Status
class OrgPatchContactRequest(OrgIDMixin): class OrgPatchContactRequest(OrgIDMixin):
contact_type: ContactType contact_type: ContactType
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
first_name: Optional[str] = None first_name: Optional[str] = None
last_name: Optional[str] = None last_name: Optional[str] = None
phonenumber: Optional[str] = None phonenumber: Optional[str] = None
vat_number: Optional[str] = None vat_number: Optional[str] = None
post_office_box_number: Optional[str] = None post_office_box_number: Optional[str] = None
street_address: Optional[str] = None street_address: Optional[str] = None
street_address_line_2: Optional[str] = None street_address_line_2: Optional[str] = None
locality: Optional[str] = None locality: Optional[str] = None
address_region: Optional[str] = None address_region: Optional[str] = None
country_code: Optional[str] = None country_code: Optional[str] = None
postal_code: Optional[str] = None postal_code: Optional[str] = None
class OrgPostUserRequest(OrgIDMixin, UserIDMixin): class OrgPostUserRequest(OrgIDMixin, UserIDMixin):
pass pass
class OrgPostUserResponse(CustomBaseModel): class OrgPostUserResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
users: list[UserSummary] users: list[UserSummary]
class OrgPatchRootRequest(OrgIDMixin, UserIDMixin): class OrgPatchRootRequest(OrgIDMixin, UserIDMixin):
pass pass
class OrgPatchRootResponse(CustomBaseModel): class OrgPatchRootResponse(CustomBaseModel):
name: str name: str
root_user_email: str root_user_email: str
class OrgGetUserResponse(CustomBaseModel): class OrgGetUserResponse(CustomBaseModel):
users: list[dict[str, str | int]] users: list[dict[str, str | int]]
organisation: OrgSummary organisation: OrgSummary
class OrgGetGroupResponse(CustomBaseModel): class OrgGetGroupResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
groups: list[GroupSummary] groups: list[GroupSummary]
class OrgGetContactResponse(CustomBaseModel): class OrgGetContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel contact: ContactModel
organisation: OrgSummary organisation: OrgSummary
class OrgPatchContactResponse(CustomBaseModel): class OrgPatchContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel contact: ContactModel
organisation: OrgSummary organisation: OrgSummary
class OrgGetOrgResponse(CustomBaseModel): class OrgGetOrgResponse(CustomBaseModel):
organisations: list[OrgSchema] organisations: list[OrgSchema]

View file

@ -4,13 +4,13 @@ from src.schemas import CustomBaseModel
class QuestionnaireQuestions(CustomBaseModel): class QuestionnaireQuestions(CustomBaseModel):
pass pass
class QuestionnaireQuestionsVersion0(QuestionnaireQuestions): class QuestionnaireQuestionsVersion0(QuestionnaireQuestions):
question_one: Optional[str] = None question_one: Optional[str] = None
question_two: Optional[str] = None question_two: Optional[str] = None
question_three: Optional[str] = None question_three: Optional[str] = None
questionnaire_union = QuestionnaireQuestionsVersion0 # | QuestionnaireQuestionsVersion1 questionnaire_union = QuestionnaireQuestionsVersion0 # | QuestionnaireQuestionsVersion1

View file

@ -11,57 +11,57 @@ from src.user.models import User
async def add_default_org_permissions( async def add_default_org_permissions(
db: Session, db: Session,
org_model: Org, org_model: Org,
perm_list: list[int], perm_list: list[int],
): ):
for permission in perm_list: for permission in perm_list:
perm_model = db.get(Perm, permission) perm_model = db.get(Perm, permission)
if perm_model is None: if perm_model is None:
continue continue
if perm_model in org_model.permission_rel: if perm_model in org_model.permission_rel:
continue continue
org_model.permission_rel.append(perm_model) org_model.permission_rel.append(perm_model)
db.flush() db.flush()
db.commit() db.commit()
async def assign_defaults( async def assign_defaults(
db: Session, db: Session,
org_id: int, org_id: int,
user_id: int, user_id: int,
): ):
default_org_permissions = [] default_org_permissions = []
default_user_permissions = [] default_user_permissions = []
org_model = db.get(Org, org_id) org_model = db.get(Org, org_id)
if org_model is None: if org_model is None:
print("Org not found while adding defaults") print("Org not found while adding defaults")
return return
user_model = db.get(User, user_id) user_model = db.get(User, user_id)
if user_model is None: if user_model is None:
print("User not found while adding defaults") print("User not found while adding defaults")
return return
await add_default_org_permissions(db, org_model, default_org_permissions) await add_default_org_permissions(db, org_model, default_org_permissions)
await assign_default_group( await assign_default_group(
db=db, db=db,
org_model=org_model, org_model=org_model,
user_model=user_model, user_model=user_model,
group_name="Default Users", group_name="Default Users",
perm_list=default_user_permissions, perm_list=default_user_permissions,
) )
await assign_default_group( await assign_default_group(
db=db, db=db,
org_model=org_model, org_model=org_model,
user_model=user_model, user_model=user_model,
group_name="Root User", group_name="Root User",
perm_list=default_org_permissions, perm_list=default_org_permissions,
) )
db.commit() db.commit()

View file

@ -11,54 +11,54 @@ from typing import Optional
class CustomBaseModel(BaseModel): class CustomBaseModel(BaseModel):
pass pass
### Mixins ### ### Mixins ###
class OrgIDMixin(CustomBaseModel): class OrgIDMixin(CustomBaseModel):
organisation_id: int = Field(gt=0) organisation_id: int = Field(gt=0)
class GroupIDMixin(CustomBaseModel): class GroupIDMixin(CustomBaseModel):
group_id: int = Field(gt=0) group_id: int = Field(gt=0)
class PermIDMixin(CustomBaseModel): class PermIDMixin(CustomBaseModel):
permission_id: int = Field(gt=0) permission_id: int = Field(gt=0)
class ServiceIDMixin(CustomBaseModel): class ServiceIDMixin(CustomBaseModel):
service_id: int = Field(gt=0) service_id: int = Field(gt=0)
class UserIDMixin(CustomBaseModel): class UserIDMixin(CustomBaseModel):
user_id: int = Field(gt=0) user_id: int = Field(gt=0)
class ServiceNameMixin(CustomBaseModel): class ServiceNameMixin(CustomBaseModel):
service: str service: str
class OrgSummary(CustomBaseModel): class OrgSummary(CustomBaseModel):
id: int id: int
name: str name: str
class GroupSummary(CustomBaseModel): class GroupSummary(CustomBaseModel):
id: int id: int
name: str name: str
class UserSummary(CustomBaseModel): class UserSummary(CustomBaseModel):
id: int id: int
email: str email: str
class ServiceSummary(CustomBaseModel): class ServiceSummary(CustomBaseModel):
id: int id: int
name: str name: str
class ResourceName(ServiceNameMixin, OrgIDMixin): class ResourceName(ServiceNameMixin, OrgIDMixin):
resource: str resource: str
instance: Optional[str] = None instance: Optional[str] = None

View file

@ -16,25 +16,23 @@ from src.service.models import Service
from src.service.schemas import ServiceIDMixin from src.service.schemas import ServiceIDMixin
async def get_service_model_query( async def get_service_model_query(db: DbSession, service_id: Annotated[int, Query(gt=0)]):
db: DbSession, service_id: Annotated[int, Query(gt=0)] service_model = db.get(Service, service_id)
): if service_model is None:
service_model = db.get(Service, service_id) raise ServiceNotFoundException(service_id=service_id)
if service_model is None:
raise ServiceNotFoundException(service_id=service_id)
return service_model return service_model
service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)] service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)]
async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin): async def get_service_model_body(db: DbSession, request_model: ServiceIDMixin):
service_model = db.get(Service, request_model.service_id) service_model = db.get(Service, request_model.service_id)
if service_model is None: if service_model is None:
raise ServiceNotFoundException(service_id=request_model.service_id) raise ServiceNotFoundException(service_id=request_model.service_id)
return service_model return service_model
service_model_body_dependency = Annotated[Service, Depends(get_service_model_body)] service_model_body_dependency = Annotated[Service, Depends(get_service_model_body)]

View file

@ -11,13 +11,13 @@ from fastapi import HTTPException, status
class ServiceNotFoundException(HTTPException): class ServiceNotFoundException(HTTPException):
def __init__(self, service_id: Optional[int] = None) -> None: def __init__(self, service_id: Optional[int] = None) -> None:
detail = ( detail = (
"Service not found" "Service not found"
if service_id is None if service_id is None
else f"Service with ID '{service_id}' was not found." else f"Service with ID '{service_id}' was not found."
) )
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )

View file

@ -12,11 +12,11 @@ from src.models import CustomBase, IdMixin
class Service(CustomBase, IdMixin): class Service(CustomBase, IdMixin):
__tablename__ = "service" __tablename__ = "service"
name: Mapped[str] = mapped_column(unique=True) name: Mapped[str] = mapped_column(unique=True)
api_key: Mapped[str] api_key: Mapped[str]
permission_rel = relationship( permission_rel = relationship(
"Permission", back_populates="service_rel", cascade="all, delete-orphan" "Permission", back_populates="service_rel", cascade="all, delete-orphan"
) )

View file

@ -15,8 +15,8 @@ from psycopg.errors import UniqueViolation
from src.exceptions import ConflictException from src.exceptions import ConflictException
from src.database import DbSession from src.database import DbSession
from src.auth.dependencies import ( from src.auth.dependencies import (
super_admin_dependency, super_admin_dependency,
org_model_root_claim_query_dependency, org_model_root_claim_query_dependency,
) )
from src.iam.service import service_key_dependency from src.iam.service import service_key_dependency
from src.iam.models import Permission as Perm from src.iam.models import Permission as Perm
@ -25,212 +25,210 @@ from src.service.exceptions import ServiceNotFoundException
from src.service.models import Service from src.service.models import Service
from src.service.utils import generate_api_key from src.service.utils import generate_api_key
from src.service.dependencies import ( from src.service.dependencies import (
service_model_body_dependency, service_model_body_dependency,
service_model_query_dependency, service_model_query_dependency,
) )
from src.service.schemas import ( from src.service.schemas import (
ServiceGetServiceResponse, ServiceGetServiceResponse,
ServicePostServiceRequest, ServicePostServiceRequest,
ServicePostServiceResponse, ServicePostServiceResponse,
ServiceWithKeySchema, ServiceWithKeySchema,
ServicePatchKeyResponse, ServicePatchKeyResponse,
ServicePatchKeyRequest, ServicePatchKeyRequest,
ServicePostPermissionsResponse, ServicePostPermissionsResponse,
ServicePostPermissionsRequest, ServicePostPermissionsRequest,
) )
router = APIRouter( router = APIRouter(
tags=["Service"], tags=["Service"],
prefix="/service", prefix="/service",
) )
@router.get( @router.get(
"", "",
summary="Get all services", summary="Get all services",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=ServiceGetServiceResponse, response_model=ServiceGetServiceResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
status.HTTP_401_UNAUTHORIZED: { status.HTTP_401_UNAUTHORIZED: {
"description": "Unauthorized", "description": "Unauthorized",
"content": { "content": {
"application/json": { "application/json": {
"examples": { "examples": {
"awaiting_approval": { "awaiting_approval": {
"summary": "Organisation has not yet been approved." "summary": "Organisation has not yet been approved."
}, },
} }
} }
}, },
}, },
status.HTTP_403_FORBIDDEN: { status.HTTP_403_FORBIDDEN: {
"description": "Forbidden", "description": "Forbidden",
"content": { "content": {
"application/json": { "application/json": {
"examples": { "examples": {
"not_root": {"summary": "Not authorised. Must be root user."}, "not_root": {"summary": "Not authorised. Must be root user."},
} }
} }
}, },
}, },
}, },
) )
async def get_all_services( async def get_all_services(db: DbSession, org_model: org_model_root_claim_query_dependency):
db: DbSession, org_model: org_model_root_claim_query_dependency """
): Returns the ID and name of all services registered to the hub.
""" """
Returns the ID and name of all services registered to the hub. permission_models = db.query(Service).all()
"""
permission_models = db.query(Service).all()
return {"services": permission_models} return {"services": permission_models}
@router.post( @router.post(
"", "",
summary="Register a new service.", summary="Register a new service.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=ServicePostServiceResponse, response_model=ServicePostServiceResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successfully registered a new service"}, status.HTTP_200_OK: {"description": "Successfully registered a new service"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"}, status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"},
}, },
) )
async def register_service( async def register_service(
db: DbSession, db: DbSession,
su: super_admin_dependency, su: super_admin_dependency,
request_model: ServicePostServiceRequest, request_model: ServicePostServiceRequest,
): ):
""" """
Registers a new service to the hub, generating and returning an API key for it. Registers a new service to the hub, generating and returning an API key for it.
""" """
key = generate_api_key() key = generate_api_key()
service_model = Service(name=request_model.name, api_key=key) service_model = Service(name=request_model.name, api_key=key)
db.add(service_model) db.add(service_model)
try: try:
db.flush() db.flush()
except IntegrityError as e: except IntegrityError as e:
if ( if (
isinstance(e.orig, UniqueViolation) # Postgres unique violation isinstance(e.orig, UniqueViolation) # Postgres unique violation
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
): ):
raise ConflictException(message="Service with this name already exists") raise ConflictException(message="Service with this name already exists")
raise raise
response = ServiceWithKeySchema(**service_model.__dict__) response = ServiceWithKeySchema(**service_model.__dict__)
db.commit() db.commit()
return {"service": response} return {"service": response}
@router.patch( @router.patch(
"/key", "/key",
summary="Regenerate service API key.", summary="Regenerate service API key.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=ServicePatchKeyResponse, response_model=ServicePatchKeyResponse,
responses={ responses={
status.HTTP_200_OK: {"description": "Successful update of API key"}, status.HTTP_200_OK: {"description": "Successful update of API key"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
}, },
) )
async def regenerate_api_key( async def regenerate_api_key(
db: DbSession, db: DbSession,
su: super_admin_dependency, su: super_admin_dependency,
service_model: service_model_body_dependency, service_model: service_model_body_dependency,
request_model: ServicePatchKeyRequest, request_model: ServicePatchKeyRequest,
): ):
""" """
Generates and returns a new API key for the service to access the hub. Generates and returns a new API key for the service to access the hub.
""" """
key = generate_api_key() key = generate_api_key()
service_model.api_key = key service_model.api_key = key
db.flush() db.flush()
response = ServiceWithKeySchema(**service_model.__dict__) response = ServiceWithKeySchema(**service_model.__dict__)
db.commit() db.commit()
return {"service": response} return {"service": response}
@router.delete( @router.delete(
"", "",
summary="Remove a service.", summary="Remove a service.",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
responses={ responses={
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"}, status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
}, },
) )
async def remove_service( async def remove_service(
db: DbSession, db: DbSession,
service_model: service_model_query_dependency, service_model: service_model_query_dependency,
su: super_admin_dependency, su: super_admin_dependency,
): ):
""" """
Removes a service from the hub. Removes a service from the hub.
""" """
db.delete(service_model) db.delete(service_model)
db.commit() db.commit()
@router.post( @router.post(
path="/permissions", path="/permissions",
summary="Service endpoint for creating its own permissions.", summary="Service endpoint for creating its own permissions.",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=ServicePostPermissionsResponse, response_model=ServicePostPermissionsResponse,
responses={ responses={
status.HTTP_401_UNAUTHORIZED: { status.HTTP_401_UNAUTHORIZED: {
"description": "API Key missing or invalid | Issue verifying user OIDC claims" "description": "API Key missing or invalid | Issue verifying user OIDC claims"
}, },
}, },
) )
async def service_create_new_permissions( async def service_create_new_permissions(
db: DbSession, db: DbSession,
request_model: ServicePostPermissionsRequest, request_model: ServicePostPermissionsRequest,
valid_key: service_key_dependency, valid_key: service_key_dependency,
): ):
""" """
Allows a service to register its own set of permissions. Allows a service to register its own set of permissions.
""" """
service_model = ( service_model = (
db.query(Service).filter(Service.name == request_model.rn.service).first() db.query(Service).filter(Service.name == request_model.rn.service).first()
) )
if service_model is None: if service_model is None:
raise ServiceNotFoundException() raise ServiceNotFoundException()
else: else:
service_id = service_model.id service_id = service_model.id
response_list = [] response_list = []
for new_permission in request_model.permissions: for new_permission in request_model.permissions:
perm_model = ( perm_model = (
db.query(Perm) db.query(Perm)
.filter(Perm.service_id == service_id) .filter(Perm.service_id == service_id)
.filter(Perm.resource == new_permission.resource) .filter(Perm.resource == new_permission.resource)
.filter(Perm.action == new_permission.action) .filter(Perm.action == new_permission.action)
.first() .first()
) )
if perm_model is not None: if perm_model is not None:
response_code = 409 response_code = 409
response = { response = {
"id": perm_model.id, "id": perm_model.id,
"service_name": perm_model.service_name, "service_name": perm_model.service_name,
"resource": perm_model.resource, "resource": perm_model.resource,
"action": perm_model.action, "action": perm_model.action,
} }
response_list.append((response, response_code)) response_list.append((response, response_code))
continue continue
new_perm_model = Perm(**new_permission.__dict__) new_perm_model = Perm(**new_permission.__dict__)
new_perm_model.service_id = service_id new_perm_model.service_id = service_id
db.add(new_perm_model) db.add(new_perm_model)
db.flush() db.flush()
response_code = 201 response_code = 201
response = { response = {
"id": new_perm_model.id, "id": new_perm_model.id,
"service_name": new_perm_model.service_name, "service_name": new_perm_model.service_name,
"resource": new_perm_model.resource, "resource": new_perm_model.resource,
"action": new_perm_model.action, "action": new_perm_model.action,
} }
response_list.append((response, response_code)) response_list.append((response, response_code))
db.commit() db.commit()
return {"permissions": response_list} return {"permissions": response_list}

View file

@ -10,10 +10,10 @@ from typing import Generic, TypeVar
from pydantic import Field, ConfigDict from pydantic import Field, ConfigDict
from src.schemas import ( from src.schemas import (
CustomBaseModel, CustomBaseModel,
ServiceIDMixin, ServiceIDMixin,
ServiceSummary, ServiceSummary,
ServiceNameMixin, ServiceNameMixin,
) )
@ -21,51 +21,51 @@ T = TypeVar("T", bound=ServiceNameMixin)
class HasServiceName(CustomBaseModel, Generic[T]): class HasServiceName(CustomBaseModel, Generic[T]):
rn: T rn: T
class PermissionResponseSchema(CustomBaseModel): class PermissionResponseSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int id: int
service_name: str service_name: str
resource: str resource: str
action: str action: str
class PermissionRequestSchema(CustomBaseModel): class PermissionRequestSchema(CustomBaseModel):
resource: str resource: str
action: str action: str
class ServiceWithKeySchema(ServiceSummary): class ServiceWithKeySchema(ServiceSummary):
api_key: str api_key: str
class ServiceGetServiceResponse(CustomBaseModel): class ServiceGetServiceResponse(CustomBaseModel):
services: list[ServiceSummary] services: list[ServiceSummary]
class ServicePostServiceRequest(CustomBaseModel): class ServicePostServiceRequest(CustomBaseModel):
name: str = Field(min_length=3) name: str = Field(min_length=3)
class ServicePostServiceResponse(CustomBaseModel): class ServicePostServiceResponse(CustomBaseModel):
service: ServiceWithKeySchema service: ServiceWithKeySchema
class ServicePatchKeyRequest(ServiceIDMixin): class ServicePatchKeyRequest(ServiceIDMixin):
pass pass
class ServicePatchKeyResponse(CustomBaseModel): class ServicePatchKeyResponse(CustomBaseModel):
service: ServiceWithKeySchema service: ServiceWithKeySchema
class ServicePostPermissionsRequest(CustomBaseModel): class ServicePostPermissionsRequest(CustomBaseModel):
rn: ServiceNameMixin rn: ServiceNameMixin
permissions: list[PermissionRequestSchema] permissions: list[PermissionRequestSchema]
class ServicePostPermissionsResponse(CustomBaseModel): class ServicePostPermissionsResponse(CustomBaseModel):
permissions: list[tuple[PermissionResponseSchema, int]] permissions: list[tuple[PermissionResponseSchema, int]]

View file

@ -9,4 +9,4 @@ import uuid
def generate_api_key() -> str: def generate_api_key() -> str:
return str(uuid.uuid4()) return str(uuid.uuid4())

View file

@ -19,37 +19,37 @@ from src.schemas import UserIDMixin
async def get_user_model_claims(claims: claims_dependency, db: DbSession): async def get_user_model_claims(claims: claims_dependency, db: DbSession):
user_id = claims.get("db_id", None) user_id = claims.get("db_id", None)
if user_id is None: if user_id is None:
raise UserNotFoundException() raise UserNotFoundException()
user_model = db.get(User, user_id) user_model = db.get(User, user_id)
if user_model is None: if user_model is None:
raise UserNotFoundException(user_id=user_id) raise UserNotFoundException(user_id=user_id)
return user_model return user_model
user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)] user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)]
async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]): async def get_user_model_query(db: DbSession, user_id: Annotated[int, Query(gt=0)]):
user_model = db.get(User, user_id) user_model = db.get(User, user_id)
if user_model is None: if user_model is None:
raise UserNotFoundException(user_id=user_id) raise UserNotFoundException(user_id=user_id)
return user_model return user_model
user_model_query_dependency = Annotated[User, Depends(get_user_model_query)] user_model_query_dependency = Annotated[User, Depends(get_user_model_query)]
async def get_user_model_body(db: DbSession, request_model: UserIDMixin): async def get_user_model_body(db: DbSession, request_model: UserIDMixin):
user_model = db.get(User, request_model.user_id) user_model = db.get(User, request_model.user_id)
if user_model is None: if user_model is None:
raise UserNotFoundException(user_id=request_model.user_id) raise UserNotFoundException(user_id=request_model.user_id)
return user_model return user_model
user_model_body_dependency = Annotated[User, Depends(get_user_model_body)] user_model_body_dependency = Annotated[User, Depends(get_user_model_body)]

View file

@ -11,13 +11,13 @@ from fastapi import HTTPException, status
class UserNotFoundException(HTTPException): class UserNotFoundException(HTTPException):
def __init__(self, user_id: Optional[int] = None) -> None: def __init__(self, user_id: Optional[int] = None) -> None:
detail = ( detail = (
"User not found" "User not found"
if user_id is None if user_id is None
else f"User with ID '{user_id}' was not found." else f"User with ID '{user_id}' was not found."
) )
super().__init__( super().__init__(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )

View file

@ -20,26 +20,26 @@ from src.models import CustomBase
class User(CustomBase, IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin): class User(CustomBase, IdMixin, ActivatedMixin, TimestampMixin, DeletedTimestampMixin):
__tablename__ = "user" __tablename__ = "user"
email: Mapped[str] email: Mapped[str]
first_name: Mapped[str] first_name: Mapped[str]
last_name: Mapped[str] last_name: Mapped[str]
oidc_id: Mapped[str] = mapped_column(index=True, unique=True) oidc_id: Mapped[str] = mapped_column(index=True, unique=True)
organisation_rel = relationship( organisation_rel = relationship(
"Organisation", secondary="orgusers", back_populates="user_rel" "Organisation", secondary="orgusers", back_populates="user_rel"
) )
group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel") group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel")
@property @property
def organisations(self): def organisations(self):
return [{"name": org.name, "id": org.id} for org in self.organisation_rel] return [{"name": org.name, "id": org.id} for org in self.organisation_rel]
@property @property
def groups(self): def groups(self):
result = defaultdict(list) result = defaultdict(list)
for group in self.group_rel: for group in self.group_rel:
result[group.org_rel.name].append({"name": group.name, "id": group.id}) result[group.org_rel.name].append({"name": group.name, "id": group.id})
return dict(result) return dict(result)

View file

@ -13,205 +13,205 @@ from fastapi import APIRouter, status, BackgroundTasks
from src.iam.models import Group from src.iam.models import Group
from src.organisation.exceptions import OrgNotFoundException from src.organisation.exceptions import OrgNotFoundException
from src.user.schemas import ( from src.user.schemas import (
UserResponse, UserResponse,
OIDCClaims, OIDCClaims,
UserPostInvitationRequest, UserPostInvitationRequest,
UserPostInvitationAcceptRequest, UserPostInvitationAcceptRequest,
UserGetSelfOrgsResponse, UserGetSelfOrgsResponse,
UserPostInvitationResponse, UserPostInvitationResponse,
UserPostInvitationAcceptResponse, UserPostInvitationAcceptResponse,
) )
from src.user.dependencies import ( from src.user.dependencies import (
user_model_claims_dependency, user_model_claims_dependency,
user_model_query_dependency, user_model_query_dependency,
) )
from src.user.service import send_invitation from src.user.service import send_invitation
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
from src.auth.dependencies import ( from src.auth.dependencies import (
super_admin_dependency, super_admin_dependency,
org_model_root_claim_body_dependency, org_model_root_claim_body_dependency,
) )
from src.auth.service import claims_dependency from src.auth.service import claims_dependency
from src.database import DbSession from src.database import DbSession
from src.utils import verify_email_token from src.utils import verify_email_token
router = APIRouter( router = APIRouter(
prefix="/user", prefix="/user",
tags=["User"], tags=["User"],
) )
@router.get( @router.get(
"/self/claims", "/self/claims",
summary="Get current user OIDC claims.", summary="Get current user OIDC claims.",
response_model=OIDCClaims, response_model=OIDCClaims,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
responses={ responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
}, },
) )
async def current_user_claims(user: claims_dependency): async def current_user_claims(user: claims_dependency):
""" """
Returns the full OIDC claims associated with the currently logged-in user. Returns the full OIDC claims associated with the currently logged-in user.
""" """
user["allowed_origins"] = user.get("allowed-origins", []) user["allowed_origins"] = user.get("allowed-origins", [])
return user return user
@router.get( @router.get(
"/self/db", "/self/db",
summary="Get current user hub details.", summary="Get current user hub details.",
response_model=UserResponse, response_model=UserResponse,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
responses={ responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
}, },
) )
async def current_user(user_model: user_model_claims_dependency): async def current_user(user_model: user_model_claims_dependency):
""" """
Returns the database details associated with the currently logged-in user. Returns the database details associated with the currently logged-in user.
""" """
return user_model return user_model
@router.get( @router.get(
"", "",
summary="Get user hub details by ID.", summary="Get user hub details by ID.",
response_model=UserResponse, response_model=UserResponse,
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
responses={ responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"}, status.HTTP_200_OK: {"description": "Successful retrieval from database"},
}, },
) )
async def get_user_by_id( async def get_user_by_id(
user_model: user_model_query_dependency, su: super_admin_dependency user_model: user_model_query_dependency, su: super_admin_dependency
): ):
""" """
Returns the database details associated with the provided user ID. Returns the database details associated with the provided user ID.
""" """
return user_model return user_model
@router.delete( @router.delete(
"", "",
summary="Delete user from hub by ID.", summary="Delete user from hub by ID.",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
responses={ responses={
status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
status.HTTP_404_NOT_FOUND: {"description": "User not found"}, status.HTTP_404_NOT_FOUND: {"description": "User not found"},
}, },
) )
async def delete_user_by_id( async def delete_user_by_id(
db: DbSession, db: DbSession,
user_model: user_model_query_dependency, user_model: user_model_query_dependency,
su: super_admin_dependency, su: super_admin_dependency,
): ):
""" """
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login. Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
""" """
db.delete(user_model) db.delete(user_model)
db.commit() db.commit()
@router.get( @router.get(
"/self/orgs", "/self/orgs",
summary="Get all orgs the current user is a member of", summary="Get all orgs the current user is a member of",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=UserGetSelfOrgsResponse, response_model=UserGetSelfOrgsResponse,
responses={}, responses={},
) )
async def get_user_orgs(user_model: user_model_claims_dependency): async def get_user_orgs(user_model: user_model_claims_dependency):
user_orgs = user_model.organisation_rel user_orgs = user_model.organisation_rel
response = [] response = []
for org in user_orgs: for org in user_orgs:
response.append( response.append(
{ {
"organisation_id": org.id, "organisation_id": org.id,
"name": org.name, "name": org.name,
"status": org.status, "status": org.status,
"intake_questionnaire": org.intake_questionnaire, "intake_questionnaire": org.intake_questionnaire,
"root_user_email": org.root_user_email, "root_user_email": org.root_user_email,
"billing_contact": { "billing_contact": {
"id": org.billing_contact_id, "id": org.billing_contact_id,
"email": org.billing_contact_rel.email, "email": org.billing_contact_rel.email,
}, },
"owner_contact": { "owner_contact": {
"id": org.owner_contact_id, "id": org.owner_contact_id,
"email": org.owner_contact_rel.email, "email": org.owner_contact_rel.email,
}, },
"security_contact": { "security_contact": {
"id": org.security_contact_id, "id": org.security_contact_id,
"email": org.security_contact_rel.email, "email": org.security_contact_rel.email,
}, },
} }
) )
return {"organisations": response} return {"organisations": response}
@router.post( @router.post(
"/invitation", "/invitation",
summary="Send an email invitation for a user to join an org", summary="Send an email invitation for a user to join an org",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=UserPostInvitationResponse, response_model=UserPostInvitationResponse,
) )
async def invitation( async def invitation(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
org_model: org_model_root_claim_body_dependency, org_model: org_model_root_claim_body_dependency,
request_model: UserPostInvitationRequest, request_model: UserPostInvitationRequest,
): ):
org_id = org_model.id org_id = org_model.id
org_name = org_model.name org_name = org_model.name
user_email = request_model.user_email user_email = request_model.user_email
background_tasks.add_task( background_tasks.add_task(
send_invitation, org_id=org_id, org_name=org_name, user_email=user_email send_invitation, org_id=org_id, org_name=org_name, user_email=user_email
) )
response = { response = {
"organisation": org_model, "organisation": org_model,
"invited_email": user_email, "invited_email": user_email,
} }
return response return response
@router.post( @router.post(
"/invitation/accept", "/invitation/accept",
summary="Accept email invitation to join an org", summary="Accept email invitation to join an org",
status_code=status.HTTP_200_OK, status_code=status.HTTP_200_OK,
response_model=UserPostInvitationAcceptResponse, response_model=UserPostInvitationAcceptResponse,
) )
async def accept_invitation( async def accept_invitation(
db: DbSession, db: DbSession,
user_model: user_model_claims_dependency, user_model: user_model_claims_dependency,
request_model: UserPostInvitationAcceptRequest, request_model: UserPostInvitationAcceptRequest,
): ):
email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model) email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model)
org_model = db.get(Org, email_claims["org_id"]) org_model = db.get(Org, email_claims["org_id"])
if org_model is None: if org_model is None:
raise OrgNotFoundException() raise OrgNotFoundException()
org_model.user_rel.append(user_model) org_model.user_rel.append(user_model)
db.flush() db.flush()
group_model = ( group_model = (
db.query(Group) db.query(Group)
.filter(Group.org_id == org_model.id) .filter(Group.org_id == org_model.id)
.filter(Group.name == "Default Users") .filter(Group.name == "Default Users")
.first() .first()
) )
if group_model is not None: if group_model is not None:
user_model.group_rel.append(group_model) user_model.group_rel.append(group_model)
response = { response = {
"organisation": org_model, "organisation": org_model,
"user": user_model, "user": user_model,
} }
db.commit() db.commit()
return response return response

View file

@ -10,63 +10,63 @@ from src.schemas import CustomBaseModel, OrgIDMixin, OrgSummary, UserSummary
class OIDCClaims(CustomBaseModel): class OIDCClaims(CustomBaseModel):
exp: int exp: int
iat: int iat: int
auth_time: int auth_time: int
jti: str jti: str
iss: str iss: str
aud: str aud: str
sub: str sub: str
typ: str typ: str
azp: str azp: str
sid: str sid: str
acr: str acr: str
allowed_origins: list[str] allowed_origins: list[str]
realm_access: dict[str, list[str]] realm_access: dict[str, list[str]]
resource_access: dict[str, dict[str, list[str]]] resource_access: dict[str, dict[str, list[str]]]
scope: str scope: str
email_verified: bool email_verified: bool
name: str name: str
preferred_username: str preferred_username: str
given_name: str given_name: str
family_name: str family_name: str
email: str email: str
db_id: int db_id: int
class OIDCUser(CustomBaseModel): class OIDCUser(CustomBaseModel):
first_name: str first_name: str
last_name: str last_name: str
email: str email: str
oidc_id: str oidc_id: str
class UserResponse(CustomBaseModel): class UserResponse(CustomBaseModel):
id: int id: int
first_name: str first_name: str
last_name: str last_name: str
email: str email: str
organisations: list[Optional[dict[str, str | int]]] organisations: list[Optional[dict[str, str | int]]]
groups: Optional[dict[str, list[dict[str, str | int]]]] = None groups: Optional[dict[str, list[dict[str, str | int]]]] = None
class UserPostInvitationRequest(OrgIDMixin): class UserPostInvitationRequest(OrgIDMixin):
user_email: EmailStr user_email: EmailStr
class UserPostInvitationAcceptRequest(CustomBaseModel): class UserPostInvitationAcceptRequest(CustomBaseModel):
jwt: str jwt: str
class UserGetSelfOrgsResponse(CustomBaseModel): class UserGetSelfOrgsResponse(CustomBaseModel):
organisations: list[OrgSchema] organisations: list[OrgSchema]
class UserPostInvitationResponse(CustomBaseModel): class UserPostInvitationResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
invited_email: EmailStr invited_email: EmailStr
class UserPostInvitationAcceptResponse(CustomBaseModel): class UserPostInvitationAcceptResponse(CustomBaseModel):
organisation: OrgSummary organisation: OrgSummary
user: UserSummary user: UserSummary

View file

@ -16,49 +16,49 @@ from src.user.models import User
async def add_user(db: Session, user_claims: dict[str, Any]) -> int: async def add_user(db: Session, user_claims: dict[str, Any]) -> int:
try: try:
valid_user = OIDCUser( valid_user = OIDCUser(
first_name=user_claims["given_name"], first_name=user_claims["given_name"],
last_name=user_claims["family_name"], last_name=user_claims["family_name"],
email=user_claims["email"], email=user_claims["email"],
oidc_id=user_claims["sub"], oidc_id=user_claims["sub"],
) )
except Exception as e: except Exception as e:
logging.exception(e) logging.exception(e)
raise UnprocessableContentException("Invalid or missing OIDC data") raise UnprocessableContentException("Invalid or missing OIDC data")
db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first() db_user = db.query(User).filter(User.oidc_id == valid_user.oidc_id).first()
if not db_user: if not db_user:
user_model = User(**valid_user.model_dump()) user_model = User(**valid_user.model_dump())
db.add(user_model) db.add(user_model)
user_id = user_model.id user_id = user_model.id
db.commit() db.commit()
return user_id return user_id
user_id = db_user.id user_id = db_user.id
db_user.first_name = valid_user.first_name db_user.first_name = valid_user.first_name
db_user.last_name = valid_user.last_name db_user.last_name = valid_user.last_name
db.commit() db.commit()
return user_id return user_id
async def send_invitation(user_email: str, org_name: str, org_id: int): async def send_invitation(user_email: str, org_name: str, org_id: int):
expiry_delta = timedelta(hours=24) expiry_delta = timedelta(hours=24)
expiry = datetime.now(timezone.utc) + expiry_delta expiry = datetime.now(timezone.utc) + expiry_delta
claims = { claims = {
"email": user_email, "email": user_email,
"org_id": org_id, "org_id": org_id,
"exp": expiry, "exp": expiry,
"type": "org_invite", "type": "org_invite",
} }
token = await generate_jwt(claims) token = await generate_jwt(claims)
subject = f"You have been invited to join {org_name}" subject = f"You have been invited to join {org_name}"
body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}" body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
await send_email( await send_email(
recipient=user_email, recipient=user_email,
subject=subject, subject=subject,
body=body, body=body,
) )

View file

@ -11,52 +11,56 @@ KEY = jwk.import_key(settings.SECRET_KEY.get_secret_value(), "oct")
async def generate_jwt(claims): async def generate_jwt(claims):
jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims) jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims)
return jwt_token return jwt_token
async def decode_jwt(encoded): async def decode_jwt(encoded):
try: try:
token = jwt.decode(encoded, key=KEY) token = jwt.decode(encoded, key=KEY)
return token.claims return token.claims
except errors.DecodeError: except errors.DecodeError:
raise UnauthorizedException("Invalid JWS") raise UnauthorizedException("Invalid JWS")
async def verify_email_token(user_model, token): async def verify_email_token(user_model, token):
email_claims = await decode_jwt(token) email_claims = await decode_jwt(token)
claimed_email = email_claims["email"] claimed_email = email_claims["email"]
expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc) expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc)
if expiry < datetime.now(timezone.utc): if expiry < datetime.now(timezone.utc):
raise UnauthorizedException("Invitation expired.") raise UnauthorizedException("Invitation expired.")
if user_model.email != claimed_email: if user_model.email != claimed_email:
raise ForbiddenException("The logged in user and email do not match.") raise ForbiddenException("The logged in user and email do not match.")
return email_claims return email_claims
async def send_email(recipient: str, subject: str, body: str): async def send_email(recipient: str, subject: str, body: str):
if settings.ENVIRONMENT.is_testing: if settings.ENVIRONMENT.is_testing:
return return
lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value()) lettermint = Lettermint(api_token=settings.LETTERMINT_API_TOKEN.get_secret_value())
if settings.ENVIRONMENT == "local": if settings.ENVIRONMENT == "local":
recipient = "ok@testing.lettermint.co" recipient = "ok@testing.lettermint.co"
try: try:
response = ( response = (
lettermint.email.from_("noreply@sr2.uk") lettermint.email.from_("noreply@sr2.uk")
.to(recipient) .to(recipient)
.subject(subject) .subject(subject)
.text(body) .text(body)
.send() .send()
) )
logging.info("Email sent to {} with subject {} (Status: {})".format(recipient, subject, response.status_code)) logging.info(
except ValidationError as e: "Email sent to {} with subject {} (Status: {})".format(
logging.exception(e) recipient, subject, response.status_code
)
)
except ValidationError as e:
logging.exception(e)

View file

@ -22,15 +22,15 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture() @pytest.fixture()
def db_session(): def db_session():
CustomBase.metadata.drop_all(bind=engine) CustomBase.metadata.drop_all(bind=engine)
CustomBase.metadata.create_all(bind=engine) CustomBase.metadata.create_all(bind=engine)
db = SessionLocal() db = SessionLocal()
try: try:
_seed(db) # extracted seeding logic into a plain function _seed(db) # extracted seeding logic into a plain function
yield db yield db
finally: finally:
db.rollback() db.rollback()
db.close() db.close()
@pytest.fixture @pytest.fixture
@ -83,176 +83,176 @@ async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def _seed(db): def _seed(db):
db.add( db.add(
User( User(
email="admin@test.com", email="admin@test.com",
first_name="Admin", first_name="Admin",
last_name="Test", last_name="Test",
oidc_id="abcd-efgh-ijkl-mnop", oidc_id="abcd-efgh-ijkl-mnop",
) )
) )
db.add( db.add(
User( User(
email="user@orgone.com", email="user@orgone.com",
first_name="User", first_name="User",
last_name="Test", last_name="Test",
oidc_id="abcd-efgh-ijkl-qwer", oidc_id="abcd-efgh-ijkl-qwer",
) )
) )
db.add( db.add(
User( User(
email="root@orgtwo.com", email="root@orgtwo.com",
first_name="Root", first_name="Root",
last_name="Test", last_name="Test",
oidc_id="abcd-efgh-ijkl-hjkl", oidc_id="abcd-efgh-ijkl-hjkl",
) )
) )
db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927")) db.add(Contact(org_id=1, email="billing@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927")) db.add(Contact(org_id=1, email="owner@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927")) db.add(Contact(org_id=1, email="security@orgone.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927")) db.add(Contact(org_id=2, email="billing@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927")) db.add(Contact(org_id=2, email="owner@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927")) db.add(Contact(org_id=2, email="security@orgtwo.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927")) db.add(Contact(org_id=3, email="billing@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927")) db.add(Contact(org_id=3, email="owner@orgthree.com", phonenumber="07521539927"))
db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927")) db.add(Contact(org_id=3, email="security@orgthree.com", phonenumber="07521539927"))
db.flush() db.flush()
db.add( db.add(
Org( Org(
name="Org One", name="Org One",
root_user_id=1, root_user_id=1,
billing_contact_id=1, billing_contact_id=1,
owner_contact_id=2, owner_contact_id=2,
security_contact_id=3, security_contact_id=3,
status="approved", status="approved",
intake_questionnaire={ intake_questionnaire={
"metadata": {"version": 0, "submission_date": None}, "metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"}, "questions": {"question_two": "answer two"},
}, },
) )
) )
db.add( db.add(
Org( Org(
name="Org Two", name="Org Two",
root_user_id=3, root_user_id=3,
billing_contact_id=4, billing_contact_id=4,
owner_contact_id=5, owner_contact_id=5,
security_contact_id=6, security_contact_id=6,
status="approved", status="approved",
intake_questionnaire={ intake_questionnaire={
"metadata": {"version": 0, "submission_date": None}, "metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"}, "questions": {"question_two": "answer two"},
}, },
) )
) )
db.add( db.add(
Org( Org(
name="Org Three", name="Org Three",
root_user_id=1, root_user_id=1,
billing_contact_id=7, billing_contact_id=7,
owner_contact_id=8, owner_contact_id=8,
security_contact_id=9, security_contact_id=9,
status="partial", status="partial",
intake_questionnaire={ intake_questionnaire={
"metadata": {"version": 0, "submission_date": None}, "metadata": {"version": 0, "submission_date": None},
"questions": {"question_two": "answer two"}, "questions": {"question_two": "answer two"},
}, },
) )
) )
db.add(OrgUsers(org_id=1, user_id=2)) db.add(OrgUsers(org_id=1, user_id=2))
db.add(Service(name="Test Service", api_key="123456789")) db.add(Service(name="Test Service", api_key="123456789"))
db.add(Permission(service_id=1, resource="test_resource", action="read")) db.add(Permission(service_id=1, resource="test_resource", action="read"))
db.add(Permission(service_id=1, resource="test_resource", action="move")) db.add(Permission(service_id=1, resource="test_resource", action="move"))
db.add(Permission(service_id=1, resource="test_resource", action="delete")) db.add(Permission(service_id=1, resource="test_resource", action="delete"))
db.add(OrgPermissions(org_id=1, permission_id=1)) db.add(OrgPermissions(org_id=1, permission_id=1))
db.add(OrgPermissions(org_id=1, permission_id=2)) db.add(OrgPermissions(org_id=1, permission_id=2))
db.add(Group(name="Org One Group", org_id=1)) db.add(Group(name="Org One Group", org_id=1))
db.add(Group(name="Org Two Group", org_id=2)) db.add(Group(name="Org Two Group", org_id=2))
db.add(Group(name="Org One Group Two", org_id=1)) db.add(Group(name="Org One Group Two", org_id=1))
db.flush() db.flush()
group_model = db.get(Group, 1) group_model = db.get(Group, 1)
perm_model = db.get(Permission, 1) perm_model = db.get(Permission, 1)
group_model.permission_rel.append(perm_model) group_model.permission_rel.append(perm_model)
user_model = db.get(User, 1) user_model = db.get(User, 1)
org_model = db.get(Org, 1) org_model = db.get(Org, 1)
org_model.user_rel.append(user_model) org_model.user_rel.append(user_model)
org_model.group_rel.append(group_model) org_model.group_rel.append(group_model)
db.flush() db.flush()
group_model.user_rel.append(user_model) group_model.user_rel.append(user_model)
db.commit() db.commit()
def generate_query_and_status(params) -> list[tuple[str, int]]: def generate_query_and_status(params) -> list[tuple[str, int]]:
possible_values = [0, -1, 42, "banana", ""] possible_values = [0, -1, 42, "banana", ""]
defaults = [f"{param}=1" for param in params] defaults = [f"{param}=1" for param in params]
# Missing params # Missing params
query_list = [ query_list = [
"&".join(combo) "&".join(combo)
for r in range(len(defaults) + 1) for r in range(len(defaults) + 1)
for combo in combinations(defaults, r) for combo in combinations(defaults, r)
] ]
# Complete query as default for invalid checks # Complete query as default for invalid checks
default_query = query_list.pop(-1) default_query = query_list.pop(-1)
# Checks for each param being invalid # Checks for each param being invalid
for param in params: for param in params:
for value in possible_values: for value in possible_values:
new_value = f"&{param}={value}" new_value = f"&{param}={value}"
query_list.append(default_query.replace(f"{param}=1", new_value)) query_list.append(default_query.replace(f"{param}=1", new_value))
query_and_status = [] query_and_status = []
# Assign expected status # Assign expected status
for query in query_list: for query in query_list:
# ID 42 is used to represent a non-existent entry. So it should 404. # ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if "42" in query else 422 status = 404 if "42" in query else 422
# Remove leading "&" if present # Remove leading "&" if present
query = query if len(query) > 1 and query[0] != "&" else query[1:] query = query if len(query) > 1 and query[0] != "&" else query[1:]
query_and_status.append((query, status)) query_and_status.append((query, status))
return query_and_status return query_and_status
def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]: def generate_body_and_status(params: dict[str, str]) -> list[tuple[dict, int]]:
possible_values_int = [0, -1, 42, "banana", ""] possible_values_int = [0, -1, 42, "banana", ""]
possible_values_str = [0, "", "a"] possible_values_str = [0, "", "a"]
defaults = [{param: 1 for param in params.keys()}] defaults = [{param: 1 for param in params.keys()}]
# Missing params # Missing params
body_list = [ body_list = [
{key: ("valid string" if params[key] == "str" else 1) for key in combo} {key: ("valid string" if params[key] == "str" else 1) for key in combo}
for r in range(len(defaults[0].keys()) + 1) for r in range(len(defaults[0].keys()) + 1)
for combo in combinations(defaults[0].keys(), r) for combo in combinations(defaults[0].keys(), r)
] ]
# Complete body as default for generating invalid checks # Complete body as default for generating invalid checks
default_body = body_list.pop(-1) default_body = body_list.pop(-1)
# Generates checks for each param being invalid # Generates checks for each param being invalid
for param, typ in params.items(): for param, typ in params.items():
if typ == "int": if typ == "int":
possible_values = possible_values_int possible_values = possible_values_int
elif typ == "str": elif typ == "str":
possible_values = possible_values_str possible_values = possible_values_str
else: else:
raise TypeError(f"Unknown type {typ}") raise TypeError(f"Unknown type {typ}")
for value in possible_values: for value in possible_values:
new_record = default_body.copy() new_record = default_body.copy()
new_record[param] = value new_record[param] = value
body_list.append(new_record) body_list.append(new_record)
body_and_status = [] body_and_status = []
# Assign expected status # Assign expected status
for body in body_list: for body in body_list:
# ID 42 is used to represent a non-existent entry. So it should 404. # ID 42 is used to represent a non-existent entry. So it should 404.
status = 404 if 42 in body.values() else 422 status = 404 if 42 in body.values() else 422
body_and_status.append((body, status)) body_and_status.append((body, status))
return body_and_status return body_and_status
def get_testable_routes(): def get_testable_routes():

View file

@ -8,181 +8,181 @@ import pytest
from httpx import AsyncClient from httpx import AsyncClient
pytestmark = [ pytestmark = [
pytest.mark.auth, pytest.mark.auth,
pytest.mark.preapproval, pytest.mark.preapproval,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_auth_approval(no_su_client: AsyncClient): async def test_get_org_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org?org_id=3") resp = await no_su_client.get("/org?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient): async def test_patch_org_questionnaire_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch(
"/org/questionnaire", "/org/questionnaire",
json={ json={
"organisation_id": 3, "organisation_id": 3,
"intake_questionnaire": { "intake_questionnaire": {
"question_one": "new answer one", "question_one": "new answer one",
"question_two": None, "question_two": None,
"question_three": None, "question_three": None,
}, },
"partial": True, "partial": True,
}, },
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_users_auth_approval(no_su_client: AsyncClient): async def test_get_org_users_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/users?org_id=3") resp = await no_su_client.get("/org/users?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_groups_auth_approval(no_su_client: AsyncClient): async def test_get_org_groups_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/groups?org_id=3") resp = await no_su_client.get("/org/groups?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_contact_auth_approval(no_su_client: AsyncClient): async def test_get_org_contact_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing") resp = await no_su_client.get("/org/contact?org_id=3&contact_type=billing")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient): async def test_patch_org_contact_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch(
"/org/contact", "/org/contact",
json={ json={
"organisation_id": 3, "organisation_id": 3,
"contact_type": "billing", "contact_type": "billing",
"email": "user@example.com", "email": "user@example.com",
}, },
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_service_auth_approval(no_su_client: AsyncClient): async def test_get_service_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/service?org_id=3") resp = await no_su_client.get("/service?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_group_permissions_auth_approval(no_su_client: AsyncClient): async def test_get_iam_group_permissions_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1") resp = await no_su_client.get("/iam/group/permissions?org_id=3&group_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_group_users_auth_approval(no_su_client: AsyncClient): async def test_get_iam_group_users_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1") resp = await no_su_client.get("/iam/group/users?org_id=3&group_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_group_auth_approval(no_su_client: AsyncClient): async def test_post_iam_group_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.post( resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 3} "/iam/group", json={"name": "New Group", "organisation_id": 3}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient): async def test_put_iam_group_permission_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.put( resp = await no_su_client.put(
"/iam/group/permission", "/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 3}, json={"permission_id": 1, "group_id": 2, "organisation_id": 3},
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient): async def test_put_iam_group_user_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.put( resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3} "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 3}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_permissions_auth_approval(no_su_client: AsyncClient): async def test_get_iam_permissions_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/permissions?org_id=3") resp = await no_su_client.get("/iam/permissions?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_permissions_search_auth_approval(no_su_client: AsyncClient): async def test_post_iam_permissions_search_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.post( resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 3, "action": "read"} "/iam/permissions/search", json={"organisation_id": 3, "action": "read"}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_org_user_auth_approval(no_su_client: AsyncClient): async def test_delete_org_user_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.delete("/org/user?org_id=3&user_id=1") resp = await no_su_client.delete("/org/user?org_id=3&user_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_preapproval_auth_approval(no_su_client: AsyncClient): async def test_delete_preapproval_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.delete("/org/self?org_id=3") resp = await no_su_client.delete("/org/self?org_id=3")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 204 assert resp.status_code == 204
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_user_invitation_auth_approval(no_su_client: AsyncClient): async def test_post_user_invitation_auth_approval(no_su_client: AsyncClient):
body = {"user_email": "admin@test.com", "organisation_id": 3} body = {"user_email": "admin@test.com", "organisation_id": 3}
resp = await no_su_client.post("/user/invitation", json=body) resp = await no_su_client.post("/user/invitation", json=body)
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_group_permissions_auth_approval(no_su_client: AsyncClient): async def test_delete_group_permissions_auth_approval(no_su_client: AsyncClient):
resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1") resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_group_users_success(no_su_client: AsyncClient): async def test_delete_group_users_success(no_su_client: AsyncClient):
resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1") resp = await no_su_client.delete("/iam/group/user?org_id=3&group_id=1&user_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_group_user_invitation_success(no_su_client: AsyncClient): async def test_put_group_user_invitation_success(no_su_client: AsyncClient):
body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1} body = {"user_email": "admin@test.com", "organisation_id": 3, "group_id": 1}
resp = await no_su_client.put("/iam/group/user/invitation", json=body) resp = await no_su_client.put("/iam/group/user/invitation", json=body)
assert resp.status_code != 422 assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"] assert "has not been approved." in resp.json()["detail"]

View file

@ -5,14 +5,14 @@ from httpx import AsyncClient
pytestmark = [ pytestmark = [
pytest.mark.auth, pytest.mark.auth,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_auth_root_su(default_client: AsyncClient): async def test_get_org_auth_root_su(default_client: AsyncClient):
# If a super admin can access a resource when not the root user # If a super admin can access a resource when not the root user
resp = await default_client.get("/org?org_id=2") resp = await default_client.get("/org?org_id=2")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json()["organisations"][0]["name"] == "Org Two" assert resp.json()["organisations"][0]["name"] == "Org Two"

View file

@ -7,147 +7,147 @@ import pytest
from httpx import AsyncClient from httpx import AsyncClient
pytestmark = [ pytestmark = [
pytest.mark.auth, pytest.mark.auth,
pytest.mark.root_user, pytest.mark.root_user,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_auth_root(no_su_client: AsyncClient): async def test_get_org_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org?org_id=2") resp = await no_su_client.get("/org?org_id=2")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient): async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch(
"/org/questionnaire", "/org/questionnaire",
json={ json={
"organisation_id": 2, "organisation_id": 2,
"intake_questionnaire": { "intake_questionnaire": {
"question_one": "new answer one", "question_one": "new answer one",
"question_two": None, "question_two": None,
"question_three": None, "question_three": None,
}, },
"partial": True, "partial": True,
}, },
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_users_auth_root(no_su_client: AsyncClient): async def test_get_org_users_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/users?org_id=2") resp = await no_su_client.get("/org/users?org_id=2")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_groups_auth_root(no_su_client: AsyncClient): async def test_get_org_groups_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/groups?org_id=2") resp = await no_su_client.get("/org/groups?org_id=2")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_contact_auth_root(no_su_client: AsyncClient): async def test_get_org_contact_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing") resp = await no_su_client.get("/org/contact?org_id=2&contact_type=billing")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_contact_auth_root(no_su_client: AsyncClient): async def test_patch_org_contact_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch(
"/org/contact", "/org/contact",
json={ json={
"organisation_id": 2, "organisation_id": 2,
"contact_type": "billing", "contact_type": "billing",
"email": "user@example.com", "email": "user@example.com",
}, },
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_service_auth_root(no_su_client: AsyncClient): async def test_get_service_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/service?org_id=2") resp = await no_su_client.get("/service?org_id=2")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_group_permissions_auth_root(no_su_client: AsyncClient): async def test_get_iam_group_permissions_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1") resp = await no_su_client.get("/iam/group/permissions?org_id=2&group_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient): async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1") resp = await no_su_client.get("/iam/group/users?org_id=2&group_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_group_auth_root(no_su_client: AsyncClient): async def test_post_iam_group_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.post( resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 2} "/iam/group", json={"name": "New Group", "organisation_id": 2}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient): async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.put( resp = await no_su_client.put(
"/iam/group/permission", "/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 2}, json={"permission_id": 1, "group_id": 2, "organisation_id": 2},
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_iam_group_user_auth_root( async def test_put_iam_group_user_auth_root(
no_su_client: AsyncClient, no_su_client: AsyncClient,
): ):
resp = await no_su_client.put( resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2} "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient): async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.get("/iam/permissions?org_id=2") resp = await no_su_client.get("/iam/permissions?org_id=2")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient): async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.post( resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 2, "action": "read"} "/iam/permissions/search", json={"organisation_id": 2, "action": "read"}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be the org's root user" in resp.json()["detail"] assert "Must be the org's root user" in resp.json()["detail"]

View file

@ -7,69 +7,69 @@ import pytest
from httpx import AsyncClient from httpx import AsyncClient
pytestmark = [ pytestmark = [
pytest.mark.auth, pytest.mark.auth,
pytest.mark.super_admin, pytest.mark.super_admin,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_user_auth_su(no_su_client: AsyncClient): async def test_get_user_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.get("/user?user_id=1") resp = await no_su_client.get("/user?user_id=1")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_status_auth_su(no_su_client: AsyncClient): async def test_patch_org_status_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch(
"/org/status", json={"organisation_id": 1, "status": "submitted"} "/org/status", json={"organisation_id": 1, "status": "submitted"}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient): async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch( resp = await no_su_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2} "/org/root_user", json={"organisation_id": 1, "user_id": 2}
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_service_key_auth_su(no_su_client: AsyncClient): async def test_patch_service_key_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch("/service/key", json={"service_id": 1}) resp = await no_su_client.patch("/service/key", json={"service_id": 1})
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_service_auth_su(no_su_client: AsyncClient): async def test_post_service_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.post("/service", json={"name": "New Test Service"}) resp = await no_su_client.post("/service", json={"name": "New Test Service"})
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_perm_auth_su(no_su_client: AsyncClient): async def test_post_perm_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.post( resp = await no_su_client.post(
"/iam/permission", "/iam/permission",
json={"service_id": 1, "resource": "test_resource", "action": "create"}, json={"service_id": 1, "resource": "test_resource", "action": "create"},
) )
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert resp.json()["detail"] == "Must be super admin" assert resp.json()["detail"] == "Must be super admin"
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_user_auth_su(no_su_client: AsyncClient): async def test_post_org_user_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2}) resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2})
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 403 assert resp.status_code == 403
assert "Must be super admin" in resp.json()["detail"] assert "Must be super admin" in resp.json()["detail"]

View file

@ -7,22 +7,22 @@ from httpx import AsyncClient
pytestmark = [ pytestmark = [
pytest.mark.auth, pytest.mark.auth,
pytest.mark.user, pytest.mark.user,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_self_db_auth_user(no_user_client: AsyncClient): async def test_get_self_db_auth_user(no_user_client: AsyncClient):
resp = await no_user_client.get("/user/self/db") resp = await no_user_client.get("/user/self/db")
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated" assert resp.json()["detail"] == "Not authenticated"
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_success_auth_user(no_user_client: AsyncClient): async def test_post_org_success_auth_user(no_user_client: AsyncClient):
resp = await no_user_client.post("/org", json={"name": "New Test Org"}) resp = await no_user_client.post("/org", json={"name": "New Test Org"})
assert resp.status_code != 422 assert resp.status_code != 422
assert resp.status_code == 401 assert resp.status_code == 401
assert resp.json()["detail"] == "Not authenticated" assert resp.json()["detail"] == "Not authenticated"

View file

@ -4,7 +4,7 @@ from httpx import AsyncClient
@pytest.mark.anyio @pytest.mark.anyio
async def test_healthcheck(default_client: AsyncClient): async def test_healthcheck(default_client: AsyncClient):
resp = await default_client.get("/healthcheck") resp = await default_client.get("/healthcheck")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.json() == {"status": "ok"} assert resp.json() == {"status": "ok"}

File diff suppressed because it is too large Load diff

View file

@ -9,506 +9,506 @@ from .conftest import generate_query_and_status
pytestmark = [ pytestmark = [
pytest.mark.org_module, pytest.mark.org_module,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_success(default_client: AsyncClient): async def test_get_org_success(default_client: AsyncClient):
resp = await default_client.get("/org?org_id=1") resp = await default_client.get("/org?org_id=1")
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
org = data["organisations"][0] org = data["organisations"][0]
assert isinstance(org, dict) assert isinstance(org, dict)
assert org["organisation_id"] == 1 assert org["organisation_id"] == 1
assert org["name"] == "Org One" assert org["name"] == "Org One"
assert org["status"] == "approved" assert org["status"] == "approved"
assert org["root_user_email"] == "admin@test.com" assert org["root_user_email"] == "admin@test.com"
assert "intake_questionnaire" in org assert "intake_questionnaire" in org
assert isinstance(org["intake_questionnaire"], dict) assert isinstance(org["intake_questionnaire"], dict)
assert org["billing_contact"]["email"] == "billing@orgone.com" assert org["billing_contact"]["email"] == "billing@orgone.com"
assert org["owner_contact"]["email"] == "owner@orgone.com" assert org["owner_contact"]["email"] == "owner@orgone.com"
assert org["security_contact"]["email"] == "security@orgone.com" assert org["security_contact"]["email"] == "security@orgone.com"
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_status_checks( async def test_get_org_status_checks(
default_client: AsyncClient, query: str, expected_status: int default_client: AsyncClient, query: str, expected_status: int
): ):
resp = await default_client.get(f"/org?{query}") resp = await default_client.get(f"/org?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_success(default_client: AsyncClient): async def test_post_org_success(default_client: AsyncClient):
resp = await default_client.post("/org", json={"name": "New Test Org"}) resp = await default_client.post("/org", json={"name": "New Test Org"})
data = resp.json() data = resp.json()
assert resp.status_code == 201 assert resp.status_code == 201
assert data["name"] == "New Test Org" assert data["name"] == "New Test Org"
assert data["status"] == "partial" assert data["status"] == "partial"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"name": 42}, 422), ({"name": 42}, 422),
({}, 422), ({}, 422),
({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422), ({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_status_checks( async def test_post_org_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.post("/org", json=body) resp = await default_client.post("/org", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient): async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient):
resp = await default_client.patch( resp = await default_client.patch(
"/org/questionnaire", "/org/questionnaire",
json={ json={
"organisation_id": 3, "organisation_id": 3,
"intake_questionnaire": { "intake_questionnaire": {
"question_one": "new answer one", "question_one": "new answer one",
"question_two": None, "question_two": None,
"question_three": None, "question_three": None,
}, },
"partial": True, "partial": True,
}, },
) )
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert data["name"] == "Org Three" assert data["name"] == "Org Three"
assert data["status"] == "partial" assert data["status"] == "partial"
assert "intake_questionnaire" in data assert "intake_questionnaire" in data
assert isinstance(data["intake_questionnaire"], dict) assert isinstance(data["intake_questionnaire"], dict)
metadata = data["intake_questionnaire"]["metadata"] metadata = data["intake_questionnaire"]["metadata"]
assert metadata["version"] == 0 assert metadata["version"] == 0
assert metadata["submission_date"] is None assert metadata["submission_date"] is None
questions = data["intake_questionnaire"]["questions"] questions = data["intake_questionnaire"]["questions"]
assert questions["question_one"] == "new answer one" assert questions["question_one"] == "new answer one"
assert questions["question_two"] == "answer two" assert questions["question_two"] == "answer two"
assert questions["question_three"] is None assert questions["question_three"] is None
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 42}, 404), ({"organisation_id": 42}, 404),
({"organisation_id": "Org One"}, 422), ({"organisation_id": "Org One"}, 422),
({"organisation_id": ""}, 422), ({"organisation_id": ""}, 422),
({}, 422), ({}, 422),
( (
{ {
"organisation_id": "1", "organisation_id": "1",
"intake_questionnaire": {"question_one": 42}, "intake_questionnaire": {"question_one": 42},
"partial": True, "partial": True,
}, },
422, 422,
), ),
( (
{"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}}, {"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}},
422, 422,
), ),
( (
{ {
"organisation_id": "1", "organisation_id": "1",
"intake_questionnaire": {"question_one": "valid"}, "intake_questionnaire": {"question_one": "valid"},
"partial": 42, "partial": 42,
}, },
422, 422,
), ),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_questionnaire_partial_status_checks( async def test_patch_questionnaire_partial_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.patch("/org/questionnaire", json=body) resp = await default_client.patch("/org/questionnaire", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient): async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient):
resp = await default_client.patch( resp = await default_client.patch(
"/org/questionnaire", "/org/questionnaire",
json={ json={
"organisation_id": 3, "organisation_id": 3,
"intake_questionnaire": { "intake_questionnaire": {
"question_one": "new answer one", "question_one": "new answer one",
"question_two": None, "question_two": None,
"question_three": None, "question_three": None,
}, },
"partial": False, "partial": False,
}, },
) )
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert data["name"] == "Org Three" assert data["name"] == "Org Three"
assert data["status"] == "submitted" assert data["status"] == "submitted"
assert "intake_questionnaire" in data assert "intake_questionnaire" in data
assert isinstance(data["intake_questionnaire"], dict) assert isinstance(data["intake_questionnaire"], dict)
metadata = data["intake_questionnaire"]["metadata"] metadata = data["intake_questionnaire"]["metadata"]
assert metadata["version"] == 0 assert metadata["version"] == 0
assert metadata["submission_date"] is not None assert metadata["submission_date"] is not None
questions = data["intake_questionnaire"]["questions"] questions = data["intake_questionnaire"]["questions"]
assert questions["question_one"] == "new answer one" assert questions["question_one"] == "new answer one"
assert questions["question_two"] == "answer two" assert questions["question_two"] == "answer two"
assert questions["question_three"] is None assert questions["question_three"] is None
@pytest.mark.parametrize( @pytest.mark.parametrize(
"status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"] "status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"]
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_status_success(default_client: AsyncClient, status: str): async def test_patch_org_status_success(default_client: AsyncClient, status: str):
resp = await default_client.patch( resp = await default_client.patch(
"/org/status", json={"organisation_id": 1, "status": status} "/org/status", json={"organisation_id": 1, "status": status}
) )
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert data["name"] == "Org One" assert data["name"] == "Org One"
assert data["status"] == status assert data["status"] == status
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 42}, 404), ({"organisation_id": 42}, 404),
({"organisation_id": "Org One"}, 422), ({"organisation_id": "Org One"}, 422),
({"organisation_id": ""}, 422), ({"organisation_id": ""}, 422),
({}, 422), ({}, 422),
({"organisation_id": "1", "status": True}, 422), ({"organisation_id": "1", "status": True}, 422),
({"organisation_id": "1", "status": 42}, 422), ({"organisation_id": "1", "status": 42}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_status_status_checks( async def test_patch_org_status_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.patch("/org/status", json=body) resp = await default_client.patch("/org/status", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_users_success(default_client: AsyncClient): async def test_get_org_users_success(default_client: AsyncClient):
resp = await default_client.get("/org/users?org_id=1") resp = await default_client.get("/org/users?org_id=1")
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert "users" in data assert "users" in data
assert isinstance(data["users"], list) assert isinstance(data["users"], list)
assert len(data["users"]) == 2 assert len(data["users"]) == 2
user = data["users"][0] user = data["users"][0]
assert isinstance(user, dict) assert isinstance(user, dict)
assert user["email"] == "admin@test.com" assert user["email"] == "admin@test.com"
assert user["id"] == 1 assert user["id"] == 1
assert "organisation" in data assert "organisation" in data
assert data["organisation"]["name"] == "Org One" assert data["organisation"]["name"] == "Org One"
assert data["organisation"]["id"] == 1 assert data["organisation"]["id"] == 1
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_users_status_checks( async def test_get_org_users_status_checks(
default_client: AsyncClient, query: str, expected_status: int default_client: AsyncClient, query: str, expected_status: int
): ):
resp = await default_client.get(f"/org/users?{query}") resp = await default_client.get(f"/org/users?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_user_success(default_client: AsyncClient): async def test_post_org_user_success(default_client: AsyncClient):
resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3}) resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3})
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert "organisation" in data assert "organisation" in data
assert isinstance(data["organisation"], dict) assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1 assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One" assert data["organisation"]["name"] == "Org One"
assert "users" in data assert "users" in data
assert isinstance(data["users"], list) assert isinstance(data["users"], list)
assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1 assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 42}, 404), ({"organisation_id": 42}, 404),
({}, 422), ({}, 422),
({"organisation_id": 1, "user_id": "id"}, 422), ({"organisation_id": 1, "user_id": "id"}, 422),
({"user_id": 2}, 422), ({"user_id": 2}, 422),
({"organisation_id": 1, "user_id": 42}, 404), ({"organisation_id": 1, "user_id": 42}, 404),
({"organisation_id": 1, "user_id": 1}, 409), ({"organisation_id": 1, "user_id": 1}, 409),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_org_user_status_checks( async def test_post_org_user_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.post("/org/user", json=body) resp = await default_client.post("/org/user", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_root_user_success(default_client: AsyncClient): async def test_patch_org_root_user_success(default_client: AsyncClient):
resp = await default_client.patch( resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2} "/org/root_user", json={"organisation_id": 1, "user_id": 2}
) )
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert data["name"] == "Org One" assert data["name"] == "Org One"
assert data["root_user_email"] == "user@orgone.com" assert data["root_user_email"] == "user@orgone.com"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 42, "user_id": 2}, 404), ({"organisation_id": 42, "user_id": 2}, 404),
({"organisation_id": "Org One", "user_id": 2}, 422), ({"organisation_id": "Org One", "user_id": 2}, 422),
({"organisation_id": "", "user_id": 2}, 422), ({"organisation_id": "", "user_id": 2}, 422),
({}, 422), ({}, 422),
({"user_id": 2}, 422), ({"user_id": 2}, 422),
({"user_id": 42}, 404), ({"user_id": 42}, 404),
({"organisation_id": 1, "user_id": "Test User"}, 422), ({"organisation_id": 1, "user_id": "Test User"}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_root_user_status_checks( async def test_patch_root_user_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.patch("/org/root_user", json=body) resp = await default_client.patch("/org/root_user", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_root_user_non_member(default_client: AsyncClient): async def test_patch_org_root_user_non_member(default_client: AsyncClient):
resp = await default_client.patch( resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 3} "/org/root_user", json={"organisation_id": 1, "user_id": 3}
) )
data = resp.json() data = resp.json()
assert resp.status_code == 422 assert resp.status_code == 422
assert data["detail"] == "This user does not belong to your organisation." assert data["detail"] == "This user does not belong to your organisation."
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_groups_success(default_client: AsyncClient): async def test_get_org_groups_success(default_client: AsyncClient):
resp = await default_client.get("/org/groups?org_id=1") resp = await default_client.get("/org/groups?org_id=1")
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert "organisation" in data assert "organisation" in data
assert isinstance(data["organisation"], dict) assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1 assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One" assert data["organisation"]["name"] == "Org One"
assert "groups" in data assert "groups" in data
assert isinstance(data["groups"], list) assert isinstance(data["groups"], list)
group = data["groups"][0] group = data["groups"][0]
assert isinstance(group, dict) assert isinstance(group, dict)
assert group["id"] == 1 assert group["id"] == 1
assert group["name"] == "Org One Group" assert group["name"] == "Org One Group"
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_groups_status_checks( async def test_get_org_groups_status_checks(
default_client: AsyncClient, query: str, expected_status: int default_client: AsyncClient, query: str, expected_status: int
): ):
resp = await default_client.get(f"/org/groups?{query}") resp = await default_client.get(f"/org/groups?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.parametrize("contact_type", ["billing", "security", "owner"]) @pytest.mark.parametrize("contact_type", ["billing", "security", "owner"])
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str): async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str):
resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}") resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}")
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert "organisation" in data assert "organisation" in data
assert data["organisation"]["id"] == 1 assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One" assert data["organisation"]["name"] == "Org One"
attributes = [ attributes = [
"email", "email",
"first_name", "first_name",
"last_name", "last_name",
"phonenumber", "phonenumber",
"vat_number", "vat_number",
"address", "address",
] ]
for attribute in attributes: for attribute in attributes:
assert attribute in data["contact"] assert attribute in data["contact"]
address_attributes = [ address_attributes = [
"post_office_box_number", "post_office_box_number",
"street_address", "street_address",
"street_address_line_2", "street_address_line_2",
"locality", "locality",
"address_region", "address_region",
"country_code", "country_code",
"postal_code", "postal_code",
] ]
for attribute in address_attributes: for attribute in address_attributes:
assert attribute in data["contact"]["address"] assert attribute in data["contact"]["address"]
@pytest.mark.parametrize( @pytest.mark.parametrize(
"query, expected_status", "query, expected_status",
[ [
("org_id=42&contact_type=billing", 404), ("org_id=42&contact_type=billing", 404),
("org_id=banana&contact_type=billing", 422), ("org_id=banana&contact_type=billing", 422),
("", 422), ("", 422),
("org_id=1&contact_type=contact", 422), ("org_id=1&contact_type=contact", 422),
("contact_type=billing", 422), ("contact_type=billing", 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_org_contact_status_checks( async def test_get_org_contact_status_checks(
default_client: AsyncClient, query: str, expected_status: int default_client: AsyncClient, query: str, expected_status: int
): ):
resp = await default_client.get(f"/org/contact?{query}") resp = await default_client.get(f"/org/contact?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.parametrize( @pytest.mark.parametrize(
"key, value", "key, value",
[ [
("email", "user@example.com"), ("email", "user@example.com"),
("first_name", "John"), ("first_name", "John"),
("last_name", "Doe"), ("last_name", "Doe"),
("phonenumber", "+441234567890"), ("phonenumber", "+441234567890"),
("vat_number", "GB123456789"), ("vat_number", "GB123456789"),
("post_office_box_number", "PO Box 123"), ("post_office_box_number", "PO Box 123"),
("street_address", "123 Example Street"), ("street_address", "123 Example Street"),
("street_address_line_2", "Suite 4B"), ("street_address_line_2", "Suite 4B"),
("locality", "Glasgow"), ("locality", "Glasgow"),
("address_region", "Glasgow City"), ("address_region", "Glasgow City"),
("country_code", "GB"), ("country_code", "GB"),
("postal_code", "G1 1AA"), ("postal_code", "G1 1AA"),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str): async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str):
resp = await default_client.patch( resp = await default_client.patch(
"/org/contact", "/org/contact",
json={"organisation_id": 1, "contact_type": "billing", key: value}, json={"organisation_id": 1, "contact_type": "billing", key: value},
) )
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert "organisation" in data assert "organisation" in data
assert data["organisation"]["id"] == 1 assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One" assert data["organisation"]["name"] == "Org One"
attributes = [ attributes = [
"email", "email",
"first_name", "first_name",
"last_name", "last_name",
"phonenumber", "phonenumber",
"vat_number", "vat_number",
"address", "address",
] ]
for attribute in attributes: for attribute in attributes:
assert attribute in data["contact"] assert attribute in data["contact"]
address_attributes = [ address_attributes = [
"post_office_box_number", "post_office_box_number",
"street_address", "street_address",
"street_address_line_2", "street_address_line_2",
"locality", "locality",
"address_region", "address_region",
"country_code", "country_code",
"postal_code", "postal_code",
] ]
for attribute in address_attributes: for attribute in address_attributes:
assert attribute in data["contact"]["address"] assert attribute in data["contact"]["address"]
if key in data["contact"]: if key in data["contact"]:
assert data["contact"][key] == value assert data["contact"][key] == value
elif key in data["contact"]["address"]: elif key in data["contact"]["address"]:
assert data["contact"]["address"][key] == value assert data["contact"]["address"][key] == value
else: else:
pytest.fail(f"Invalid contact key: {key}") pytest.fail(f"Invalid contact key: {key}")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 42, "contact_type": "billing"}, 404), ({"organisation_id": 42, "contact_type": "billing"}, 404),
({"organisation_id": 1, "contact_type": "security"}, 200), ({"organisation_id": 1, "contact_type": "security"}, 200),
({"organisation_id": 1, "contact_type": "owner"}, 200), ({"organisation_id": 1, "contact_type": "owner"}, 200),
({"organisation_id": "Org One", "contact_type": "billing"}, 422), ({"organisation_id": "Org One", "contact_type": "billing"}, 422),
({"organisation_id": "", "contact_type": "billing"}, 422), ({"organisation_id": "", "contact_type": "billing"}, 422),
({}, 422), ({}, 422),
({"organisation_id": 1, "contact_type": "not_real"}, 422), ({"organisation_id": 1, "contact_type": "not_real"}, 422),
({"organisation_id": 1, "contact_type": 42}, 422), ({"organisation_id": 1, "contact_type": 42}, 422),
({"organisation_id": 1, "contact_type": ""}, 422), ({"organisation_id": 1, "contact_type": ""}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_org_contact_status_checks( async def test_patch_org_contact_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.patch("/org/contact", json=body) resp = await default_client.patch("/org/contact", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_org_success(default_client: AsyncClient): async def test_delete_org_success(default_client: AsyncClient):
resp = await default_client.delete("/org?org_id=1") resp = await default_client.delete("/org?org_id=1")
assert resp.status_code == 204 assert resp.status_code == 204
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_org_users_success(default_client: AsyncClient): async def test_delete_org_users_success(default_client: AsyncClient):
resp = await default_client.delete("/org/user?org_id=1&user_id=2") resp = await default_client.delete("/org/user?org_id=1&user_id=2")
assert resp.status_code == 204 assert resp.status_code == 204
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_preapproval_org_success(default_client: AsyncClient): async def test_delete_preapproval_org_success(default_client: AsyncClient):
resp = await default_client.delete("/org/self?org_id=3") resp = await default_client.delete("/org/self?org_id=3")
assert resp.status_code == 204 assert resp.status_code == 204

View file

@ -8,90 +8,90 @@ from httpx import AsyncClient
from .conftest import generate_query_and_status, generate_body_and_status from .conftest import generate_query_and_status, generate_body_and_status
pytestmark = [ pytestmark = [
pytest.mark.service_module, pytest.mark.service_module,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_services_success(default_client: AsyncClient): async def test_get_services_success(default_client: AsyncClient):
resp = await default_client.get("/service?org_id=1") resp = await default_client.get("/service?org_id=1")
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert "services" in data assert "services" in data
assert isinstance(data["services"], list) assert isinstance(data["services"], list)
assert data["services"][0]["id"] == 1 assert data["services"][0]["id"] == 1
assert data["services"][0]["name"] == "Test Service" assert data["services"][0]["name"] == "Test Service"
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"]))
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_services_status_checks( async def test_get_services_status_checks(
default_client: AsyncClient, query: str, expected_status: int default_client: AsyncClient, query: str, expected_status: int
): ):
resp = await default_client.get(f"/service?{query}") resp = await default_client.get(f"/service?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_service_success(default_client: AsyncClient): async def test_post_service_success(default_client: AsyncClient):
resp = await default_client.post("/service", json={"name": "New Test Service"}) resp = await default_client.post("/service", json={"name": "New Test Service"})
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert "service" in data assert "service" in data
assert isinstance(data["service"], dict) assert isinstance(data["service"], dict)
assert data["service"]["name"] == "New Test Service" assert data["service"]["name"] == "New Test Service"
assert data["service"]["id"] == 2 assert data["service"]["id"] == 2
assert isinstance(data["service"]["api_key"], str) assert isinstance(data["service"]["api_key"], str)
@pytest.mark.parametrize("body, expected_status", generate_body_and_status({"name": "str"})) @pytest.mark.parametrize("body, expected_status", generate_body_and_status({"name": "str"}))
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_service_status_checks( async def test_post_service_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.post("/service", json=body) resp = await default_client.post("/service", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_service_conflict(default_client: AsyncClient): async def test_post_service_conflict(default_client: AsyncClient):
resp = await default_client.post("/service", json={"name": "Test Service"}) resp = await default_client.post("/service", json={"name": "Test Service"})
assert resp.status_code == 409 assert resp.status_code == 409
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_service_success(default_client: AsyncClient): async def test_patch_service_success(default_client: AsyncClient):
resp = await default_client.patch("/service/key", json={"service_id": 1}) resp = await default_client.patch("/service/key", json={"service_id": 1})
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert "service" in data assert "service" in data
assert isinstance(data["service"], dict) assert isinstance(data["service"], dict)
assert data["service"]["name"] == "Test Service" assert data["service"]["name"] == "Test Service"
assert data["service"]["id"] == 1 assert data["service"]["id"] == 1
assert isinstance(data["service"]["api_key"], str) assert isinstance(data["service"]["api_key"], str)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
generate_body_and_status({"service_id": "int"}), generate_body_and_status({"service_id": "int"}),
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_patch_services_status_checks( async def test_patch_services_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int default_client: AsyncClient, body: dict[str, str], expected_status: int
): ):
resp = await default_client.patch("/service/key", json=body) resp = await default_client.patch("/service/key", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_service_success(default_client: AsyncClient): async def test_delete_service_success(default_client: AsyncClient):
resp = await default_client.delete("/service?service_id=1") resp = await default_client.delete("/service?service_id=1")
assert resp.status_code == 204 assert resp.status_code == 204

View file

@ -11,191 +11,191 @@ from .conftest import generate_query_and_status
pytestmark = [ pytestmark = [
pytest.mark.user_module, pytest.mark.user_module,
] ]
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_self_db_success(default_client: AsyncClient): async def test_get_self_db_success(default_client: AsyncClient):
resp = await default_client.get("/user/self/db") resp = await default_client.get("/user/self/db")
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert data["first_name"] == "Admin" assert data["first_name"] == "Admin"
assert data["last_name"] == "Test" assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com" assert data["email"] == "admin@test.com"
assert "organisations" in data assert "organisations" in data
assert isinstance(data["organisations"], list) assert isinstance(data["organisations"], list)
assert "groups" in data assert "groups" in data
assert isinstance(data["groups"], dict) assert isinstance(data["groups"], dict)
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_user_success(default_client: AsyncClient): async def test_get_user_success(default_client: AsyncClient):
resp = await default_client.get("/user?user_id=1") resp = await default_client.get("/user?user_id=1")
data = resp.json() data = resp.json()
assert resp.status_code == 200 assert resp.status_code == 200
assert data["first_name"] == "Admin" assert data["first_name"] == "Admin"
assert data["last_name"] == "Test" assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com" assert data["email"] == "admin@test.com"
assert "organisations" in data assert "organisations" in data
assert isinstance(data["organisations"], list) assert isinstance(data["organisations"], list)
assert "groups" in data assert "groups" in data
assert isinstance(data["groups"], dict) assert isinstance(data["groups"], dict)
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["user_id"])) @pytest.mark.parametrize("query, expected_status", generate_query_and_status(["user_id"]))
async def test_get_user_status_checks( async def test_get_user_status_checks(
default_client: AsyncClient, query: str, expected_status: int default_client: AsyncClient, query: str, expected_status: int
): ):
resp = await default_client.get(f"/user?{query}") resp = await default_client.get(f"/user?{query}")
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_user_success(default_client: AsyncClient): async def test_delete_user_success(default_client: AsyncClient):
resp = await default_client.delete("/user?user_id=1") resp = await default_client.delete("/user?user_id=1")
assert resp.status_code == 204 assert resp.status_code == 204
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_user_invitation_success(default_client: AsyncClient): async def test_post_user_invitation_success(default_client: AsyncClient):
body = {"user_email": "admin@test.com", "organisation_id": 1} body = {"user_email": "admin@test.com", "organisation_id": 1}
resp = await default_client.post("/user/invitation", json=body) resp = await default_client.post("/user/invitation", json=body)
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert "organisation" in data assert "organisation" in data
assert isinstance(data["organisation"], dict) assert isinstance(data["organisation"], dict)
assert data["organisation"]["id"] == 1 assert data["organisation"]["id"] == 1
assert data["organisation"]["name"] == "Org One" assert data["organisation"]["name"] == "Org One"
assert "invited_email" in data assert "invited_email" in data
assert isinstance(data["invited_email"], str) assert isinstance(data["invited_email"], str)
assert data["invited_email"] == "admin@test.com" assert data["invited_email"] == "admin@test.com"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"organisation_id": 42, "user_email": "admin@test.com"}, 404), ({"organisation_id": 42, "user_email": "admin@test.com"}, 404),
({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422), ({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422),
({"organisation_id": "", "user_email": "admin@test.com"}, 422), ({"organisation_id": "", "user_email": "admin@test.com"}, 422),
({}, 422), ({}, 422),
({"user_email": 42}, 422), ({"user_email": 42}, 422),
({"organisation_id": 1, "user_email": "Test User"}, 422), ({"organisation_id": 1, "user_email": "Test User"}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_user_invitation_status_checks( async def test_post_user_invitation_status_checks(
default_client: AsyncClient, body, expected_status default_client: AsyncClient, body, expected_status
): ):
resp = await default_client.post("/user/invitation", json=body) resp = await default_client.post("/user/invitation", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
@pytest.mark.parametrize( @pytest.mark.parametrize(
"body, expected_status", "body, expected_status",
[ [
({"jwt": "invalid"}, 401), ({"jwt": "invalid"}, 401),
({"jwt": ""}, 401), ({"jwt": ""}, 401),
({"jwt": None}, 422), ({"jwt": None}, 422),
({"jwt": 42}, 422), ({"jwt": 42}, 422),
], ],
) )
@pytest.mark.anyio @pytest.mark.anyio
async def test_post_user_invitation_accept_status_checks( async def test_post_user_invitation_accept_status_checks(
default_client: AsyncClient, body, expected_status default_client: AsyncClient, body, expected_status
): ):
resp = await default_client.post("/user/invitation/accept", json=body) resp = await default_client.post("/user/invitation/accept", json=body)
assert resp.status_code == expected_status assert resp.status_code == expected_status
if resp.status_code == 401: if resp.status_code == 401:
assert resp.json()["detail"] == "Invalid JWS" assert resp.json()["detail"] == "Invalid JWS"
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_self_orgs_success(default_client: AsyncClient): async def test_get_self_orgs_success(default_client: AsyncClient):
resp = await default_client.get("/user/self/orgs") resp = await default_client.get("/user/self/orgs")
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert "organisations" in data assert "organisations" in data
assert isinstance(data["organisations"], list) assert isinstance(data["organisations"], list)
assert len(data["organisations"]) > 0 assert len(data["organisations"]) > 0
org = data["organisations"][0] org = data["organisations"][0]
assert org["organisation_id"] == 1 assert org["organisation_id"] == 1
assert org["name"] == "Org One" assert org["name"] == "Org One"
assert org["status"] == "approved" assert org["status"] == "approved"
assert org["root_user_email"] == "admin@test.com" assert org["root_user_email"] == "admin@test.com"
assert "intake_questionnaire" in org assert "intake_questionnaire" in org
assert isinstance(org["intake_questionnaire"], dict) assert isinstance(org["intake_questionnaire"], dict)
assert isinstance(org["billing_contact"], dict) assert isinstance(org["billing_contact"], dict)
assert org["billing_contact"]["email"] == "billing@orgone.com" assert org["billing_contact"]["email"] == "billing@orgone.com"
assert org["billing_contact"]["id"] == 1 assert org["billing_contact"]["id"] == 1
assert isinstance(org["owner_contact"], dict) assert isinstance(org["owner_contact"], dict)
assert org["owner_contact"]["email"] == "owner@orgone.com" assert org["owner_contact"]["email"] == "owner@orgone.com"
assert org["owner_contact"]["id"] == 2 assert org["owner_contact"]["id"] == 2
assert isinstance(org["security_contact"], dict) assert isinstance(org["security_contact"], dict)
assert org["security_contact"]["email"] == "security@orgone.com" assert org["security_contact"]["email"] == "security@orgone.com"
assert org["security_contact"]["id"] == 3 assert org["security_contact"]["id"] == 3
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_self_orgs_dynamic(default_client: AsyncClient): async def test_get_self_orgs_dynamic(default_client: AsyncClient):
method = "GET" method = "GET"
path = "/user/self/orgs" path = "/user/self/orgs"
expected_data = { expected_data = {
"organisations": [ "organisations": [
{ {
"organisation_id": 1, "organisation_id": 1,
"name": "Org One", "name": "Org One",
"status": "approved", "status": "approved",
"root_user_email": "admin@test.com", "root_user_email": "admin@test.com",
"owner_contact": {"email": "owner@orgone.com", "id": 2}, "owner_contact": {"email": "owner@orgone.com", "id": 2},
"security_contact": {"email": "security@orgone.com", "id": 3}, "security_contact": {"email": "security@orgone.com", "id": 3},
"billing_contact": {"email": "billing@orgone.com", "id": 1}, "billing_contact": {"email": "billing@orgone.com", "id": 1},
"intake_questionnaire": { "intake_questionnaire": {
"questions": { "questions": {
"question_one": None, "question_one": None,
"question_three": None, "question_three": None,
"question_two": "answer two", "question_two": "answer two",
}, },
"metadata": {"version": 0, "submission_date": None}, "metadata": {"version": 0, "submission_date": None},
}, },
} }
] ]
} }
resp = await default_client.get(path) resp = await default_client.get(path)
route = next( route = next(
route route
for route in default_client._transport.app.routes # ty:ignore[unresolved-attribute] for route in default_client._transport.app.routes # ty:ignore[unresolved-attribute]
if isinstance(route, APIRoute) and path in route.path and method in route.methods if isinstance(route, APIRoute) and path in route.path and method in route.methods
) )
assert resp.status_code == route.status_code assert resp.status_code == route.status_code
if route.status_code == 204: if route.status_code == 204:
return return
expected_response_schema = route.response_model expected_response_schema = route.response_model
data = resp.json() data = resp.json()
response_model = expected_response_schema(**data) response_model = expected_response_schema(**data)
assert isinstance(response_model, expected_response_schema) assert isinstance(response_model, expected_response_schema)
expected_response_model = expected_response_schema(**expected_data) expected_response_model = expected_response_schema(**expected_data)
assert response_model == expected_response_model assert response_model == expected_response_model