minor: ruff formatter

All changes are either:
- Correcting tabs
- Adding/removing line breaks
- Adding trailing commas
This commit is contained in:
Chris Milne 2026-06-08 15:31:37 +01:00
parent b2e5dd2ebb
commit c689ac1e10
91 changed files with 1710 additions and 689 deletions

View file

@ -7,6 +7,10 @@ exclude = [
".alembic"
]
[tool.ruff.format]
quote-style = "double"
indent-style = "tab"
[project]
name = "cloud-api"
version = "0.1.0"

View file

@ -2,4 +2,4 @@
Configurations for the <this> module
Exports:
"""
"""

View file

@ -2,4 +2,4 @@
Constants for the <this> module
Exports:
"""
"""

View file

@ -3,4 +3,4 @@ Dependencies related to the <this> module
Exports:
- <dep_name>: <return_type>: <description>
"""
"""

View file

@ -3,4 +3,4 @@ Exceptions related to the <this> modules
Exceptions:
- <ExceptionName>: Details e.g. optional params
"""
"""

View file

@ -6,4 +6,4 @@ Models:
- <normal_columns[FK][PK]>
- <orm_relationships>
- <calculated_properties>
"""
"""

View file

@ -17,6 +17,7 @@ Exports:
- Dependencies should be used for db model get and validation where possible
- Verify module level docstring is still accurate after updates
"""
from fastapi import APIRouter

View file

@ -5,4 +5,4 @@ Models follow the nomenclature of:
- Sub-models: "<Resource><Opt:>Schema"
- Mixins: "<Attribute>Mixin"
- Models: "<Module><Method><Resource><Opt:Resource><Direction>" ie ""
"""
"""

View file

@ -2,4 +2,4 @@
Module specific business logic for the <this> module
Exports:
"""
"""

View file

@ -1,3 +1,3 @@
"""
Non-business logic reusable functions and classes for the <this> module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Configurations for the admin module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Constants for the admin module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Dependencies for the admin module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Custom exceptions for the admin module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Database models for the admin module
"""
"""

View file

@ -4,6 +4,7 @@ Router endpoints for the admin module
Exports:
- router: fastapi.APIRouter
"""
from fastapi import APIRouter
router = APIRouter(

View file

@ -1,3 +1,3 @@
"""
Pydantic models for the admin module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Module specific business logic for the admin module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Non-business logic reusable functions and classes for the admin module
"""
"""

View file

@ -1,6 +1,7 @@
"""
This module hooks the routers for the main endpoints into a single router for importing to the app.
"""
from fastapi import APIRouter
from src.auth.router import router as auth_router
@ -12,9 +13,7 @@ from src.iam.router import router as iam_router
from src.service.router import router as service_router
api_router = APIRouter(
prefix="/api/v1"
)
api_router = APIRouter(prefix="/api/v1")
api_router.include_router(auth_router)
api_router.include_router(contact_router)
@ -27,5 +26,5 @@ api_router.include_router(iam_router)
@api_router.get("/healthcheck", include_in_schema=False)
def healthcheck():
"""Simple healthcheck endpoint."""
return {"status": "ok"}
"""Simple healthcheck endpoint."""
return {"status": "ok"}

View file

@ -4,12 +4,14 @@ Configurations for the auth module
Exports:
- auth_settings: Contains OIDC information
"""
from src.config import CustomBaseSettings
class AuthConfig(CustomBaseSettings):
OIDC_CONFIG: str = ""
OIDC_ISSUER: str = ""
CLIENT_ID: str = ""
OIDC_CONFIG: str = ""
OIDC_ISSUER: str = ""
CLIENT_ID: str = ""
auth_settings = AuthConfig()

View file

@ -1,3 +1,3 @@
"""
Constants for the auth module
"""
"""

View file

@ -7,18 +7,24 @@ Exports:
- org_model_root_claim_body_dependency: org_model: verifies org exists and user is either root or su, gets org from body
- super_admin_dependency: user_model: verifies the user is a super admin
"""
from typing import Annotated
from fastapi import Depends
from src.user.dependencies import user_model_claims_dependency
from src.user.models import User
from src.organisation.dependencies import org_model_query_dependency, org_model_body_dependency
from src.organisation.dependencies import (
org_model_query_dependency,
org_model_body_dependency,
)
from src.organisation.models import Organisation as Org
from src.auth.exceptions import UnauthorizedException
async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency):
async def org_query_user_claims(
org_model: org_model_query_dependency, user_model: user_model_claims_dependency
):
if user_model in org_model.user_rel:
return True
@ -28,7 +34,11 @@ async def org_query_user_claims(org_model: org_model_query_dependency, user_mode
org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)]
async def org_query_root_claims(user_model: user_model_claims_dependency, org_model: org_model_query_dependency, su_emails: su_list_dependency):
async def org_query_root_claims(
user_model: user_model_claims_dependency,
org_model: org_model_query_dependency,
su_emails: su_list_dependency,
):
if org_model.root_user_id == user_model.id:
return org_model
@ -41,10 +51,16 @@ async def org_query_root_claims(user_model: user_model_claims_dependency, org_mo
raise UnauthorizedException(message="Must be the org's root user")
org_model_root_claim_query_dependency = Annotated[type[Org], Depends(org_query_root_claims)]
org_model_root_claim_query_dependency = Annotated[
type[Org], Depends(org_query_root_claims)
]
async def org_body_root_claims(user_model: user_model_claims_dependency, org_model: org_model_body_dependency, su_emails: su_list_dependency):
async def org_body_root_claims(
user_model: user_model_claims_dependency,
org_model: org_model_body_dependency,
su_emails: su_list_dependency,
):
if org_model.root_user_id == user_model.id:
return org_model
@ -57,21 +73,29 @@ async def org_body_root_claims(user_model: user_model_claims_dependency, org_mod
raise UnauthorizedException(message="Must be the org's root user")
org_model_root_claim_body_dependency = Annotated[type[Org], Depends(org_body_root_claims)]
org_model_root_claim_body_dependency = Annotated[
type[Org], Depends(org_body_root_claims)
]
def get_super_admin_list():
return []
def empty_su_list():
return []
def testing_su_list():
return ["admin@test.com"]
su_list_dependency = Annotated[list[User], Depends(get_super_admin_list)]
async def user_model_super_admin(user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency):
async def user_model_super_admin(
user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency
):
if user_model.email in super_admin_emails:
return user_model

View file

@ -4,6 +4,7 @@ Module specific exceptions for the auth module
Exceptions:
- UnauthorizedException: Takes an optional message string
"""
from typing import Optional
from fastapi import HTTPException, status

View file

@ -1,3 +1,3 @@
"""
Database models for the auth module
"""
"""

View file

@ -4,8 +4,9 @@ Router endpoints for the auth module
Exports:
- router: fastapi.APIRouter
"""
from fastapi import APIRouter
router = APIRouter(
tags=["auth"],
)
)

View file

@ -1,3 +1,3 @@
"""
Pydantic models for the auth module
"""
"""

View file

@ -4,6 +4,7 @@ Module specific business logic for the auth module
Exports:
- claims_dependency: Dict[str, Any] containing OIDC claims and database ID
"""
import json
import requests
@ -25,11 +26,14 @@ from src.database import db_dependency
oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG)
oidc_dependency = Annotated[str, Depends(oidc)]
def get_dev_user():
return {"db_id": 1}
async def get_current_user(oidc_auth_string: oidc_dependency, db: db_dependency) -> dict[str, Any]:
async def get_current_user(
oidc_auth_string: oidc_dependency, db: db_dependency
) -> dict[str, Any]:
config_url = urlopen(auth_settings.OIDC_CONFIG)
config = json.loads(config_url.read())
jwks_uri = config["jwks_uri"]
@ -41,10 +45,7 @@ async def get_current_user(oidc_auth_string: oidc_dependency, db: db_dependency)
"iss": {"essential": True, "value": auth_settings.OIDC_ISSUER},
}
token = jwt.decode(
oidc_auth_string.replace("Bearer ", ""),
jwk_keys
)
token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys)
claims_requests = jwt.JWTClaimsRegistry(**claims_options)

View file

@ -1,3 +1,3 @@
"""
Non-business logic reusable functions and classes for the auth module
"""
"""

View file

@ -16,25 +16,26 @@ from src.constants import Environment
class CustomBaseSettings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
class Config(CustomBaseSettings):
APP_VERSION: str = "0.1"
ENVIRONMENT: Environment = Environment.PRODUCTION
SECRET_KEY: SecretStr = ""
DISABLE_AUTH: bool = False
APP_VERSION: str = "0.1"
ENVIRONMENT: Environment = Environment.PRODUCTION
SECRET_KEY: SecretStr = ""
DISABLE_AUTH: bool = False
CORS_ORIGINS: list[str] = ["*"]
CORS_ORIGINS_REGEX: str | None = None
CORS_HEADERS: list[str] = ["*"]
CORS_ORIGINS: list[str] = ["*"]
CORS_ORIGINS_REGEX: str | None = None
CORS_HEADERS: list[str] = ["*"]
DATABASE_NAME: str = "fastapi-exp"
DATABASE_PORT: str = "5432"
DATABASE_HOSTNAME: str = "localhost"
DATABASE_CREDENTIALS: SecretStr = ":"
DATABASE_NAME: str = "fastapi-exp"
DATABASE_PORT: str = "5432"
DATABASE_HOSTNAME: str = "localhost"
DATABASE_CREDENTIALS: SecretStr = ":"
settings = Config()
@ -43,17 +44,21 @@ DATABASE_PORT = settings.DATABASE_PORT
DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME
DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value()
# this will support special chars for credentials
_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split(":")
_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(
DATABASE_CREDENTIALS
).split(":")
_QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD))
SQLALCHEMY_DATABASE_URI = SecretStr(f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}")
SQLALCHEMY_DATABASE_URI = SecretStr(
f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}"
)
if settings.ENVIRONMENT == Environment.TESTING:
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")
SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:")
app_configs: dict[str, Any] = {"title": "App API"}
if settings.ENVIRONMENT.is_deployed:
app_configs["root_path"] = f"/v{settings.APP_VERSION}"
app_configs["root_path"] = f"/v{settings.APP_VERSION}"
if not settings.ENVIRONMENT.is_debug:
app_configs["openapi_url"] = None # hide docs
app_configs["openapi_url"] = None # hide docs

View file

@ -4,6 +4,7 @@ Global constants
Classes:
- Environment(StrEnum): LOCAL, TESTING, STAGING, PRODUCTION
"""
from enum import StrEnum, auto

View file

@ -1,3 +1,3 @@
"""
Configurations for the contact module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Constants for the contact module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Dependencies for the contact module
"""
"""

View file

@ -4,6 +4,7 @@ Exceptions related to the contact module
Exports:
- ContactNotFoundException: Takes an optional contact ID int
"""
from typing import Optional
from fastapi import HTTPException, status
@ -11,7 +12,11 @@ from fastapi import HTTPException, status
class ContactNotFoundException(HTTPException):
def __init__(self, contact_id: Optional[int] = None) -> None:
detail = "Contact not found" if contact_id is None else f"Contact with ID '{contact_id}' was not found."
detail = (
"Contact not found"
if contact_id is None
else f"Contact with ID '{contact_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,

View file

@ -5,6 +5,7 @@ Models:
- Contact: id[pk], email, first_name, last_name, phonenumber, vat_number
street_address, street_address_line_2, post_office_box_number, address_locality, country_code, address_region, postal_code
"""
from sqlalchemy import Column, Integer, String, ForeignKey
from src.database import Base
@ -23,9 +24,11 @@ class Contact(Base):
street_address = Column(String)
street_address_line_2 = Column(String)
post_office_box_number = Column(String, default=None, nullable=True)
locality = Column(String) # Ie City
country_code = Column(String) # Eg GB
locality = Column(String) # Ie City
country_code = Column(String) # Eg GB
address_region = Column(String, default=None, nullable=True)
postal_code = Column(String)
org_id = Column(Integer, ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False)
org_id = Column(
Integer, ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False
)

View file

@ -1,10 +1,11 @@
"""
Router endpoints for the contact module
"""
from fastapi import APIRouter
router = APIRouter(
prefix="/contact",
tags=["contact"],
)
)

View file

@ -5,6 +5,7 @@ Models:
- ContactAddress
- ContactModel: Contains ContactAddress as a property
"""
from typing import Optional
from pydantic import EmailStr, ConfigDict

View file

@ -1,3 +1,3 @@
"""
Module specific business logic for the contact module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Non-business logic reusable functions and classes for the contact module
"""
"""

View file

@ -5,6 +5,7 @@ Exports:
- db_dependency
- Base (sqlalchemy base model)
"""
from typing import Annotated
from sqlalchemy import create_engine, StaticPool
from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session
@ -16,7 +17,11 @@ from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings
if global_settings.ENVIRONMENT == Environment.TESTING:
connect_args = {"check_same_thread": False}
engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value(), connect_args=connect_args, poolclass=StaticPool)
engine = create_engine(
SQLALCHEMY_DATABASE_URI.get_secret_value(),
connect_args=connect_args,
poolclass=StaticPool,
)
else:
engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value())
@ -36,5 +41,7 @@ def get_db():
db_dependency = Annotated[Session, Depends(get_db)]
class Base(DeclarativeBase):
pass

View file

@ -5,6 +5,7 @@ Exports:
- UnprocessableContentException
- ConflictException
"""
from typing import Optional
from fastapi import HTTPException, status

View file

@ -1,3 +1,3 @@
"""
Configurations for the IAM module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Constants for the IAM module
"""
"""

View file

@ -6,6 +6,7 @@ Exports:
- group_model_body_dependency: group_model: Gets group model from db, if it exists. Uses group_id from request body.
- perm_model_body_dependency: perm_model: Gets perm model from db, if it exists. Uses perm_id from request body.
"""
from typing import Annotated, Optional
from fastapi import Depends, Query
@ -17,17 +18,22 @@ from src.iam.exceptions import GroupNotFoundException, PermNotFoundException
from src.iam.schemas import GroupIDMixin, PermIDMixin
def get_group_model_query(db: db_dependency, group_id: Annotated[int, Query(gt=0)]) -> type[Group]:
def get_group_model_query(
db: db_dependency, group_id: Annotated[int, Query(gt=0)]
) -> type[Group]:
group_model = db.get(Group, group_id)
if group_model is None:
raise GroupNotFoundException(group_id)
return group_model
group_model_query_dependency = Annotated[type[Group], Depends(get_group_model_query)]
def get_group_model_body(db: db_dependency, request_model: Optional[GroupIDMixin] = None) -> type[Group]:
def get_group_model_body(
db: db_dependency, request_model: Optional[GroupIDMixin] = None
) -> type[Group]:
group_id = getattr(request_model, "group_id", None)
if group_id is None:
raise GroupNotFoundException()
@ -37,10 +43,13 @@ def get_group_model_body(db: db_dependency, request_model: Optional[GroupIDMixin
return group_model
group_model_body_dependency = Annotated[type[Group], Depends(get_group_model_body)]
def get_perm_model_body(db: db_dependency, request_model: Optional[PermIDMixin] = None) -> type[Permission]:
def get_perm_model_body(
db: db_dependency, request_model: Optional[PermIDMixin] = None
) -> type[Permission]:
perm_id = getattr(request_model, "permission_id", None)
if perm_id is None:
raise PermNotFoundException
@ -50,4 +59,5 @@ def get_perm_model_body(db: db_dependency, request_model: Optional[PermIDMixin]
return perm_model
perm_model_body_dependency = Annotated[type[Permission], Depends(get_perm_model_body)]

View file

@ -5,6 +5,7 @@ Exceptions:
- GroupNotFoundException: Takes an optional group_id int
- PermNotFoundException: Takes an optional perm_id int
"""
from typing import Optional
from fastapi import HTTPException, status
@ -12,7 +13,11 @@ from fastapi import HTTPException, status
class GroupNotFoundException(HTTPException):
def __init__(self, group_id: Optional[int] = None) -> None:
detail = "Group not found" if group_id is None else f"User with ID '{group_id}' was not found."
detail = (
"Group not found"
if group_id is None
else f"User with ID '{group_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
@ -21,7 +26,11 @@ class GroupNotFoundException(HTTPException):
class PermNotFoundException(HTTPException):
def __init__(self, perm_id: Optional[int] = None) -> None:
detail = "Permission not found" if perm_id is None else f"User with ID '{perm_id}' was not found."
detail = (
"Permission not found"
if perm_id is None
else f"User with ID '{perm_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,

View file

@ -17,6 +17,7 @@ Models:
- UserGroups:
- org_id[FK][PK], user_id[FK][PK], group_id[FK][PK]
"""
from sqlalchemy import Column, Integer, String, ForeignKey, UniqueConstraint
from sqlalchemy.orm import relationship
@ -32,7 +33,9 @@ class Permission(Base):
service_id = Column(Integer, ForeignKey("service.id", ondelete="CASCADE"))
UniqueConstraint("service_id", "resource", "action", name="uniq_permission_resource_and_action")
UniqueConstraint(
"service_id", "resource", "action", name="uniq_permission_resource_and_action"
)
service_rel = relationship("Service", foreign_keys=[service_id])
@ -41,13 +44,10 @@ class Permission(Base):
return self.service_rel.name
group_rel = relationship(
"Group",
secondary="group_permissions",
back_populates="permission_rel"
"Group", secondary="group_permissions", back_populates="permission_rel"
)
class Group(Base):
__tablename__ = "group"
id = Column(Integer, primary_key=True)
@ -55,28 +55,30 @@ class Group(Base):
org_id = Column(Integer, ForeignKey("organisation.id", ondelete="CASCADE"))
user_rel = relationship(
"User",
secondary="user_groups",
back_populates="group_rel"
)
user_rel = relationship("User", secondary="user_groups", back_populates="group_rel")
org_rel = relationship("Organisation", back_populates="group_rel")
permission_rel = relationship(
"Permission",
secondary="group_permissions",
back_populates="group_rel"
"Permission", secondary="group_permissions", back_populates="group_rel"
)
class GroupPermissions(Base):
__tablename__ = "group_permissions"
group_id = Column(Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True)
permission_id = Column(Integer, ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True)
group_id = Column(
Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
)
permission_id = Column(
Integer, ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True
)
class UserGroups(Base):
__tablename__ = "user_groups"
user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True)
group_id = Column(Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True)
user_id = Column(
Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
)
group_id = Column(
Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True
)

View file

@ -15,6 +15,7 @@ Endpoints:
- [DELETE](/iam/permission): [super admin]: Removes a permission
- [GET](/iam/permissions/search): [root user]: Returns a list of permissions matching a filter(service|resource|action)
"""
from fastapi import APIRouter, status
from sqlalchemy.exc import IntegrityError
from psycopg import errors
@ -25,21 +26,49 @@ from src.database import db_dependency
from src.schemas import ResourceName
from src.auth.exceptions import UnauthorizedException
from src.auth.service import claims_dependency
from src.auth.dependencies import org_model_root_claim_query_dependency, org_model_root_claim_body_dependency, \
super_admin_dependency
from src.auth.dependencies import (
org_model_root_claim_query_dependency,
org_model_root_claim_body_dependency,
super_admin_dependency,
)
from src.user.models import User
from src.user.dependencies import user_model_body_dependency
from src.organisation.models import Organisation as Org
from src.service.models import Service
from src.iam.service import service_key_dependency
from src.iam.models import Permission as Perm, GroupPermissions as GPerms, Group, UserGroups
from src.iam.dependencies import group_model_query_dependency, group_model_body_dependency, perm_model_body_dependency
from src.iam.schemas import IAMGetGroupPermissionsResponse, IAMGetGroupUsersResponse, IAMPostGroupRequest, \
GroupSchema, IAMPostGroupResponse, IAMPutGroupPermissionRequest, IAMPutGroupPermissionResponse, \
IAMPutGroupUserRequest, IAMPutGroupUserResponse, IAMDeleteGroupPermissionRequest, IAMDeleteGroupPermissionResponse, \
IAMDeleteGroupUserRequest, IAMDeleteGroupUserResponse, IAMGetPermissionsResponse, IAMPostPermissionRequest, \
IAMPostPermissionResponse, IAMDeletePermissionRequest, IAMGetPermissionsSearchRequest, IAMGetPermissionsSearchResponse
from src.iam.models import (
Permission as Perm,
GroupPermissions as GPerms,
Group,
UserGroups,
)
from src.iam.dependencies import (
group_model_query_dependency,
group_model_body_dependency,
perm_model_body_dependency,
)
from src.iam.schemas import (
IAMGetGroupPermissionsResponse,
IAMGetGroupUsersResponse,
IAMPostGroupRequest,
GroupSchema,
IAMPostGroupResponse,
IAMPutGroupPermissionRequest,
IAMPutGroupPermissionResponse,
IAMPutGroupUserRequest,
IAMPutGroupUserResponse,
IAMDeleteGroupPermissionRequest,
IAMDeleteGroupPermissionResponse,
IAMDeleteGroupUserRequest,
IAMDeleteGroupUserResponse,
IAMGetPermissionsResponse,
IAMPostPermissionRequest,
IAMPostPermissionResponse,
IAMDeletePermissionRequest,
IAMGetPermissionsSearchRequest,
IAMGetPermissionsSearchResponse,
)
router = APIRouter(
tags=["IAM"],
@ -48,26 +77,32 @@ router = APIRouter(
@router.post("/can_act_on_resource")
async def can_act_on_resource(valid_key: service_key_dependency, db: db_dependency, user_claims: claims_dependency,
rn: ResourceName, action: str) -> bool:
async def can_act_on_resource(
valid_key: service_key_dependency,
db: db_dependency,
user_claims: claims_dependency,
rn: ResourceName,
action: str,
) -> bool:
try:
user_id = user_claims["db_id"]
rn_org = rn.organisation
rn_service = rn.service
rn_resource = rn.resource
result = (db.query(Perm)
.join(Service, Service.id == Perm.service_id)
.join(GPerms, GPerms.permission_id == Perm.id)
.join(Group, Group.id == GPerms.group_id)
.join(Org, Org.id == Group.org_id)
.join(UserGroups, UserGroups.group_id == Group.id)
.join(User, User.id == UserGroups.user_id)
.filter(User.id == user_id)
.filter(Org.name == rn_org)
.filter(Service.name == rn_service)
.filter(Perm.resource == rn_resource)
.filter(Perm.action == action)
result = (
db.query(Perm)
.join(Service, Service.id == Perm.service_id)
.join(GPerms, GPerms.permission_id == Perm.id)
.join(Group, Group.id == GPerms.group_id)
.join(Org, Org.id == Group.org_id)
.join(UserGroups, UserGroups.group_id == Group.id)
.join(User, User.id == UserGroups.user_id)
.filter(User.id == user_id)
.filter(Org.name == rn_org)
.filter(Service.name == rn_service)
.filter(Perm.resource == rn_resource)
.filter(Perm.action == action)
).first()
if result:
@ -79,21 +114,31 @@ async def can_act_on_resource(valid_key: service_key_dependency, db: db_dependen
@router.get("/group/permissions", response_model=IAMGetGroupPermissionsResponse)
async def get_group_permissions(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency):
async def get_group_permissions(
group_model: group_model_query_dependency,
org_model: org_model_root_claim_query_dependency,
):
if group_model.org_id != org_model.id:
raise UnauthorizedException("Group does not belong to this organization")
return {"permissions": group_model.permission_rel}
@router.get("/group/users", response_model=IAMGetGroupUsersResponse)
async def get_group_users(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency):
async def get_group_users(
group_model: group_model_query_dependency,
org_model: org_model_root_claim_query_dependency,
):
if group_model.org_id != org_model.id:
raise UnauthorizedException("Group does not belong to this organization")
return {"users": group_model.user_rel}
@router.post("/group", response_model=IAMPostGroupResponse)
async def create_group(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPostGroupRequest):
async def create_group(
db: db_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: IAMPostGroupRequest,
):
group_model = Group(name=request_model.name, org_id=org_model.id)
db.add(group_model)
@ -101,9 +146,9 @@ async def create_group(db: db_dependency, org_model: org_model_root_claim_body_d
db.flush()
except IntegrityError as e:
if (
getattr(e.orig, "pgcode", None) == "23505" # Postgres unique violation
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
):
getattr(e.orig, "pgcode", None) == "23505" # Postgres unique violation
or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation
):
raise ConflictException("Group with this name already exists")
response = GroupSchema(**group_model.__dict__)
db.commit()
@ -111,7 +156,13 @@ async def create_group(db: db_dependency, org_model: org_model_root_claim_body_d
@router.put("/group/permission", response_model=IAMPutGroupPermissionResponse)
async def add_group_permission(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupPermissionRequest):
async def add_group_permission(
db: db_dependency,
group_model: group_model_body_dependency,
perm_model: perm_model_body_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: IAMPutGroupPermissionRequest,
):
if group_model.org_id != org_model.id:
raise UnauthorizedException("Group does not belong to this organization")
@ -121,13 +172,22 @@ async def add_group_permission(db: db_dependency, group_model: group_model_body_
group_model.permission_rel.append(perm_model)
db.flush()
response = IAMPutGroupPermissionResponse(group=GroupSchema(**group_model.__dict__), permissions=group_model.permission_rel)
response = IAMPutGroupPermissionResponse(
group=GroupSchema(**group_model.__dict__),
permissions=group_model.permission_rel,
)
db.commit()
return response
@router.put("/group/user", response_model=IAMPutGroupUserResponse)
async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupUserRequest):
async def add_group_user(
db: db_dependency,
group_model: group_model_body_dependency,
user_model: user_model_body_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: IAMPutGroupUserRequest,
):
if group_model.org_id != org_model.id:
raise UnauthorizedException("Group does not belong to this organization")
@ -136,46 +196,70 @@ async def add_group_user(db: db_dependency, group_model: group_model_body_depend
group_model.user_rel.append(user_model)
db.flush()
response = IAMPutGroupUserResponse(group=GroupSchema(**group_model.__dict__), users=group_model.user_rel)
response = IAMPutGroupUserResponse(
group=GroupSchema(**group_model.__dict__), users=group_model.user_rel
)
db.commit()
return response
@router.delete("/group/permissions")
async def remove_group_permissions(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupPermissionRequest):
async def remove_group_permissions(
db: db_dependency,
group_model: group_model_body_dependency,
perm_model: perm_model_body_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: IAMDeleteGroupPermissionRequest,
):
if group_model.org_id != org_model.id:
raise UnauthorizedException("Group does not belong to this organization")
group_model.permission_rel.remove(perm_model)
db.flush()
response = IAMDeleteGroupPermissionResponse(group=GroupSchema(**group_model.__dict__),
permissions=group_model.permission_rel)
response = IAMDeleteGroupPermissionResponse(
group=GroupSchema(**group_model.__dict__),
permissions=group_model.permission_rel,
)
db.commit()
return response
@router.delete("/group/user")
async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupUserRequest):
async def remove_group_user(
db: db_dependency,
group_model: group_model_body_dependency,
user_model: user_model_body_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: IAMDeleteGroupUserRequest,
):
if group_model.org_id != org_model.id:
raise UnauthorizedException("Group does not belong to this organization")
user_model.group_rel.remove(group_model)
db.flush()
response = IAMDeleteGroupUserResponse(group=GroupSchema(**group_model.__dict__), users=group_model.user_rel)
response = IAMDeleteGroupUserResponse(
group=GroupSchema(**group_model.__dict__), users=group_model.user_rel
)
db.commit()
return response
@router.get("/permissions", response_model=IAMGetPermissionsResponse)
async def get_permissions(db: db_dependency, org_model: org_model_root_claim_query_dependency):
async def get_permissions(
db: db_dependency, org_model: org_model_root_claim_query_dependency
):
permission_models = db.query(Perm).all()
return {"permissions": permission_models}
@router.post("/permission", response_model=IAMPostPermissionResponse)
async def create_new_permission(db: db_dependency, su: super_admin_dependency, request_model: IAMPostPermissionRequest):
async def create_new_permission(
db: db_dependency,
su: super_admin_dependency,
request_model: IAMPostPermissionRequest,
):
service_model = db.get(Service, request_model.service_id)
if service_model is None:
raise ServiceNotFoundException(service_id=request_model.service_id)
@ -186,29 +270,46 @@ async def create_new_permission(db: db_dependency, su: super_admin_dependency, r
if isinstance(e.orig, errors.UniqueViolation):
raise ConflictException(message="Permission already exists")
db.flush()
response = {"service_name": perm_model.service_name, "resource": perm_model.resource, "action": perm_model.action}
response = {
"service_name": perm_model.service_name,
"resource": perm_model.resource,
"action": perm_model.action,
}
db.commit()
return {"permission": response}
@router.delete("/permission", status_code=status.HTTP_204_NO_CONTENT)
async def delete_permission(db: db_dependency, su: super_admin_dependency, perm_model: perm_model_body_dependency, request_model: IAMDeletePermissionRequest):
async def delete_permission(
db: db_dependency,
su: super_admin_dependency,
perm_model: perm_model_body_dependency,
request_model: IAMDeletePermissionRequest,
):
db.delete(perm_model)
db.commit()
@router.post("/permissions/search", response_model=IAMGetPermissionsSearchResponse)
async def post_permissions(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMGetPermissionsSearchRequest):
async def post_permissions(
db: db_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: IAMGetPermissionsSearchRequest,
):
permission_query = db.query(Perm)
if request_model.service_id is not None:
permission_query = permission_query.filter(Perm.service_id == request_model.service_id)
permission_query = permission_query.filter(
Perm.service_id == request_model.service_id
)
if request_model.resource is not None:
permission_query = permission_query.filter(Perm.resource == request_model.resource)
permission_query = permission_query.filter(
Perm.resource == request_model.resource
)
if request_model.action is not None:
permission_query = permission_query.filter(Perm.action == request_model. action)
permission_query = permission_query.filter(Perm.action == request_model.action)
permission_models = permission_query.all()

View file

@ -6,6 +6,7 @@ Models follow the nomenclature of:
- Mixins: "<Attribute>Mixin"
- Models: "<Module><Method><Resource><Opt:Resource><Direction>" ie "IAMGetGroupPermissionsResponse"
"""
from typing import Optional, Annotated
from pydantic import EmailStr, ConfigDict, Field
@ -24,6 +25,7 @@ class UserSchema(CustomBaseModel):
last_name: str
email: EmailStr
class PermissionSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
@ -31,73 +33,94 @@ class PermissionSchema(CustomBaseModel):
resource: str
action: str
class GroupSchema(CustomBaseModel):
id: int
name: str
class GroupIDMixin(CustomBaseModel):
group_id: int = Field(gt=0)
class PermIDMixin(CustomBaseModel):
permission_id: int = Field(gt=0)
class IAMGetGroupPermissionsResponse(CustomBaseModel):
permissions: list[PermissionSchema]
class IAMGetGroupUsersResponse(CustomBaseModel):
users : list[UserSchema]
users: list[UserSchema]
class IAMPostGroupRequest(OrgIDMixin):
name: str = Field(min_length=3)
class IAMPostGroupResponse(CustomBaseModel):
group: GroupSchema
class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin):
pass
class IAMPutGroupPermissionResponse(CustomBaseModel):
group: GroupSchema
permissions: list[PermissionSchema]
class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin):
pass
class IAMPutGroupUserResponse(CustomBaseModel):
group: GroupSchema
users: list[UserSchema]
class IAMDeleteGroupPermissionRequest(GroupIDMixin, PermIDMixin):
pass
class IAMDeleteGroupPermissionResponse(CustomBaseModel):
group: GroupSchema
permissions: list[PermissionSchema]
class IAMDeleteGroupUserRequest(GroupIDMixin, UserIDMixin):
pass
class IAMDeleteGroupUserResponse(CustomBaseModel):
group: GroupSchema
users: list[UserSchema]
class IAMGetPermissionsResponse(CustomBaseModel):
permissions: list[PermissionSchema]
class IAMPostPermissionRequest(ServiceIDMixin):
resource: str
action: str
class IAMPostPermissionResponse(CustomBaseModel):
permission: PermissionSchema
class IAMDeletePermissionRequest(PermIDMixin):
pass
class IAMGetPermissionsSearchRequest(OrgIDMixin):
service_id: Annotated[int | None, Field(gt=0)] = None
resource: Optional[str] = None
action: Optional[str] = None
class IAMGetPermissionsSearchResponse(CustomBaseModel):
permissions: list[PermissionSchema]

View file

@ -4,6 +4,7 @@ Business logic reusable functions related to IAM
Exports:
- service_key_dependency: bool: verifies request headers contain the correct api key for the service
"""
from typing import Annotated
from src.service.models import Service
@ -19,10 +20,16 @@ def valid_service_key(db: db_dependency, request: Request, rn: ResourceName) ->
if not api_key:
raise UnauthorizedException("Missing API key")
service = rn.service
result = db.query(Service).filter(Service.name == service).filter(Service.api_key == api_key).first()
result = (
db.query(Service)
.filter(Service.name == service)
.filter(Service.api_key == api_key)
.first()
)
if result is None:
raise UnauthorizedException("Invalid API key")
return True
service_key_dependency = Annotated[bool, Depends(valid_service_key)]

View file

@ -1,6 +1,7 @@
"""
Application root file: Inits the FastAPI application
"""
from contextlib import asynccontextmanager
from typing import AsyncGenerator

View file

@ -1,4 +1,3 @@
"""
Global database models
"""

View file

@ -1,3 +1,3 @@
"""
Configurations for the organisation module
"""
"""

View file

@ -5,6 +5,7 @@ Classes:
- Status(StrEnum): PARTIAL, SUBMITTED, REMEDIATION, APPROVED, REJECTED, REMOVED
- ContactType(StrEnum): BILLING, SECURITY, OWNER
"""
from enum import StrEnum, auto

View file

@ -5,6 +5,7 @@ Exports:
- org_model_query_dependency: org_model: Gets org model from db, if it exists. Uses org_id from query param. Also verifies if the org has been approved.
- org_model_body_dependency: org_model: Gets org model from db, if it exists. Uses org_id from request body. Also verifies if the org has been approved.
"""
from typing import Annotated, Optional
from sqlalchemy.orm import Session
@ -25,25 +26,40 @@ def get_org_model(db: Session, request: Request, org_id: int):
root = "/api/v1"
pre_approval_endpoints = [f"PATCH{root}/org/status", f"PATCH{root}/org/questionnaire", f"GET{root}/org", f"GET{root}/org/contact", f"PATCH{root}/org/contact"]
pre_approval_endpoints = [
f"PATCH{root}/org/status",
f"PATCH{root}/org/questionnaire",
f"GET{root}/org",
f"GET{root}/org/contact",
f"PATCH{root}/org/contact",
]
current_request = f"{request.method}{request.url.path}"
if current_request not in pre_approval_endpoints and org_model.status != OrgStatus.APPROVED:
if (
current_request not in pre_approval_endpoints
and org_model.status != OrgStatus.APPROVED
):
raise AwaitingApprovalException(org_id)
return org_model
def get_org_model_query(db: db_dependency, request: Request, org_id: Annotated[int, Query(gt=0)]) -> type[Org]:
def get_org_model_query(
db: db_dependency, request: Request, org_id: Annotated[int, Query(gt=0)]
) -> type[Org]:
return get_org_model(db, request, org_id)
org_model_query_dependency = Annotated[type[Org], Depends(get_org_model_query)]
def get_org_model_body(db: db_dependency, request: Request, request_model: OrgIDMixin) -> type[Org]:
def get_org_model_body(
db: db_dependency, request: Request, request_model: OrgIDMixin
) -> type[Org]:
org_id: Optional[int] = getattr(request_model, "organisation_id", None)
if org_id is None:
raise OrgNotFoundException
return get_org_model(db, request, org_id)
org_model_body_dependency = Annotated[type[Org], Depends(get_org_model_body)]

View file

@ -5,6 +5,7 @@ Exceptions:
- OrgNotFoundException: Takes an optional org_id int
- AwaitingApprovalException: Takes an optional org_id int
"""
from typing import Optional
from fastapi import HTTPException, status
@ -12,15 +13,24 @@ from fastapi import HTTPException, status
class OrgNotFoundException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None:
detail = "Organisation not found" if org_id is None else f"Organisation with ID '{org_id}' was not found."
detail = (
"Organisation not found"
if org_id is None
else f"Organisation with ID '{org_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,
)
class AwaitingApprovalException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None:
detail = "Organisation has not been approved." if org_id is None else f"Organisation with ID '{org_id}' has not been approved."
detail = (
"Organisation has not been approved."
if org_id is None
else f"Organisation with ID '{org_id}' has not been approved."
)
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,

View file

@ -13,6 +13,7 @@ Models:
- owner_contact_rel: ORM relationship to Contact with owner_contact FK
- OrgUsers: org_id[FK][PK], user_id[FK][PK]
"""
from sqlalchemy import Column, Integer, String, ForeignKey, JSON
from sqlalchemy.orm import relationship
@ -34,9 +35,7 @@ class Organisation(Base):
owner_contact_id = Column(Integer, ForeignKey("contact.id"))
user_rel = relationship(
"User",
secondary="orgusers",
back_populates="organisation_rel"
"User", secondary="orgusers", back_populates="organisation_rel"
)
group_rel = relationship("Group", back_populates="org_rel")
@ -54,5 +53,9 @@ class Organisation(Base):
class OrgUsers(Base):
__tablename__ = "orgusers"
org_id = Column(Integer, ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True)
user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True)
org_id = Column(
Integer, ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True
)
user_id = Column(
Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True
)

View file

@ -15,6 +15,7 @@ Endpoints:
- [GET](/org/contact): [root user]: Gets the (contact_type) contact for an org(id)
- [PATCH](/org/contact): [root user]: Updates the (contact_type) contact for an org(id). Any number of details can be changed.
"""
from typing import Annotated
from fastapi import APIRouter, status
@ -28,17 +29,40 @@ from src.contact.models import Contact
from src.contact.schemas import ContactAddress
from src.contact.exceptions import ContactNotFoundException
from src.database import db_dependency
from src.user.dependencies import user_model_body_dependency, user_model_claims_dependency
from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency, org_model_root_claim_body_dependency
from src.user.dependencies import (
user_model_body_dependency,
user_model_claims_dependency,
)
from src.auth.dependencies import (
super_admin_dependency,
org_model_root_claim_query_dependency,
org_model_root_claim_body_dependency,
)
from src.organisation.dependencies import org_model_body_dependency
from src.organisation.constants import ContactType
from src.organisation.models import Organisation as Org
from src.organisation.schemas import OrgPostOrgRequest, OrgPatchQuestionnaireRequest, OrgPatchStatusRequest, \
OrgPatchContactRequest, \
OrgPostUserRequest, OrgGetUserResponse, OrgGetContactResponse, OrgGetOrgResponse, OrgPatchRootRequest, \
OrgGetGroupResponse, OrgDeleteUserRequest, OrgDeleteOrgRequest, OrgPostOrgResponse, OrgPatchQuestionnaireResponse, \
OrgPatchStatusResponse, OrgPostUserResponse, OrgPatchRootResponse, Questionnaire, OrgPatchContactResponse
from src.organisation.schemas import (
OrgPostOrgRequest,
OrgPatchQuestionnaireRequest,
OrgPatchStatusRequest,
OrgPatchContactRequest,
OrgPostUserRequest,
OrgGetUserResponse,
OrgGetContactResponse,
OrgGetOrgResponse,
OrgPatchRootRequest,
OrgGetGroupResponse,
OrgDeleteUserRequest,
OrgDeleteOrgRequest,
OrgPostOrgResponse,
OrgPatchQuestionnaireResponse,
OrgPatchStatusResponse,
OrgPostUserResponse,
OrgPatchRootResponse,
Questionnaire,
OrgPatchContactResponse,
)
router = APIRouter(
prefix="/org",
@ -46,16 +70,22 @@ router = APIRouter(
)
@router.get("",
summary="Get org details by ID.",
response_model=OrgGetOrgResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
status.HTTP_404_NOT_FOUND: {"description": "Organisation not found"},
status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Missing or invalid org_id query parameter"},
status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
})
@router.get(
"",
summary="Get org details by ID.",
response_model=OrgGetOrgResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
status.HTTP_404_NOT_FOUND: {"description": "Organisation not found"},
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Missing or invalid org_id query parameter"
},
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be org root user."
},
},
)
async def get_org_by_id(org_model: org_model_root_claim_query_dependency):
"""
Returns organisation details including key member email addresses
@ -68,23 +98,35 @@ async def get_org_by_id(org_model: org_model_root_claim_query_dependency):
"billing_contact": org_model.billing_contact_rel.email,
"security_contact": org_model.security_contact_rel.email,
"root_user": org_model.root_user_email,
"intake_questionnaire": org_model.intake_questionnaire
"intake_questionnaire": org_model.intake_questionnaire,
}
return response
@router.post("",
summary="Create new organisation.",
status_code=status.HTTP_201_CREATED,
response_model=OrgPostOrgResponse,
responses={
status.HTTP_201_CREATED: {"description": "Successfully created organisation."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
status.HTTP_401_UNAUTHORIZED: {"description": "User must be logged in with OIDC to create organisation."},
status.HTTP_409_CONFLICT: {"description": "Organisation with this name already exists."},
})
async def create_org(db: db_dependency, user_model: user_model_claims_dependency, request_model: OrgPostOrgRequest):
@router.post(
"",
summary="Create new organisation.",
status_code=status.HTTP_201_CREATED,
response_model=OrgPostOrgResponse,
responses={
status.HTTP_201_CREATED: {"description": "Successfully created organisation."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Invalid data in request."
},
status.HTTP_401_UNAUTHORIZED: {
"description": "User must be logged in with OIDC to create organisation."
},
status.HTTP_409_CONFLICT: {
"description": "Organisation with this name already exists."
},
},
)
async def create_org(
db: db_dependency,
user_model: user_model_claims_dependency,
request_model: OrgPostOrgRequest,
):
"""
Creates a new organisation with optional questionnaire (to be completed or submitted).
ALl organisations are given the "partial" status on creation. See update_questionnaire() for more details.
@ -102,11 +144,17 @@ async def create_org(db: db_dependency, user_model: user_model_claims_dependency
db.flush()
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise ConflictException(message="Organisation with this name already exists")
raise ConflictException(
message="Organisation with this name already exists"
)
# Adds currently logged-in user to org users list and sets them as root_user
org_model.user_rel.append(user_model)
org_model.root_user_rel = user_model
for contact_type in ["billing_contact_id", "security_contact_id", "owner_contact_id"]:
for contact_type in [
"billing_contact_id",
"security_contact_id",
"owner_contact_id",
]:
contact_model = Contact(org_id=org_model.id)
db.add(contact_model)
db.flush()
@ -116,16 +164,26 @@ async def create_org(db: db_dependency, user_model: user_model_claims_dependency
return response
@router.patch("/questionnaire",
summary="Update questionnaire.",
status_code=status.HTTP_200_OK,
response_model=OrgPatchQuestionnaireResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully updated questionnaire."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
})
async def update_questionnaire(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchQuestionnaireRequest):
@router.patch(
"/questionnaire",
summary="Update questionnaire.",
status_code=status.HTTP_200_OK,
response_model=OrgPatchQuestionnaireResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully updated questionnaire."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Invalid data in request."
},
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be org root user."
},
},
)
async def update_questionnaire(
db: db_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: OrgPatchQuestionnaireRequest,
):
"""
Route for updating questionnaire.
The partial bool allows for submission of partially completed questionnaire and/or
@ -150,16 +208,29 @@ async def update_questionnaire(db: db_dependency, org_model: org_model_root_clai
return response
@router.patch("/status",
summary="Update status of organisation.",
status_code=status.HTTP_200_OK,
response_model=OrgPatchStatusResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully updated organisation status."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be super admin."},
})
async def update_status(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchStatusRequest):
@router.patch(
"/status",
summary="Update status of organisation.",
status_code=status.HTTP_200_OK,
response_model=OrgPatchStatusResponse,
responses={
status.HTTP_200_OK: {
"description": "Successfully updated organisation status."
},
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Invalid data in request."
},
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be super admin."
},
},
)
async def update_status(
db: db_dependency,
org_model: org_model_body_dependency,
su: super_admin_dependency,
request_model: OrgPatchStatusRequest,
):
"""
Sets an organisation's status. This is the endpoint for approving or denying an organisation after reviewing the questionnaire.
"""
@ -170,33 +241,57 @@ async def update_status(db: db_dependency, org_model: org_model_body_dependency,
return response
@router.get("/users",
summary="Get email addresses of users of the organisation.",
status_code=status.HTTP_200_OK,
response_model=OrgGetUserResponse,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval of users."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Org ID missing or invalid."},
status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
})
@router.get(
"/users",
summary="Get email addresses of users of the organisation.",
status_code=status.HTTP_200_OK,
response_model=OrgGetUserResponse,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval of users."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Org ID missing or invalid."
},
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be org root user."
},
},
)
async def get_users(org_model: org_model_root_claim_query_dependency):
"""
Returns a list of the email addresses of all users of the organisation.
"""
return {"users": [user.email for user in org_model.user_rel], "organisation": org_model}
return {
"users": [user.email for user in org_model.user_rel],
"organisation": org_model,
}
@router.post("/user",
summary="Add user to the organisation.",
status_code=status.HTTP_200_OK,
response_model=OrgPostUserResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully added user to the organisation."},
status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
status.HTTP_409_CONFLICT: {"description": "User is already a member of the organisation."},
})
async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgPostUserRequest):
@router.post(
"/user",
summary="Add user to the organisation.",
status_code=status.HTTP_200_OK,
response_model=OrgPostUserResponse,
responses={
status.HTTP_200_OK: {
"description": "Successfully added user to the organisation."
},
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be org root user."
},
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Invalid data in request."
},
status.HTTP_409_CONFLICT: {
"description": "User is already a member of the organisation."
},
},
)
async def add_user_to_org(
db: db_dependency,
org_model: org_model_root_claim_body_dependency,
user_model: user_model_body_dependency,
request_model: OrgPostUserRequest,
):
"""
Adds a user to the organisation.
"""
@ -209,15 +304,28 @@ async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_bod
return response
@router.delete("",
summary="Delete organisation from the hub.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."},
status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be super admin."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Org ID missing or invalid."},
})
async def delete_organisation_by_id(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgDeleteOrgRequest):
@router.delete(
"",
summary="Delete organisation from the hub.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {
"description": "Successfully deleted organisation."
},
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be super admin."
},
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Org ID missing or invalid."
},
},
)
async def delete_organisation_by_id(
db: db_dependency,
org_model: org_model_body_dependency,
su: super_admin_dependency,
request_model: OrgDeleteOrgRequest,
):
"""
Removes an organisation from the hub.
"""
@ -225,37 +333,59 @@ async def delete_organisation_by_id(db: db_dependency, org_model: org_model_body
db.commit()
@router.patch("/root_user",
summary="Update the root user of the organisation.",
status_code=status.HTTP_200_OK,
response_model=OrgPatchRootResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully updated root user."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be super admin."},
})
async def update_root_user(db: db_dependency, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchRootRequest):
@router.patch(
"/root_user",
summary="Update the root user of the organisation.",
status_code=status.HTTP_200_OK,
response_model=OrgPatchRootResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully updated root user."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Invalid data in request."
},
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be super admin."
},
},
)
async def update_root_user(
db: db_dependency,
org_model: org_model_body_dependency,
user_model: user_model_body_dependency,
su: super_admin_dependency,
request_model: OrgPatchRootRequest,
):
"""
Promotes an existing organisation user to the root user, giving them full control of the org.
"""
if user_model not in org_model.user_rel:
raise UnprocessableContentException(message="This user does not belong to your organisation.")
raise UnprocessableContentException(
message="This user does not belong to your organisation."
)
org_model.root_user_rel = user_model
db.flush()
response = OrgPatchRootResponse(name=org_model.name, root_user_email=org_model.root_user_email)
response = OrgPatchRootResponse(
name=org_model.name, root_user_email=org_model.root_user_email
)
db.commit()
return response
@router.get("/groups",
summary="Get all organisation IAM groups.",
status_code=status.HTTP_200_OK,
response_model=OrgGetGroupResponse,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval of IAM groups."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Org ID missing or invalid."},
status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
})
@router.get(
"/groups",
summary="Get all organisation IAM groups.",
status_code=status.HTTP_200_OK,
response_model=OrgGetGroupResponse,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval of IAM groups."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Org ID missing or invalid."
},
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be org root user."
},
},
)
async def get_org_groups(org_model: org_model_root_claim_query_dependency):
"""
Returns a list of the names of all IAM groups created by the organisation.
@ -263,15 +393,26 @@ async def get_org_groups(org_model: org_model_root_claim_query_dependency):
return {"groups": [group.name for group in org_model.group_rel]}
@router.delete("/user",
summary="Remove user from organisation.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed user."},
status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
})
async def remove_user_from_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgDeleteUserRequest):
@router.delete(
"/user",
summary="Remove user from organisation.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed user."},
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be org root user."
},
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Invalid data in request."
},
},
)
async def remove_user_from_org(
db: db_dependency,
org_model: org_model_root_claim_body_dependency,
user_model: user_model_body_dependency,
request_model: OrgDeleteUserRequest,
):
"""
Revokes a user's membership in an organisation.
"""
@ -282,16 +423,27 @@ async def remove_user_from_org(db: db_dependency, org_model: org_model_root_clai
db.commit()
@router.get("/contact",
summary="Get contact for organisation.",
status_code=status.HTTP_200_OK,
response_model=OrgGetContactResponse,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval of contact."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
})
async def get_contact(org_model: org_model_root_claim_query_dependency, contact_type: Annotated[ContactType, Query(description="Must be billing|security|owner")]):
@router.get(
"/contact",
summary="Get contact for organisation.",
status_code=status.HTTP_200_OK,
response_model=OrgGetContactResponse,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval of contact."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Invalid data in request."
},
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be org root user."
},
},
)
async def get_contact(
org_model: org_model_root_claim_query_dependency,
contact_type: Annotated[
ContactType, Query(description="Must be billing|security|owner")
],
):
"""
Gets full details for a contact point at an organisation.
"""
@ -309,21 +461,33 @@ async def get_contact(org_model: org_model_root_claim_query_dependency, contact_
raise ContactNotFoundException()
address = ContactAddress.model_validate(contact_model)
contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address)
contact_response = ContactModel.model_construct(
**contact_model.__dict__, address=address
)
return {"contact": contact_response, "organisation": org_model}
@router.patch("/contact",
summary="Update contact for organisation.",
status_code=status.HTTP_200_OK,
response_model=OrgPatchContactResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully updated contact."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."},
status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."},
})
async def update_contact(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchContactRequest):
@router.patch(
"/contact",
summary="Update contact for organisation.",
status_code=status.HTTP_200_OK,
response_model=OrgPatchContactResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully updated contact."},
status.HTTP_422_UNPROCESSABLE_CONTENT: {
"description": "Invalid data in request."
},
status.HTTP_401_UNAUTHORIZED: {
"description": "Not authorised. Must be org root user."
},
},
)
async def update_contact(
db: db_dependency,
org_model: org_model_root_claim_body_dependency,
request_model: OrgPatchContactRequest,
):
"""
Updates details for a contact point at an organisation.
"""
@ -351,7 +515,9 @@ async def update_contact(db: db_dependency, org_model: org_model_root_claim_body
db.flush()
address = ContactAddress.model_validate(contact_model)
contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address)
contact_response = ContactModel.model_construct(
**contact_model.__dict__, address=address
)
db.commit()

View file

@ -6,6 +6,7 @@ Models follow the nomenclature of:
- Mixins: "<Attribute>Mixin"
- Models: "<Module><Method><Resource><Opt:Resource><Direction>" ie "OrgPostOrgRequest"
"""
from typing import Optional
from pydantic import EmailStr, ConfigDict, Field
@ -20,11 +21,13 @@ from src.organisation.constants import Status, ContactType
class OrgIDMixin(CustomBaseModel):
organisation_id: int = Field(gt=0)
class Questionnaire(CustomBaseModel):
question_one: Optional[str] = None
question_two: Optional[str] = None
question_three: Optional[str] = None
class OrgSchema(CustomBaseModel):
id: int
name: str
@ -34,26 +37,32 @@ class OrgPostOrgRequest(CustomBaseModel):
name: str
intake_questionnaire: Optional[Questionnaire] = None
class OrgPostOrgResponse(CustomBaseModel):
name: str
status: Status
class OrgPatchQuestionnaireRequest(OrgIDMixin):
intake_questionnaire: Questionnaire
partial: bool
class OrgPatchQuestionnaireResponse(CustomBaseModel):
name: str
intake_questionnaire: Questionnaire
status: Status
class OrgPatchStatusRequest(OrgIDMixin):
status: Status
class OrgPatchStatusResponse(CustomBaseModel):
name: str
status: Status
class OrgPatchContactRequest(OrgIDMixin):
contact_type: ContactType
@ -70,41 +79,51 @@ class OrgPatchContactRequest(OrgIDMixin):
country_code: Optional[str] = None
postal_code: Optional[str] = None
class OrgPostUserRequest(OrgIDMixin, UserIDMixin):
pass
class OrgPostUserResponse(CustomBaseModel):
users: list[str]
class OrgDeleteUserRequest(OrgIDMixin, UserIDMixin):
pass
class OrgPatchRootRequest(OrgIDMixin, UserIDMixin):
pass
class OrgPatchRootResponse(CustomBaseModel):
name: str
root_user_email: str
class OrgGetUserResponse(CustomBaseModel):
users: list[str]
organisation: OrgSchema
class OrgGetGroupResponse(CustomBaseModel):
groups: list[str]
class OrgGetContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel
organisation: OrgSchema
class OrgPatchContactResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel
organisation: OrgSchema
class OrgGetOrgResponse(CustomBaseModel):
id: int
name: str
@ -115,5 +134,6 @@ class OrgGetOrgResponse(CustomBaseModel):
security_contact: Optional[str] = None
intake_questionnaire: Optional[Questionnaire] = None
class OrgDeleteOrgRequest(OrgIDMixin):
pass

View file

@ -1,3 +1,3 @@
"""
Reusable business logic functions for the organisation module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Non-business logic reusable functions and classes for the organisation module
"""
"""

View file

@ -5,6 +5,7 @@ Exports:
- CustomBaseModel: Schema used for all other Pydantic models
- ResourceName
"""
from pydantic import BaseModel
from typing import Optional

View file

@ -1,3 +1,3 @@
"""
Configurations for the services module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Constants for the services module
"""
"""

View file

@ -5,6 +5,7 @@ Exports:
- service_model_query_dependency: service_model: Gets service model from db, if it exists. Uses service_id from query param.
- service_model_body_dependency: service_model: Gets service model from db, if it exists. Uses service_id from request body.
"""
from typing import Annotated
from fastapi import Depends, Query
@ -15,14 +16,19 @@ from src.service.models import Service
from src.service.schemas import ServiceIDMixin
async def get_service_model_query(db: db_dependency, service_id: Annotated[int, Query(gt=0)]):
async def get_service_model_query(
db: db_dependency, service_id: Annotated[int, Query(gt=0)]
):
service_model = db.get(Service, service_id)
if service_model is None:
raise ServiceNotFoundException(service_id=service_id)
return service_model
service_model_query_dependency = Annotated[type[Service], Depends(get_service_model_query)]
service_model_query_dependency = Annotated[
type[Service], Depends(get_service_model_query)
]
async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixin):
@ -32,4 +38,7 @@ async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixi
return service_model
service_model_body_dependency = Annotated[type[Service], Depends(get_service_model_body)]
service_model_body_dependency = Annotated[
type[Service], Depends(get_service_model_body)
]

View file

@ -4,6 +4,7 @@ Exceptions related to the services module
Exceptions:
- ServiceNotFoundException: Takes an optional service_id int
"""
from typing import Optional
from fastapi import HTTPException, status
@ -11,7 +12,11 @@ from fastapi import HTTPException, status
class ServiceNotFoundException(HTTPException):
def __init__(self, service_id: Optional[int] = None) -> None:
detail = "Service not found" if service_id is None else f"Service with ID '{service_id}' was not found."
detail = (
"Service not found"
if service_id is None
else f"Service with ID '{service_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,

View file

@ -5,6 +5,7 @@ Models:
- Service:
- id[PK], name[U], api_key[U]
"""
from sqlalchemy import Column, Integer, String
from src.database import Base

View file

@ -7,19 +7,30 @@ Endpoints:
- [PATCH](/key): [super_admin]: Refreshes the API key for a service(id), returning a new one.
- [DELETE](/): [super_admin]: Removes a service(id) from the hub.
"""
from fastapi import APIRouter, status
from psycopg.errors import UniqueViolation
from sqlalchemy.exc import IntegrityError
from src.exceptions import ConflictException
from src.database import db_dependency
from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency
from src.auth.dependencies import (
super_admin_dependency,
org_model_root_claim_query_dependency,
)
from src.service.models import Service
from src.service.utils import generate_api_key
from src.service.dependencies import service_model_body_dependency
from src.service.schemas import ServiceGetServiceResponse, ServicePostServiceRequest, ServicePostServiceResponse, \
ServiceWithKeySchema, ServicePatchKeyResponse, ServicePatchKeyRequest, ServiceDeleteServiceRequest
from src.service.schemas import (
ServiceGetServiceResponse,
ServicePostServiceRequest,
ServicePostServiceResponse,
ServiceWithKeySchema,
ServicePatchKeyResponse,
ServicePatchKeyRequest,
ServiceDeleteServiceRequest,
)
router = APIRouter(
tags=["Service"],
@ -27,15 +38,19 @@ router = APIRouter(
)
@router.get("/",
summary="Get all services",
status_code=status.HTTP_200_OK,
response_model=ServiceGetServiceResponse,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
})
async def get_all_services(db: db_dependency, org_model: org_model_root_claim_query_dependency):
@router.get(
"/",
summary="Get all services",
status_code=status.HTTP_200_OK,
response_model=ServiceGetServiceResponse,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
)
async def get_all_services(
db: db_dependency, org_model: org_model_root_claim_query_dependency
):
"""
Returns the ID and name of all services registered to the hub.
"""
@ -44,16 +59,24 @@ async def get_all_services(db: db_dependency, org_model: org_model_root_claim_qu
return {"services": permission_models}
@router.post("/",
summary="Register a new service.",
status_code=status.HTTP_200_OK,
response_model=ServicePostServiceResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully registered a new service"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"},
})
async def register_service(db: db_dependency, su: super_admin_dependency, request_model: ServicePostServiceRequest):
@router.post(
"/",
summary="Register a new service.",
status_code=status.HTTP_200_OK,
response_model=ServicePostServiceResponse,
responses={
status.HTTP_200_OK: {"description": "Successfully registered a new service"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_409_CONFLICT: {
"description": "Service with this name already exists"
},
},
)
async def register_service(
db: db_dependency,
su: super_admin_dependency,
request_model: ServicePostServiceRequest,
):
"""
Registers a new service to the hub, generating and returning an API key for it.
"""
@ -71,16 +94,22 @@ async def register_service(db: db_dependency, su: super_admin_dependency, reques
return {"service": response}
@router.patch("/key",
summary="Regenerate service API key.",
status_code=status.HTTP_200_OK,
response_model=ServicePatchKeyResponse,
responses={
status.HTTP_200_OK: {"description": "Successful update of API key"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
})
async def regenerate_api_key(db: db_dependency, su: super_admin_dependency,
service_model: service_model_body_dependency, request_model: ServicePatchKeyRequest):
@router.patch(
"/key",
summary="Regenerate service API key.",
status_code=status.HTTP_200_OK,
response_model=ServicePatchKeyResponse,
responses={
status.HTTP_200_OK: {"description": "Successful update of API key"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
)
async def regenerate_api_key(
db: db_dependency,
su: super_admin_dependency,
service_model: service_model_body_dependency,
request_model: ServicePatchKeyRequest,
):
"""
Generates and returns a new API key for the service to access the hub.
"""
@ -93,15 +122,23 @@ async def regenerate_api_key(db: db_dependency, su: super_admin_dependency,
return {"service": response}
@router.delete("/",
summary="Remove a service.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
})
async def remove_service(db: db_dependency, service_model: service_model_body_dependency, su: super_admin_dependency,
request_model: ServiceDeleteServiceRequest):
@router.delete(
"/",
summary="Remove a service.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {
"description": "Successfully removed service from db"
},
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
},
)
async def remove_service(
db: db_dependency,
service_model: service_model_body_dependency,
su: super_admin_dependency,
request_model: ServiceDeleteServiceRequest,
):
"""
Removes a service from the hub.
"""

View file

@ -6,36 +6,46 @@ Models follow the nomenclature of:
- Mixins: "<Attribute>Mixin"
- Models: "<Module><Method><Resource><Opt:Resource><Direction>" ie "ServiceGetServiceResponse"
"""
from pydantic import ConfigDict, Field
from src.schemas import CustomBaseModel
class ServiceIDMixin(CustomBaseModel):
service_id: int = Field(gt=0)
class ServiceSchema(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore")
id: int
name: str
class ServiceWithKeySchema(ServiceSchema):
api_key: str
class ServiceGetServiceResponse(CustomBaseModel):
services: list[ServiceSchema]
class ServicePostServiceRequest(CustomBaseModel):
name: str
class ServicePostServiceResponse(CustomBaseModel):
service: ServiceWithKeySchema
class ServicePatchKeyRequest(ServiceIDMixin):
pass
class ServicePatchKeyResponse(CustomBaseModel):
service: ServiceWithKeySchema
class ServiceDeleteServiceRequest(ServiceIDMixin):
pass

View file

@ -1,3 +1,3 @@
"""
Business logic for the services module
"""
"""

View file

@ -4,6 +4,7 @@ Non-business logic reusable functions and classes for the services module
Exports:
- generate_api_key(): returns a new UUID
"""
import uuid

View file

@ -1,3 +1,3 @@
"""
Configurations for the user module
"""
"""

View file

@ -1,3 +1,3 @@
"""
Constants for the user module
"""
"""

View file

@ -6,6 +6,7 @@ Exports:
- user_model_query_dependency: user_model: Gets user model from db, if it exists. Uses user_id from query param
- user_model_body_dependency: user_model: Gets user model from db, if it exists. Uses user_id from request body.
"""
from typing import Annotated
from fastapi import Depends, Query
@ -28,6 +29,7 @@ async def get_user_model_claims(claims: claims_dependency, db: db_dependency):
return user_model
user_model_claims_dependency = Annotated[type[User], Depends(get_user_model_claims)]
@ -38,6 +40,7 @@ async def get_user_model_query(db: db_dependency, user_id: Annotated[int, Query(
return user_model
user_model_query_dependency = Annotated[type[User], Depends(get_user_model_query)]
@ -48,4 +51,5 @@ async def get_user_model_body(db: db_dependency, request_model: UserIDMixin):
return user_model
user_model_body_dependency = Annotated[type[User], Depends(get_user_model_body)]

View file

@ -4,6 +4,7 @@ Exceptions related to the user module
Exceptions:
- UserNotFoundException: Takes an optional user_id int
"""
from typing import Optional
from fastapi import HTTPException, status
@ -11,7 +12,11 @@ from fastapi import HTTPException, status
class UserNotFoundException(HTTPException):
def __init__(self, user_id: Optional[int] = None) -> None:
detail = "User not found" if user_id is None else f"User with ID '{user_id}' was not found."
detail = (
"User not found"
if user_id is None
else f"User with ID '{user_id}' was not found."
)
super().__init__(
status_code=status.HTTP_404_NOT_FOUND,
detail=detail,

View file

@ -9,6 +9,7 @@ Models:
- organisations: Calc property list of organisation_rel.name
- groups: Calc property dict of {group_rel.org_rel.name: group_rel.name}
"""
from collections import defaultdict
from sqlalchemy import Column, Integer, String
@ -18,29 +19,29 @@ from src.database import Base
class User(Base):
__tablename__ = "user"
__tablename__ = "user"
id = Column(Integer, primary_key=True)
email = Column(String)
first_name = Column(String)
last_name = Column(String)
oidc_id = Column(String, index=True, unique=True)
id = Column(Integer, primary_key=True)
email = Column(String)
first_name = Column(String)
last_name = Column(String)
oidc_id = Column(String, index=True, unique=True)
organisation_rel = relationship(
"Organisation", secondary="orgusers", back_populates="user_rel"
)
organisation_rel = relationship(
"Organisation", secondary="orgusers", back_populates="user_rel"
)
@property
def organisations(self):
return [{"name": org.name, "id": org.id} for org in self.organisation_rel]
@property
def organisations(self):
return [{"name": org.name, "id": org.id} for org in self.organisation_rel]
group_rel = relationship(
"Group", secondary="user_groups", back_populates="user_rel"
)
group_rel = relationship(
"Group", secondary="user_groups", back_populates="user_rel"
)
@property
def groups(self):
result = defaultdict(list)
for group in self.group_rel:
result[group.org_rel.name].append({"name": group.name, "id": group.id})
return dict(result)
@property
def groups(self):
result = defaultdict(list)
for group in self.group_rel:
result[group.org_rel.name].append({"name": group.name, "id": group.id})
return dict(result)

View file

@ -7,11 +7,16 @@ Endpoints:
- [GET](/user/): [super admin]: Returns user(id) details.
- [DELETE](/user/): [super admin]: Removes a User(id) from the hub database.
"""
from fastapi import APIRouter
from starlette import status
from src.user.schemas import UserResponse, OIDCClaims, UserDeleteUserRequest
from src.user.dependencies import user_model_claims_dependency, user_model_query_dependency, user_model_body_dependency
from src.user.dependencies import (
user_model_claims_dependency,
user_model_query_dependency,
user_model_body_dependency,
)
from src.auth.dependencies import super_admin_dependency
from src.auth.service import claims_dependency
@ -23,13 +28,15 @@ router = APIRouter(
)
@router.get("/self/claims",
summary="Get current user OIDC claims.",
response_model=OIDCClaims,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
@router.get(
"/self/claims",
summary="Get current user OIDC claims.",
response_model=OIDCClaims,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
)
async def current_user_claims(user: claims_dependency):
"""
Returns the full OIDC claims associated with the currently logged-in user.
@ -38,14 +45,16 @@ async def current_user_claims(user: claims_dependency):
return user
@router.get("/self/db",
summary="Get current user hub details.",
response_model=UserResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
@router.get(
"/self/db",
summary="Get current user hub details.",
response_model=UserResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
)
async def current_user(user_model: user_model_claims_dependency):
"""
Returns the database details associated with the currently logged-in user.
@ -53,30 +62,40 @@ async def current_user(user_model: user_model_claims_dependency):
return user_model
@router.get("/",
summary="Get user hub details by ID.",
response_model=UserResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
})
async def get_user_by_id(user_model: user_model_query_dependency, su: super_admin_dependency):
@router.get(
"/",
summary="Get user hub details by ID.",
response_model=UserResponse,
status_code=status.HTTP_200_OK,
responses={
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
status.HTTP_200_OK: {"description": "Successful retrieval from database"},
},
)
async def get_user_by_id(
user_model: user_model_query_dependency, su: super_admin_dependency
):
"""
Returns the database details associated with the provided user ID.
"""
return user_model
@router.delete("/",
summary="Delete user from hub by ID.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
})
async def delete_user_by_id(db: db_dependency, user_model: user_model_body_dependency, su: super_admin_dependency,
request_model: UserDeleteUserRequest):
@router.delete(
"/",
summary="Delete user from hub by ID.",
status_code=status.HTTP_204_NO_CONTENT,
responses={
status.HTTP_204_NO_CONTENT: {"description": "User deleted"},
status.HTTP_404_NOT_FOUND: {"description": "User not found"},
},
)
async def delete_user_by_id(
db: db_dependency,
user_model: user_model_body_dependency,
su: super_admin_dependency,
request_model: UserDeleteUserRequest,
):
"""
Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login.
"""

View file

@ -1,6 +1,7 @@
"""
Pydantic models for the user module
"""
from typing import Optional
from pydantic import Field
@ -47,8 +48,8 @@ class UserResponse(CustomBaseModel):
first_name: str
last_name: str
email: str
organisations: list[Optional[dict[str, str|int]]]
groups: Optional[dict[str, list[dict[str, str|int]]]] = None
organisations: list[Optional[dict[str, str | int]]]
groups: Optional[dict[str, list[dict[str, str | int]]]] = None
class OrgResponse(CustomBaseModel):
@ -57,4 +58,4 @@ class OrgResponse(CustomBaseModel):
class UserDeleteUserRequest(UserIDMixin):
pass
pass

View file

@ -4,6 +4,7 @@ Module specific business logic for user module
Exports:
- add_user_to_db: Creates a User record from OIDC claims, or updates user details
"""
from typing import Any
from sqlalchemy.orm import Session
@ -16,7 +17,12 @@ from src.user.models import User
async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
try:
valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"])
valid_user = OIDCUser(
first_name=user_claims["given_name"],
last_name=user_claims["family_name"],
email=user_claims["email"],
oidc_id=user_claims["sub"],
)
except Exception as e:
print(e)
raise UnprocessableContentException("Invalid or missing OIDC data")

View file

@ -1,3 +1,3 @@
"""
Non-business logic reusable functions and classes for the user module
"""
"""

View file

@ -37,11 +37,14 @@ def db_session():
async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override():
return db_session
app.dependency_overrides[get_db] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_super_admin_list] = testing_su_list
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac:
async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac
app.dependency_overrides.clear()
@ -51,37 +54,58 @@ async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]:
async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override():
return db_session
app.dependency_overrides[get_db] = get_db_override
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac:
async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac
app.dependency_overrides.clear()
@pytest.fixture
async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]:
def get_db_override():
return db_session
app.dependency_overrides[get_db] = get_db_override
app.dependency_overrides[get_current_user] = get_dev_user
app.dependency_overrides[get_super_admin_list] = empty_su_list
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac:
async with AsyncClient(
transport=transport, base_url="http://localhost:8000/api/v1"
) as ac:
yield ac
app.dependency_overrides.clear()
def _seed(db):
db.add(User(email="admin@test.com", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-mnop"))
db.add(
User(
email="admin@test.com",
first_name="Admin",
last_name="Test",
oidc_id="abcd-efgh-ijkl-mnop",
)
)
db.add(Contact(org_id=1, email="billing@test.org", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="owner@test.org", phonenumber="07521539927"))
db.add(Contact(org_id=1, email="security@test.org", phonenumber="07521539927"))
db.flush()
db.add(Org(name="Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3,
status="approved", intake_questionnaire={"question_two": "answer two"}))
db.add(
Org(
name="Test Org",
root_user_id=1,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
intake_questionnaire={"question_two": "answer two"},
)
)
db.add(Service(name="Test Service", api_key="123456789"))
db.add(Permission(service_id=1, resource="test_resource", action="read"))
db.add(Group(name="Test Group", org_id=1))
@ -131,6 +155,7 @@ def generate_query_and_status(params) -> list[tuple[str, int]]:
return query_and_status
# # Produces a text file with method and path for every endpoint in the API
# from fastapi.routing import APIRoute
#

View file

@ -3,6 +3,7 @@ This test module checks relevant endpoints to ensure only approved orgs get acce
Endpoints not checked here are endpoints that do not require an org check.
Delete endpoints are currently skipped because the testing system cannot use bodies in deletes.
"""
import pytest
from httpx import AsyncClient
@ -27,18 +28,27 @@ async def test_get_org_auth_approval(default_client: AsyncClient):
@pytest.mark.anyio
async def test_patch_org_questionnaire_auth_approval(default_client: AsyncClient):
resp = await default_client.patch("/org/questionnaire", json={"organisation_id": 1,
"intake_questionnaire": {"question_one": "new answer one",
"question_two": None,
"question_three": None},
"partial": True})
resp = await default_client.patch(
"/org/questionnaire",
json={
"organisation_id": 1,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
assert resp.status_code != 422
assert resp.status_code == 200
@pytest.mark.anyio
async def test_patch_org_status_auth_approval(default_client: AsyncClient):
resp = await default_client.patch("/org/status", json={"organisation_id": 1, "status": "submitted"})
resp = await default_client.patch(
"/org/status", json={"organisation_id": 1, "status": "submitted"}
)
assert resp.status_code != 422
assert resp.status_code == 200
@ -52,22 +62,42 @@ async def test_get_org_users_auth_approval(default_client: AsyncClient):
@pytest.mark.anyio
async def test_post_org_user_auth_approval(default_client: AsyncClient, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush()
resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 2})
resp = await default_client.post(
"/org/user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_patch_org_root_user_auth_approval(default_client: AsyncClient, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
async def test_patch_org_root_user_auth_approval(
default_client: AsyncClient, db_session
):
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush()
db_session.add(OrgUsers(org_id=1, user_id=2))
db_session.flush()
resp = await default_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2})
resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@ -88,8 +118,14 @@ async def test_get_org_contact_auth_approval(default_client: AsyncClient):
@pytest.mark.anyio
async def test_patch_org_contact_auth_approval(default_client: AsyncClient):
resp = await default_client.patch("/org/contact",
json={"organisation_id": 1, "contact_type": "billing", "email": "user@example.com"})
resp = await default_client.patch(
"/org/contact",
json={
"organisation_id": 1,
"contact_type": "billing",
"email": "user@example.com",
},
)
assert resp.status_code != 422
assert resp.status_code == 200
@ -117,26 +153,44 @@ async def test_get_iam_group_users_auth_approval(default_client: AsyncClient):
@pytest.mark.anyio
async def test_post_iam_group_auth_approval(default_client: AsyncClient):
resp = await default_client.post("/iam/group", json={"name": "New Group", "organisation_id": 1})
resp = await default_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 1}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_put_iam_group_permission_auth_approval(default_client: AsyncClient, db_session):
async def test_put_iam_group_permission_auth_approval(
default_client: AsyncClient, db_session
):
db_session.add(Group(name="Test Group Two", org_id=1))
db_session.flush()
resp = await default_client.put("/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 1})
resp = await default_client.put(
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 1},
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@pytest.mark.anyio
async def test_put_iam_group_user_auth_approval(default_client: AsyncClient, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
async def test_put_iam_group_user_auth_approval(
default_client: AsyncClient, db_session
):
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush()
resp = await default_client.put("/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1})
resp = await default_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]
@ -150,6 +204,8 @@ async def test_get_iam_permissions_auth_approval(default_client: AsyncClient):
@pytest.mark.anyio
async def test_post_iam_permissions_search_auth_approval(default_client: AsyncClient):
resp = await default_client.post("/iam/permissions/search", json={"organisation_id": 1, "action": "read"})
resp = await default_client.post(
"/iam/permissions/search", json={"organisation_id": 1, "action": "read"}
)
assert resp.status_code != 422
assert "has not been approved." in resp.json()["detail"]

View file

@ -1,5 +1,5 @@
"""
"""
""" """
import pytest
from httpx import AsyncClient
@ -10,11 +10,26 @@ from src.user.models import User
@pytest.mark.anyio
async def test_get_org_auth_root_su(default_client: AsyncClient, db_session):
# If a super admin can access a resource when not the root user
db_session.add(User(email="admin@test.org", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-4321"))
db_session.add(
User(
email="admin@test.org",
first_name="Admin",
last_name="Test",
oidc_id="abcd-efgh-ijkl-4321",
)
)
db_session.flush()
db_session.add(
Org(name="Test Org Two", root_user_id=2, billing_contact_id=1, owner_contact_id=2, security_contact_id=3,
status="approved", intake_questionnaire={}))
Org(
name="Test Org Two",
root_user_id=2,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
intake_questionnaire={},
)
)
db_session.flush()
resp = await default_client.get("/org?org_id=2")

View file

@ -2,6 +2,7 @@
This module ensures root user only endpoints do return a correctly formatted 401 when user is not the root user for the org
DELETE endpoints are not tested
"""
import pytest
from httpx import AsyncClient
@ -12,10 +13,26 @@ from src.iam.models import Group
@pytest.fixture(autouse=True)
def add_second_org(db_session):
db_session.add(User(email="admin@test.org", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-4321"))
db_session.add(
User(
email="admin@test.org",
first_name="Admin",
last_name="Test",
oidc_id="abcd-efgh-ijkl-4321",
)
)
db_session.flush()
db_session.add(Org(name="Test Org Two", root_user_id=2, billing_contact_id=1, owner_contact_id=2, security_contact_id=3,
status="approved", intake_questionnaire={}))
db_session.add(
Org(
name="Test Org Two",
root_user_id=2,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
intake_questionnaire={},
)
)
db_session.flush()
@ -29,11 +46,18 @@ async def test_get_org_auth_root(no_su_client: AsyncClient):
@pytest.mark.anyio
async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.patch("/org/questionnaire", json={"organisation_id": 2,
"intake_questionnaire": {"question_one": "new answer one",
"question_two": None,
"question_three": None},
"partial": True})
resp = await no_su_client.patch(
"/org/questionnaire",
json={
"organisation_id": 2,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
assert resp.status_code != 422
assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"]
@ -49,10 +73,19 @@ async def test_get_org_users_auth_root(no_su_client: AsyncClient):
@pytest.mark.anyio
async def test_post_org_user_auth_root(no_su_client: AsyncClient, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush()
resp = await no_su_client.post("/org/user", json={"organisation_id": 2, "user_id": 2})
resp = await no_su_client.post(
"/org/user", json={"organisation_id": 2, "user_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"]
@ -76,8 +109,14 @@ async def test_get_org_contact_auth_root(no_su_client: AsyncClient):
@pytest.mark.anyio
async def test_patch_org_contact_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.patch("/org/contact",
json={"organisation_id": 2, "contact_type": "billing", "email": "user@example.com"})
resp = await no_su_client.patch(
"/org/contact",
json={
"organisation_id": 2,
"contact_type": "billing",
"email": "user@example.com",
},
)
assert resp.status_code != 422
assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"]
@ -109,17 +148,24 @@ async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient):
@pytest.mark.anyio
async def test_post_iam_group_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.post("/iam/group", json={"name": "New Group", "organisation_id": 2})
resp = await no_su_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"]
@pytest.mark.anyio
async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient, db_session):
async def test_put_iam_group_permission_auth_root(
no_su_client: AsyncClient, db_session
):
db_session.add(Group(name="Test Group Two", org_id=2))
db_session.flush()
resp = await no_su_client.put("/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 2})
resp = await no_su_client.put(
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 2},
)
assert resp.status_code != 422
assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"]
@ -127,10 +173,19 @@ async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient, db_
@pytest.mark.anyio
async def test_put_iam_group_user_auth_root(no_su_client: AsyncClient, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush()
resp = await no_su_client.put("/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2})
resp = await no_su_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"]
@ -146,7 +201,9 @@ async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient):
@pytest.mark.anyio
async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient):
resp = await no_su_client.post("/iam/permissions/search", json={"organisation_id": 2, "action": "read"})
resp = await no_su_client.post(
"/iam/permissions/search", json={"organisation_id": 2, "action": "read"}
)
assert resp.status_code != 422
assert resp.status_code == 401
assert "Must be the org's root user" in resp.json()["detail"]

View file

@ -2,6 +2,7 @@
This module ensures super admin only endpoints do return a correctly formatted 401 when user is not a super admin
DELETE endpoints are not tested
"""
import pytest
from httpx import AsyncClient
@ -19,7 +20,9 @@ async def test_get_user_auth_su(no_su_client: AsyncClient):
@pytest.mark.anyio
async def test_patch_org_status_auth_su(no_su_client: AsyncClient):
resp = await no_su_client.patch("/org/status", json={"organisation_id": 1, "status": "submitted"})
resp = await no_su_client.patch(
"/org/status", json={"organisation_id": 1, "status": "submitted"}
)
assert resp.status_code != 422
assert resp.status_code == 401
assert resp.json()["detail"] == "Must be super admin"
@ -27,12 +30,21 @@ async def test_patch_org_status_auth_su(no_su_client: AsyncClient):
@pytest.mark.anyio
async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush()
db_session.add(OrgUsers(org_id=1, user_id=2))
db_session.flush()
resp = await no_su_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2})
resp = await no_su_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
assert resp.status_code != 422
assert resp.status_code == 401
assert resp.json()["detail"] == "Must be super admin"
@ -56,7 +68,10 @@ async def test_post_service_auth_su(no_su_client: AsyncClient):
@pytest.mark.anyio
async def test_post_perm_success(no_su_client: AsyncClient, db_session):
resp = await no_su_client.post("/iam/permission", json={"service_id": 1, "resource": "test_resource", "action": "create"})
resp = await no_su_client.post(
"/iam/permission",
json={"service_id": 1, "resource": "test_resource", "action": "create"},
)
assert resp.status_code != 422
assert resp.status_code == 401
assert resp.json()["detail"] == "Must be super admin"

View file

@ -1,6 +1,7 @@
"""
This testing module removes the testing user override to verify that endpoints with only the user requirement return a 401 error when not logged in
"""
import pytest
from httpx import AsyncClient

View file

@ -4,7 +4,7 @@ from httpx import AsyncClient
@pytest.mark.anyio
async def test_healthcheck(default_client: AsyncClient):
resp = await default_client.get("/healthcheck")
resp = await default_client.get("/healthcheck")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}

View file

@ -1,5 +1,5 @@
"""
"""
""" """
import pytest
from httpx import AsyncClient
@ -15,13 +15,15 @@ async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient
body = {
"service": "Test Service",
"organisation": "Test Org",
"resource": "test_resource"
"resource": "test_resource",
}
headers = {
"Authorization": "Bearer not_checked_when_auth_is_disabled",
"X-API-Key": "123456789"
"X-API-Key": "123456789",
}
resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers)
resp = await default_client.post(
"/iam/can_act_on_resource?action=read", json=body, headers=headers
)
data = resp.json()
assert resp.status_code == 200
@ -30,23 +32,20 @@ async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient
@pytest.mark.parametrize(
"service, api_key",
[
("Test Service", "not_the_correct_key"),
("Test Service Two", "123456789")
],
[("Test Service", "not_the_correct_key"), ("Test Service Two", "123456789")],
)
@pytest.mark.anyio
async def test_act_on_resource_wrong_key(default_client: AsyncClient, db_session, service: str, api_key: str):
body = {
"service": service,
"organisation": "Test Org",
"resource": "test_resource"
}
async def test_act_on_resource_wrong_key(
default_client: AsyncClient, db_session, service: str, api_key: str
):
body = {"service": service, "organisation": "Test Org", "resource": "test_resource"}
headers = {
"Authorization": "Bearer not_checked_when_auth_is_disabled",
"X-API-Key": api_key
"X-API-Key": api_key,
}
resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers)
resp = await default_client.post(
"/iam/can_act_on_resource?action=read", json=body, headers=headers
)
data = resp.json()
assert resp.status_code == 401
@ -58,12 +57,12 @@ async def test_act_on_resource_missing_key(default_client: AsyncClient):
body = {
"service": "Test Service",
"organisation": "Test Org",
"resource": "test_resource"
"resource": "test_resource",
}
headers = {
"Authorization": "Bearer not_checked_when_auth_is_disabled"
}
resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers)
headers = {"Authorization": "Bearer not_checked_when_auth_is_disabled"}
resp = await default_client.post(
"/iam/can_act_on_resource?action=read", json=body, headers=headers
)
data = resp.json()
assert resp.status_code == 401
@ -82,18 +81,17 @@ async def test_act_on_resource_missing_key(default_client: AsyncClient):
],
)
@pytest.mark.anyio
async def test_act_on_resource_endpoint_status_checks(default_client: AsyncClient, service, org, resource, action,
expected_status: int):
body = {
"service": service,
"organisation": org,
"resource": resource
}
async def test_act_on_resource_endpoint_status_checks(
default_client: AsyncClient, service, org, resource, action, expected_status: int
):
body = {"service": service, "organisation": org, "resource": resource}
headers = {
"Authorization": "Bearer not_checked_when_auth_is_disabled",
"X-API-Key": "123456789"
"X-API-Key": "123456789",
}
resp = await default_client.post(f"/iam/can_act_on_resource?action={action}", json=body, headers=headers)
resp = await default_client.post(
f"/iam/can_act_on_resource?action={action}", json=body, headers=headers
)
assert resp.status_code == expected_status
@ -108,18 +106,23 @@ async def test_act_on_resource_endpoint_status_checks(default_client: AsyncClien
],
)
@pytest.mark.anyio
async def test_act_on_resource_logic(default_client: AsyncClient, db_session, service, org, resource, action,
expected_response: bool):
body = {
"service": service,
"organisation": org,
"resource": resource
}
async def test_act_on_resource_logic(
default_client: AsyncClient,
db_session,
service,
org,
resource,
action,
expected_response: bool,
):
body = {"service": service, "organisation": org, "resource": resource}
headers = {
"Authorization": "Bearer not_checked_when_auth_is_disabled",
"X-API-Key": "123456789"
"X-API-Key": "123456789",
}
resp = await default_client.post(f"/iam/can_act_on_resource?action={action}", json=body, headers=headers)
resp = await default_client.post(
f"/iam/can_act_on_resource?action={action}", json=body, headers=headers
)
data = resp.json()
assert resp.status_code == 200
@ -140,11 +143,12 @@ async def test_get_group_permissions_success(default_client: AsyncClient):
@pytest.mark.parametrize(
"query, expected_status",
generate_query_and_status(["group_id", "org_id"])
"query, expected_status", generate_query_and_status(["group_id", "org_id"])
)
@pytest.mark.anyio
async def test_get_group_permissions_status_checks(default_client: AsyncClient, db_session, query: str, expected_status: int):
async def test_get_group_permissions_status_checks(
default_client: AsyncClient, db_session, query: str, expected_status: int
):
resp = await default_client.get(f"/iam/group/permissions?{query}")
assert resp.status_code == expected_status
@ -158,8 +162,19 @@ async def test_get_group_permissions_status_checks(default_client: AsyncClient,
],
)
@pytest.mark.anyio
async def test_get_group_permissions_mismatch(default_client: AsyncClient, db_session, query: str):
db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved"))
async def test_get_group_permissions_mismatch(
default_client: AsyncClient, db_session, query: str
):
db_session.add(
Org(
name="Another Test Org",
root_user_id=1,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
)
)
db_session.add(Group(name="Another Test Group", org_id=2))
db_session.flush()
resp = await default_client.get(f"/iam/group/permissions?{query}")
@ -183,11 +198,12 @@ async def test_get_group_users_success(default_client: AsyncClient):
@pytest.mark.parametrize(
"query, expected_status",
generate_query_and_status(["group_id", "org_id"])
"query, expected_status", generate_query_and_status(["group_id", "org_id"])
)
@pytest.mark.anyio
async def test_get_group_users_status_checks(default_client: AsyncClient, query: str, expected_status: int):
async def test_get_group_users_status_checks(
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/iam/group/users?{query}")
assert resp.status_code == expected_status
@ -201,8 +217,19 @@ async def test_get_group_users_status_checks(default_client: AsyncClient, query:
],
)
@pytest.mark.anyio
async def test_get_group_users_mismatch(default_client: AsyncClient, db_session, query: str):
db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved"))
async def test_get_group_users_mismatch(
default_client: AsyncClient, db_session, query: str
):
db_session.add(
Org(
name="Another Test Org",
root_user_id=1,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
)
)
db_session.add(Group(name="Another Test Group", org_id=2))
db_session.flush()
resp = await default_client.get(f"/iam/group/users?{query}")
@ -213,7 +240,9 @@ async def test_get_group_users_mismatch(default_client: AsyncClient, db_session,
@pytest.mark.anyio
async def test_post_group_success(default_client: AsyncClient):
resp = await default_client.post("/iam/group", json={"name": "New Group", "organisation_id": 1})
resp = await default_client.post(
"/iam/group", json={"name": "New Group", "organisation_id": 1}
)
data = resp.json()
assert resp.status_code == 200
@ -227,10 +256,22 @@ async def test_post_group_success(default_client: AsyncClient):
"body, expected_status",
[
({"organisation_id": 1, "name": "Test Group"}, 409),
({"organisation_id": 2, "name": "new group"}, 404), # Non-existent organisation, valid name
({"organisation_id": "banana", "name": "new group"}, 422), # Invalid organisation ID, valid name
({"organisation_id": "", "name": "new group"}, 422), # Blank organisation ID, valid name
({"organisation_id": -1, "name": "new group"}, 422), # Negative organisation ID, valid name
(
{"organisation_id": 2, "name": "new group"},
404,
), # Non-existent organisation, valid name
(
{"organisation_id": "banana", "name": "new group"},
422,
), # Invalid organisation ID, valid name
(
{"organisation_id": "", "name": "new group"},
422,
), # Blank organisation ID, valid name
(
{"organisation_id": -1, "name": "new group"},
422,
), # Negative organisation ID, valid name
({"name": 1}, 422), # Only name
({}, 422), # Blank body
({"organisation_id": "", "name": ""}, 422), # Both blank
@ -241,7 +282,9 @@ async def test_post_group_success(default_client: AsyncClient):
],
)
@pytest.mark.anyio
async def test_post_group_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
async def test_post_group_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/iam/group", json=body)
assert resp.status_code == expected_status
@ -251,7 +294,10 @@ async def test_post_group_status_checks(default_client: AsyncClient, body: dict[
async def test_put_group_perm_success(default_client: AsyncClient, db_session):
db_session.add(Group(name="Test Group Two", org_id=1))
db_session.flush()
resp = await default_client.put("/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 1})
resp = await default_client.put(
"/iam/group/permission",
json={"permission_id": 1, "group_id": 2, "organisation_id": 1},
)
data = resp.json()
assert resp.status_code == 200
@ -270,36 +316,71 @@ async def test_put_group_perm_success(default_client: AsyncClient, db_session):
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42, "group_id": 1, "permission_id": 1}, 404), # Non-existent organisation
({"organisation_id": "banana", "group_id": 1, "permission_id": 1}, 422), # Invalid organisation ID
({"organisation_id": "", "group_id": 1, "permission_id": 1}, 422), # Blank organisation ID
({"organisation_id": -1, "group_id": 1, "permission_id": 1}, 422), # Negative organisation ID
({"organisation_id": 1, "group_id": 42, "permission_id": 1}, 404), # Non-existent group
({"organisation_id": 1, "group_id": "banana", "permission_id": 1}, 422), # Invalid group ID
({"organisation_id": 1, "group_id": "", "permission_id": 1}, 422), # Blank group ID
({"organisation_id": 1, "group_id": -1, "permission_id": 1}, 422), # Negative group ID
({"organisation_id": 1, "group_id": 1, "permission_id": 42}, 404), # Non-existent permission
({"organisation_id": 1, "group_id": 1, "permission_id": "banana"}, 422), # Invalid permission ID
({"organisation_id": 1, "group_id": 1, "permission_id": ""}, 422), # Blank permission ID
({"organisation_id": 1, "group_id": 1, "permission_id": -1}, 422), # Negative permission ID
(
{"organisation_id": 42, "group_id": 1, "permission_id": 1},
404,
), # Non-existent organisation
(
{"organisation_id": "banana", "group_id": 1, "permission_id": 1},
422,
), # Invalid organisation ID
(
{"organisation_id": "", "group_id": 1, "permission_id": 1},
422,
), # Blank organisation ID
(
{"organisation_id": -1, "group_id": 1, "permission_id": 1},
422,
), # Negative organisation ID
(
{"organisation_id": 1, "group_id": 42, "permission_id": 1},
404,
), # Non-existent group
(
{"organisation_id": 1, "group_id": "banana", "permission_id": 1},
422,
), # Invalid group ID
(
{"organisation_id": 1, "group_id": "", "permission_id": 1},
422,
), # Blank group ID
(
{"organisation_id": 1, "group_id": -1, "permission_id": 1},
422,
), # Negative group ID
(
{"organisation_id": 1, "group_id": 1, "permission_id": 42},
404,
), # Non-existent permission
(
{"organisation_id": 1, "group_id": 1, "permission_id": "banana"},
422,
), # Invalid permission ID
(
{"organisation_id": 1, "group_id": 1, "permission_id": ""},
422,
), # Blank permission ID
(
{"organisation_id": 1, "group_id": 1, "permission_id": -1},
422,
), # Negative permission ID
({}, 422), # Blank body
({"permission_id": 1}, 422), # Only permission
({"organisation_id": 1}, 422), # Only organisation
({"group_id": 1}, 422), # Only group
({"organisation_id": 1, "permission_id": 1}, 422), # Missing group
({"group_id": 1, "permission_id": 1}, 422), # Missing organisation
({"organisation_id": 1, "group_id": 1}, 422), # Missing permission
({"organisation_id": 1, "group_id": 1, "permission_id": 1}, 409), # Permission already in group
(
{"organisation_id": 1, "group_id": 1, "permission_id": 1},
409,
), # Permission already in group
],
)
@pytest.mark.anyio
async def test_put_group_perm_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
async def test_put_group_perm_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.put("/iam/group/permission", json=body)
assert resp.status_code == expected_status
@ -313,8 +394,19 @@ async def test_put_group_perm_status_checks(default_client: AsyncClient, body: d
],
)
@pytest.mark.anyio
async def test_put_group_perm_mismatch(default_client: AsyncClient, db_session, body: dict):
db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved"))
async def test_put_group_perm_mismatch(
default_client: AsyncClient, db_session, body: dict
):
db_session.add(
Org(
name="Another Test Org",
root_user_id=1,
billing_contact_id=1,
owner_contact_id=2,
security_contact_id=3,
status="approved",
)
)
db_session.add(Group(name="Another Test Group", org_id=2))
db_session.flush()
resp = await default_client.put("/iam/group/permission", json=body)
@ -325,10 +417,19 @@ async def test_put_group_perm_mismatch(default_client: AsyncClient, db_session,
@pytest.mark.anyio
async def test_put_group_user_success(default_client: AsyncClient, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush()
resp = await default_client.put("/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1})
resp = await default_client.put(
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1}
)
data = resp.json()
assert resp.status_code == 200
@ -348,34 +449,58 @@ async def test_put_group_user_success(default_client: AsyncClient, db_session):
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42, "group_id": 1, "user_id": 1}, 404), # Non-existent organisation
({"organisation_id": "banana", "group_id": 1, "user_id": 1}, 422), # Invalid organisation ID
({"organisation_id": "", "group_id": 1, "user_id": 1}, 422), # Blank organisation ID
({"organisation_id": -1, "group_id": 1, "user_id": 1}, 422), # Negative organisation ID
({"organisation_id": 1, "group_id": 42, "user_id": 1}, 404), # Non-existent group
({"organisation_id": 1, "group_id": "banana", "user_id": 1}, 422), # Invalid group ID
(
{"organisation_id": 42, "group_id": 1, "user_id": 1},
404,
), # Non-existent organisation
(
{"organisation_id": "banana", "group_id": 1, "user_id": 1},
422,
), # Invalid organisation ID
(
{"organisation_id": "", "group_id": 1, "user_id": 1},
422,
), # Blank organisation ID
(
{"organisation_id": -1, "group_id": 1, "user_id": 1},
422,
), # Negative organisation ID
(
{"organisation_id": 1, "group_id": 42, "user_id": 1},
404,
), # Non-existent group
(
{"organisation_id": 1, "group_id": "banana", "user_id": 1},
422,
), # Invalid group ID
({"organisation_id": 1, "group_id": "", "user_id": 1}, 422), # Blank group ID
({"organisation_id": 1, "group_id": -1, "user_id": 1}, 422), # Negative group ID
({"organisation_id": 1, "group_id": 1, "user_id": 42}, 404), # Non-existent user
({"organisation_id": 1, "group_id": 1, "user_id": "banana"}, 422), # Invalid user ID
(
{"organisation_id": 1, "group_id": -1, "user_id": 1},
422,
), # Negative group ID
(
{"organisation_id": 1, "group_id": 1, "user_id": 42},
404,
), # Non-existent user
(
{"organisation_id": 1, "group_id": 1, "user_id": "banana"},
422,
), # Invalid user ID
({"organisation_id": 1, "group_id": 1, "user_id": ""}, 422), # Blank user ID
({"organisation_id": 1, "group_id": 1, "user_id": -1}, 422), # Negative user ID
({}, 422), # Blank body
({"user_id": 1}, 422), # Only user
({"organisation_id": 1}, 422), # Only organisation
({"group_id": 1}, 422), # Only group
({"organisation_id": 1, "user_id": 1}, 422), # Missing group
({"group_id": 1, "user_id": 1}, 422), # Missing organisation
({"organisation_id": 1, "group_id": 1}, 422), # Missing user
],
)
@pytest.mark.anyio
async def test_put_group_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
async def test_put_group_user_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.put("/iam/group/user", json=body)
assert resp.status_code == expected_status
@ -395,11 +520,12 @@ async def test_get_permissions_success(default_client: AsyncClient):
@pytest.mark.parametrize(
"query, expected_status",
generate_query_and_status(["org_id"])
"query, expected_status", generate_query_and_status(["org_id"])
)
@pytest.mark.anyio
async def test_get_permissions_status_checks(default_client: AsyncClient, query: str, expected_status: int):
async def test_get_permissions_status_checks(
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/iam/permissions?{query}")
assert resp.status_code == expected_status
@ -407,7 +533,10 @@ async def test_get_permissions_status_checks(default_client: AsyncClient, query:
@pytest.mark.anyio
async def test_post_perm_success(default_client: AsyncClient, db_session):
resp = await default_client.post("/iam/permission", json={"service_id": 1, "resource": "test_resource", "action": "create"})
resp = await default_client.post(
"/iam/permission",
json={"service_id": 1, "resource": "test_resource", "action": "create"},
)
data = resp.json()
assert resp.status_code == 200
@ -418,51 +547,70 @@ async def test_post_perm_success(default_client: AsyncClient, db_session):
@pytest.mark.parametrize(
"body, expected_status",
[
# service_id tests
({"service_id": 42, "resource": "test_resource", "action": "read"}, 404), # Non-existent service
({"service_id": "banana", "resource": "test_resource", "action": "read"}, 422), # Invalid service ID
({"service_id": "", "resource": "test_resource", "action": "read"}, 422), # Blank service ID
({"service_id": -1, "resource": "test_resource", "action": "read"}, 422), # Negative service ID
# resource tests
({"service_id": 1, "resource": 42, "action": "read"}, 422), # Invalid resource type
# action tests
({"service_id": 1, "resource": "test_resource", "action": 42}, 422), # Invalid action type
# missing/partial body tests
({}, 422), # Blank body
({"resource": "test_resource"}, 422), # Only resource
({"action": "read"}, 422), # Only action
({"service_id": 1}, 422), # Only service
({"service_id": 1, "action": "read"}, 422), # Missing resource
({"service_id": 1, "resource": "test_resource"}, 422), # Missing action
({"resource": "test_resource", "action": "read"}, 422), # Missing service
],
"body, expected_status",
[
# service_id tests
(
{"service_id": 42, "resource": "test_resource", "action": "read"},
404,
), # Non-existent service
(
{"service_id": "banana", "resource": "test_resource", "action": "read"},
422,
), # Invalid service ID
(
{"service_id": "", "resource": "test_resource", "action": "read"},
422,
), # Blank service ID
(
{"service_id": -1, "resource": "test_resource", "action": "read"},
422,
), # Negative service ID
# resource tests
(
{"service_id": 1, "resource": 42, "action": "read"},
422,
), # Invalid resource type
# action tests
(
{"service_id": 1, "resource": "test_resource", "action": 42},
422,
), # Invalid action type
# missing/partial body tests
({}, 422), # Blank body
({"resource": "test_resource"}, 422), # Only resource
({"action": "read"}, 422), # Only action
({"service_id": 1}, 422), # Only service
({"service_id": 1, "action": "read"}, 422), # Missing resource
({"service_id": 1, "resource": "test_resource"}, 422), # Missing action
({"resource": "test_resource", "action": "read"}, 422), # Missing service
],
)
@pytest.mark.anyio
async def test_post_perm_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
async def test_post_perm_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/iam/permission", json=body)
assert resp.status_code == expected_status
@pytest.mark.parametrize(
"body",
[
{"organisation_id": 1, "service_id": 1, "resource": "test_resource", "action": "read"},
{"organisation_id": 1, "service_id": 1},
{"organisation_id": 1, "resource": "test_resource"},
{"organisation_id": 1, "action": "read"},
{"organisation_id": 1, "service_id": 1, "action": "read"},
{"organisation_id": 1, "service_id": 1, "resource": "test_resource"},
{"organisation_id": 1, "resource": "test_resource", "action": "read"},
],
"body",
[
{
"organisation_id": 1,
"service_id": 1,
"resource": "test_resource",
"action": "read",
},
{"organisation_id": 1, "service_id": 1},
{"organisation_id": 1, "resource": "test_resource"},
{"organisation_id": 1, "action": "read"},
{"organisation_id": 1, "service_id": 1, "action": "read"},
{"organisation_id": 1, "service_id": 1, "resource": "test_resource"},
{"organisation_id": 1, "resource": "test_resource", "action": "read"},
],
)
@pytest.mark.anyio
async def test_post_perm_search_success(default_client: AsyncClient, db_session, body):
@ -478,33 +626,96 @@ async def test_post_perm_search_success(default_client: AsyncClient, db_session,
@pytest.mark.parametrize(
"body, expected_status",
[
# organisation_id tests
({"organisation_id": 42, "service_id": 1, "resource": "test_resource", "action": "read"}, 404), # Non-existent organisation
({"organisation_id": "banana", "service_id": 1, "resource": "test_resource", "action": "read"}, 422), # Invalid organisation ID
({"organisation_id": "", "service_id": 1, "resource": "test_resource", "action": "read"}, 422), # Blank organisation ID
({"organisation_id": -1, "service_id": 1, "resource": "test_resource", "action": "read"}, 422), # Negative organisation ID
# service_id tests
({"organisation_id": 1, "service_id": "banana", "resource": "test_resource", "action": "read"}, 422), # Invalid service ID
({"organisation_id": 1, "service_id": "", "resource": "test_resource", "action": "read"}, 422), # Blank service ID
({"organisation_id": 1, "service_id": -1, "resource": "test_resource", "action": "read"}, 422), # Negative service ID
# resource tests
({"organisation_id": 1, "service_id": 1, "resource": 42, "action": "read"}, 422), # Invalid resource type
# action tests
({"organisation_id": 1, "service_id": 1, "resource": "test_resource", "action": 42}, 422), # Invalid action type
# missing/partial body tests
({}, 422), # Blank body
],
"body, expected_status",
[
# organisation_id tests
(
{
"organisation_id": 42,
"service_id": 1,
"resource": "test_resource",
"action": "read",
},
404,
), # Non-existent organisation
(
{
"organisation_id": "banana",
"service_id": 1,
"resource": "test_resource",
"action": "read",
},
422,
), # Invalid organisation ID
(
{
"organisation_id": "",
"service_id": 1,
"resource": "test_resource",
"action": "read",
},
422,
), # Blank organisation ID
(
{
"organisation_id": -1,
"service_id": 1,
"resource": "test_resource",
"action": "read",
},
422,
), # Negative organisation ID
# service_id tests
(
{
"organisation_id": 1,
"service_id": "banana",
"resource": "test_resource",
"action": "read",
},
422,
), # Invalid service ID
(
{
"organisation_id": 1,
"service_id": "",
"resource": "test_resource",
"action": "read",
},
422,
), # Blank service ID
(
{
"organisation_id": 1,
"service_id": -1,
"resource": "test_resource",
"action": "read",
},
422,
), # Negative service ID
# resource tests
(
{"organisation_id": 1, "service_id": 1, "resource": 42, "action": "read"},
422,
), # Invalid resource type
# action tests
(
{
"organisation_id": 1,
"service_id": 1,
"resource": "test_resource",
"action": 42,
},
422,
), # Invalid action type
# missing/partial body tests
({}, 422), # Blank body
],
)
@pytest.mark.anyio
async def test_post_perm_search_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
async def test_post_perm_search_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/iam/permissions/search", json=body)
assert resp.status_code == expected_status

View file

@ -1,6 +1,7 @@
"""
[DELETE] /org/ is not tested because the testing client cannot attach a body to a delete request.
"""
import pytest
from httpx import AsyncClient
@ -24,11 +25,12 @@ async def test_get_org_success(default_client: AsyncClient):
@pytest.mark.parametrize(
"query, expected_status",
generate_query_and_status(["org_id"])
"query, expected_status", generate_query_and_status(["org_id"])
)
@pytest.mark.anyio
async def test_get_org_status_checks(default_client: AsyncClient, query: str, expected_status: int):
async def test_get_org_status_checks(
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org?{query}")
assert resp.status_code == expected_status
@ -53,18 +55,33 @@ async def test_post_org_success(default_client: AsyncClient):
],
)
@pytest.mark.anyio
async def test_post_org_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
async def test_post_org_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/org", json=body)
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient, db_session):
async def test_patch_org_questionnaire_partial_success(
default_client: AsyncClient, db_session
):
org_model = db_session.get(Organisation, 1)
org_model.status = "partial"
db_session.flush()
resp = await default_client.patch("/org/questionnaire", json={"organisation_id": 1, "intake_questionnaire": {"question_one": "new answer one", "question_two": None, "question_three": None}, "partial": True})
resp = await default_client.patch(
"/org/questionnaire",
json={
"organisation_id": 1,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": True,
},
)
data = resp.json()
assert resp.status_code == 200
@ -83,24 +100,56 @@ async def test_patch_org_questionnaire_partial_success(default_client: AsyncClie
({"organisation_id": "Test Org"}, 422),
({"organisation_id": ""}, 422),
({}, 422),
({"organisation_id": "1", "intake_questionnaire": {"question_one": 42}, "partial": True}, 422),
({"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}}, 422),
({"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}, "partial": 42}, 422),
(
{
"organisation_id": "1",
"intake_questionnaire": {"question_one": 42},
"partial": True,
},
422,
),
(
{"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}},
422,
),
(
{
"organisation_id": "1",
"intake_questionnaire": {"question_one": "valid"},
"partial": 42,
},
422,
),
],
)
@pytest.mark.anyio
async def test_patch_questionnaire_partial_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
async def test_patch_questionnaire_partial_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/questionnaire", json=body)
assert resp.status_code == expected_status
@pytest.mark.anyio
async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient, db_session):
async def test_patch_org_questionnaire_submit_success(
default_client: AsyncClient, db_session
):
org_model = db_session.get(Organisation, 1)
org_model.status = "partial"
db_session.flush()
resp = await default_client.patch("/org/questionnaire", json={"organisation_id": 1, "intake_questionnaire": {"question_one": "new answer one", "question_two": None, "question_three": None}, "partial": False})
resp = await default_client.patch(
"/org/questionnaire",
json={
"organisation_id": 1,
"intake_questionnaire": {
"question_one": "new answer one",
"question_two": None,
"question_three": None,
},
"partial": False,
},
)
data = resp.json()
assert resp.status_code == 200
@ -113,12 +162,13 @@ async def test_patch_org_questionnaire_submit_success(default_client: AsyncClien
@pytest.mark.parametrize(
"status",
["partial", "submitted", "remediation", "approved", "rejected", "removed"]
"status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"]
)
@pytest.mark.anyio
async def test_patch_org_status_success(default_client: AsyncClient, status: str):
resp = await default_client.patch("/org/status", json={"organisation_id": 1, "status": status})
resp = await default_client.patch(
"/org/status", json={"organisation_id": 1, "status": status}
)
data = resp.json()
assert resp.status_code == 200
@ -138,7 +188,9 @@ async def test_patch_org_status_success(default_client: AsyncClient, status: str
],
)
@pytest.mark.anyio
async def test_patch_org_status_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
async def test_patch_org_status_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/status", json=body)
assert resp.status_code == expected_status
@ -161,11 +213,12 @@ async def test_get_org_users_success(default_client: AsyncClient):
@pytest.mark.parametrize(
"query, expected_status",
generate_query_and_status(["org_id"])
"query, expected_status", generate_query_and_status(["org_id"])
)
@pytest.mark.anyio
async def test_get_org_users_status_checks(default_client: AsyncClient, query: str, expected_status: int):
async def test_get_org_users_status_checks(
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org/users?{query}")
assert resp.status_code == expected_status
@ -173,10 +226,19 @@ async def test_get_org_users_status_checks(default_client: AsyncClient, query: s
@pytest.mark.anyio
async def test_post_org_user_success(default_client: AsyncClient, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush()
resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 2})
resp = await default_client.post(
"/org/user", json={"organisation_id": 1, "user_id": 2}
)
data = resp.json()
assert resp.status_code == 200
@ -197,8 +259,17 @@ async def test_post_org_user_success(default_client: AsyncClient, db_session):
],
)
@pytest.mark.anyio
async def test_post_org_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
async def test_post_org_user_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session
):
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush()
resp = await default_client.post("/org/user", json=body)
@ -208,12 +279,21 @@ async def test_post_org_user_status_checks(default_client: AsyncClient, body: di
@pytest.mark.anyio
async def test_patch_org_root_user_success(default_client: AsyncClient, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush()
db_session.add(OrgUsers(org_id=1, user_id=2))
db_session.flush()
resp = await default_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2})
resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
data = resp.json()
assert resp.status_code == 200
@ -234,8 +314,17 @@ async def test_patch_org_root_user_success(default_client: AsyncClient, db_sessi
],
)
@pytest.mark.anyio
async def test_patch_root_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
async def test_patch_root_user_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session
):
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush()
db_session.add(OrgUsers(org_id=1, user_id=2))
db_session.flush()
@ -247,10 +336,19 @@ async def test_patch_root_user_status_checks(default_client: AsyncClient, body:
@pytest.mark.anyio
async def test_patch_org_root_user_non_member(default_client: AsyncClient, db_session):
db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234"))
db_session.add(
User(
email="user@test.org",
first_name="User",
last_name="Test",
oidc_id="abcd-efgh-ijkl-1234",
)
)
db_session.flush()
resp = await default_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2})
resp = await default_client.patch(
"/org/root_user", json={"organisation_id": 1, "user_id": 2}
)
data = resp.json()
assert resp.status_code == 422
@ -269,23 +367,23 @@ async def test_get_org_groups_success(default_client: AsyncClient):
@pytest.mark.parametrize(
"query, expected_status",
generate_query_and_status(["org_id"])
"query, expected_status", generate_query_and_status(["org_id"])
)
@pytest.mark.anyio
async def test_get_org_groups_status_checks(default_client: AsyncClient, query: str, expected_status: int):
async def test_get_org_groups_status_checks(
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org/groups?{query}")
assert resp.status_code == expected_status
@pytest.mark.parametrize(
"contact_type",
["billing", "security", "owner"]
)
@pytest.mark.parametrize("contact_type", ["billing", "security", "owner"])
@pytest.mark.anyio
async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str):
resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}")
resp = await default_client.get(
f"/org/contact?org_id=1&contact_type={contact_type}"
)
data = resp.json()
assert resp.status_code == 200
@ -327,7 +425,9 @@ async def test_get_org_contact_success(default_client: AsyncClient, contact_type
],
)
@pytest.mark.anyio
async def test_get_org_contact_status_checks(default_client: AsyncClient, query: str, expected_status: int):
async def test_get_org_contact_status_checks(
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/org/contact?{query}")
assert resp.status_code == expected_status
@ -348,11 +448,16 @@ async def test_get_org_contact_status_checks(default_client: AsyncClient, query:
("address_region", "Glasgow City"),
("country_code", "GB"),
("postal_code", "G1 1AA"),
]
],
)
@pytest.mark.anyio
async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str):
resp = await default_client.patch("/org/contact", json={"organisation_id": 1, "contact_type": "billing", key: value})
async def test_patch_org_contact_success(
default_client: AsyncClient, key: str, value: str
):
resp = await default_client.patch(
"/org/contact",
json={"organisation_id": 1, "contact_type": "billing", key: value},
)
data = resp.json()
assert resp.status_code == 200
@ -379,7 +484,9 @@ async def test_patch_org_contact_success(default_client: AsyncClient, key: str,
],
)
@pytest.mark.anyio
async def test_patch_org_contact_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
async def test_patch_org_contact_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/org/contact", json=body)
assert resp.status_code == expected_status

View file

@ -1,6 +1,7 @@
"""
409 on [POST]/service/ not tested because SQLite throws a different error than Postgres
"""
import pytest
from httpx import AsyncClient
@ -19,11 +20,12 @@ async def test_get_services_success(default_client: AsyncClient):
@pytest.mark.parametrize(
"query, expected_status",
generate_query_and_status(["org_id"])
"query, expected_status", generate_query_and_status(["org_id"])
)
@pytest.mark.anyio
async def test_get_services_status_checks(default_client: AsyncClient, query: str, expected_status: int):
async def test_get_services_status_checks(
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/service/?{query}")
assert resp.status_code == expected_status
@ -49,7 +51,9 @@ async def test_post_service_success(default_client: AsyncClient):
],
)
@pytest.mark.anyio
async def test_post_services_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
async def test_post_services_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.post("/service/", json=body)
assert resp.status_code == expected_status
@ -77,7 +81,9 @@ async def test_patch_service_success(default_client: AsyncClient):
],
)
@pytest.mark.anyio
async def test_patch_services_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int):
async def test_patch_services_status_checks(
default_client: AsyncClient, body: dict[str, str], expected_status: int
):
resp = await default_client.patch("/service/key", json=body)
assert resp.status_code == expected_status

View file

@ -8,38 +8,40 @@ from httpx import AsyncClient
from .conftest import generate_query_and_status
@pytest.mark.anyio
async def test_get_self_db_success(default_client: AsyncClient):
resp = await default_client.get("/user/self/db")
data = resp.json()
resp = await default_client.get("/user/self/db")
data = resp.json()
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert "groups" in data
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert "groups" in data
@pytest.mark.anyio
async def test_get_user_success(default_client: AsyncClient):
resp = await default_client.get("/user/?user_id=1")
data = resp.json()
resp = await default_client.get("/user/?user_id=1")
data = resp.json()
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert "groups" in data
assert resp.status_code == 200
assert data["first_name"] == "Admin"
assert data["last_name"] == "Test"
assert data["email"] == "admin@test.com"
assert "organisations" in data
assert "groups" in data
@pytest.mark.anyio
@pytest.mark.parametrize(
"query, expected_status",
generate_query_and_status(["user_id"])
"query, expected_status", generate_query_and_status(["user_id"])
)
async def test_get_user_status_checks(default_client: AsyncClient, query: str, expected_status: int):
resp = await default_client.get(f"/user/?{query}")
async def test_get_user_status_checks(
default_client: AsyncClient, query: str, expected_status: int
):
resp = await default_client.get(f"/user/?{query}")
assert resp.status_code == expected_status
assert resp.status_code == expected_status