diff --git a/pyproject.toml b/pyproject.toml index 297f21d..63a648b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,10 @@ exclude = [ ".alembic" ] +[tool.ruff.format] +quote-style = "double" +indent-style = "tab" + [project] name = "cloud-api" version = "0.1.0" diff --git a/src/_module_template/config.py b/src/_module_template/config.py index 45d3182..927e7bc 100644 --- a/src/_module_template/config.py +++ b/src/_module_template/config.py @@ -2,4 +2,4 @@ Configurations for the module Exports: -""" \ No newline at end of file +""" diff --git a/src/_module_template/constants.py b/src/_module_template/constants.py index cc72009..9e8da5b 100644 --- a/src/_module_template/constants.py +++ b/src/_module_template/constants.py @@ -2,4 +2,4 @@ Constants for the module Exports: -""" \ No newline at end of file +""" diff --git a/src/_module_template/dependencies.py b/src/_module_template/dependencies.py index 71750bc..c61b149 100644 --- a/src/_module_template/dependencies.py +++ b/src/_module_template/dependencies.py @@ -3,4 +3,4 @@ Dependencies related to the module Exports: - : : -""" \ No newline at end of file +""" diff --git a/src/_module_template/exceptions.py b/src/_module_template/exceptions.py index 402940a..976b6c3 100644 --- a/src/_module_template/exceptions.py +++ b/src/_module_template/exceptions.py @@ -3,4 +3,4 @@ Exceptions related to the modules Exceptions: - : Details e.g. optional params -""" \ No newline at end of file +""" diff --git a/src/_module_template/models.py b/src/_module_template/models.py index d03c882..d059461 100644 --- a/src/_module_template/models.py +++ b/src/_module_template/models.py @@ -6,4 +6,4 @@ Models: - - - -""" \ No newline at end of file +""" diff --git a/src/_module_template/router.py b/src/_module_template/router.py index c81fc26..09250df 100644 --- a/src/_module_template/router.py +++ b/src/_module_template/router.py @@ -17,6 +17,7 @@ Exports: - Dependencies should be used for db model get and validation where possible - Verify module level docstring is still accurate after updates """ + from fastapi import APIRouter diff --git a/src/_module_template/schemas.py b/src/_module_template/schemas.py index 71cfc07..f72482a 100644 --- a/src/_module_template/schemas.py +++ b/src/_module_template/schemas.py @@ -5,4 +5,4 @@ Models follow the nomenclature of: - Sub-models: "Schema" - Mixins: "Mixin" - Models: "" ie "" -""" \ No newline at end of file +""" diff --git a/src/_module_template/service.py b/src/_module_template/service.py index 139a237..39764da 100644 --- a/src/_module_template/service.py +++ b/src/_module_template/service.py @@ -2,4 +2,4 @@ Module specific business logic for the module Exports: -""" \ No newline at end of file +""" diff --git a/src/_module_template/utils.py b/src/_module_template/utils.py index 4e99ff6..5f52b1c 100644 --- a/src/_module_template/utils.py +++ b/src/_module_template/utils.py @@ -1,3 +1,3 @@ """ Non-business logic reusable functions and classes for the module -""" \ No newline at end of file +""" diff --git a/src/admin/config.py b/src/admin/config.py index 46e4142..1b96e18 100644 --- a/src/admin/config.py +++ b/src/admin/config.py @@ -1,3 +1,3 @@ """ Configurations for the admin module -""" \ No newline at end of file +""" diff --git a/src/admin/constants.py b/src/admin/constants.py index c75163f..d02c272 100644 --- a/src/admin/constants.py +++ b/src/admin/constants.py @@ -1,3 +1,3 @@ """ Constants for the admin module -""" \ No newline at end of file +""" diff --git a/src/admin/dependencies.py b/src/admin/dependencies.py index aff00b3..0b7fefb 100644 --- a/src/admin/dependencies.py +++ b/src/admin/dependencies.py @@ -1,3 +1,3 @@ """ Dependencies for the admin module -""" \ No newline at end of file +""" diff --git a/src/admin/exceptions.py b/src/admin/exceptions.py index 513805c..18dba86 100644 --- a/src/admin/exceptions.py +++ b/src/admin/exceptions.py @@ -1,3 +1,3 @@ """ Custom exceptions for the admin module -""" \ No newline at end of file +""" diff --git a/src/admin/models.py b/src/admin/models.py index 304e336..1faf06c 100644 --- a/src/admin/models.py +++ b/src/admin/models.py @@ -1,3 +1,3 @@ """ Database models for the admin module -""" \ No newline at end of file +""" diff --git a/src/admin/router.py b/src/admin/router.py index e0246a4..9fe91eb 100644 --- a/src/admin/router.py +++ b/src/admin/router.py @@ -4,6 +4,7 @@ Router endpoints for the admin module Exports: - router: fastapi.APIRouter """ + from fastapi import APIRouter router = APIRouter( diff --git a/src/admin/schemas.py b/src/admin/schemas.py index 1289bcb..5d65867 100644 --- a/src/admin/schemas.py +++ b/src/admin/schemas.py @@ -1,3 +1,3 @@ """ Pydantic models for the admin module -""" \ No newline at end of file +""" diff --git a/src/admin/service.py b/src/admin/service.py index 1db3599..9dbfebb 100644 --- a/src/admin/service.py +++ b/src/admin/service.py @@ -1,3 +1,3 @@ """ Module specific business logic for the admin module -""" \ No newline at end of file +""" diff --git a/src/admin/utils.py b/src/admin/utils.py index e570f14..161d101 100644 --- a/src/admin/utils.py +++ b/src/admin/utils.py @@ -1,3 +1,3 @@ """ Non-business logic reusable functions and classes for the admin module -""" \ No newline at end of file +""" diff --git a/src/api.py b/src/api.py index 67ee8d1..1461fe1 100644 --- a/src/api.py +++ b/src/api.py @@ -1,6 +1,7 @@ """ This module hooks the routers for the main endpoints into a single router for importing to the app. """ + from fastapi import APIRouter from src.auth.router import router as auth_router @@ -12,9 +13,7 @@ from src.iam.router import router as iam_router from src.service.router import router as service_router -api_router = APIRouter( - prefix="/api/v1" -) +api_router = APIRouter(prefix="/api/v1") api_router.include_router(auth_router) api_router.include_router(contact_router) @@ -27,5 +26,5 @@ api_router.include_router(iam_router) @api_router.get("/healthcheck", include_in_schema=False) def healthcheck(): - """Simple healthcheck endpoint.""" - return {"status": "ok"} + """Simple healthcheck endpoint.""" + return {"status": "ok"} diff --git a/src/auth/config.py b/src/auth/config.py index 979c0e5..030c36e 100644 --- a/src/auth/config.py +++ b/src/auth/config.py @@ -4,12 +4,14 @@ Configurations for the auth module Exports: - auth_settings: Contains OIDC information """ + from src.config import CustomBaseSettings + class AuthConfig(CustomBaseSettings): - OIDC_CONFIG: str = "" - OIDC_ISSUER: str = "" - CLIENT_ID: str = "" + OIDC_CONFIG: str = "" + OIDC_ISSUER: str = "" + CLIENT_ID: str = "" auth_settings = AuthConfig() diff --git a/src/auth/constants.py b/src/auth/constants.py index faabd82..382aac7 100644 --- a/src/auth/constants.py +++ b/src/auth/constants.py @@ -1,3 +1,3 @@ """ Constants for the auth module -""" \ No newline at end of file +""" diff --git a/src/auth/dependencies.py b/src/auth/dependencies.py index 959a830..e29b641 100644 --- a/src/auth/dependencies.py +++ b/src/auth/dependencies.py @@ -7,18 +7,24 @@ Exports: - org_model_root_claim_body_dependency: org_model: verifies org exists and user is either root or su, gets org from body - super_admin_dependency: user_model: verifies the user is a super admin """ + from typing import Annotated from fastapi import Depends from src.user.dependencies import user_model_claims_dependency from src.user.models import User -from src.organisation.dependencies import org_model_query_dependency, org_model_body_dependency +from src.organisation.dependencies import ( + org_model_query_dependency, + org_model_body_dependency, +) from src.organisation.models import Organisation as Org from src.auth.exceptions import UnauthorizedException -async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency): +async def org_query_user_claims( + org_model: org_model_query_dependency, user_model: user_model_claims_dependency +): if user_model in org_model.user_rel: return True @@ -28,7 +34,11 @@ async def org_query_user_claims(org_model: org_model_query_dependency, user_mode org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)] -async def org_query_root_claims(user_model: user_model_claims_dependency, org_model: org_model_query_dependency, su_emails: su_list_dependency): +async def org_query_root_claims( + user_model: user_model_claims_dependency, + org_model: org_model_query_dependency, + su_emails: su_list_dependency, +): if org_model.root_user_id == user_model.id: return org_model @@ -41,10 +51,16 @@ async def org_query_root_claims(user_model: user_model_claims_dependency, org_mo raise UnauthorizedException(message="Must be the org's root user") -org_model_root_claim_query_dependency = Annotated[type[Org], Depends(org_query_root_claims)] +org_model_root_claim_query_dependency = Annotated[ + type[Org], Depends(org_query_root_claims) +] -async def org_body_root_claims(user_model: user_model_claims_dependency, org_model: org_model_body_dependency, su_emails: su_list_dependency): +async def org_body_root_claims( + user_model: user_model_claims_dependency, + org_model: org_model_body_dependency, + su_emails: su_list_dependency, +): if org_model.root_user_id == user_model.id: return org_model @@ -57,21 +73,29 @@ async def org_body_root_claims(user_model: user_model_claims_dependency, org_mod raise UnauthorizedException(message="Must be the org's root user") -org_model_root_claim_body_dependency = Annotated[type[Org], Depends(org_body_root_claims)] +org_model_root_claim_body_dependency = Annotated[ + type[Org], Depends(org_body_root_claims) +] def get_super_admin_list(): return [] + def empty_su_list(): return [] + def testing_su_list(): return ["admin@test.com"] + su_list_dependency = Annotated[list[User], Depends(get_super_admin_list)] -async def user_model_super_admin(user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency): + +async def user_model_super_admin( + user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency +): if user_model.email in super_admin_emails: return user_model diff --git a/src/auth/exceptions.py b/src/auth/exceptions.py index 613b166..f4a2cba 100644 --- a/src/auth/exceptions.py +++ b/src/auth/exceptions.py @@ -4,6 +4,7 @@ Module specific exceptions for the auth module Exceptions: - UnauthorizedException: Takes an optional message string """ + from typing import Optional from fastapi import HTTPException, status diff --git a/src/auth/models.py b/src/auth/models.py index 4717477..aaa8362 100644 --- a/src/auth/models.py +++ b/src/auth/models.py @@ -1,3 +1,3 @@ """ Database models for the auth module -""" \ No newline at end of file +""" diff --git a/src/auth/router.py b/src/auth/router.py index 9cd7fad..ee32033 100644 --- a/src/auth/router.py +++ b/src/auth/router.py @@ -4,8 +4,9 @@ Router endpoints for the auth module Exports: - router: fastapi.APIRouter """ + from fastapi import APIRouter router = APIRouter( tags=["auth"], -) \ No newline at end of file +) diff --git a/src/auth/schemas.py b/src/auth/schemas.py index 279bb1b..5f5ac35 100644 --- a/src/auth/schemas.py +++ b/src/auth/schemas.py @@ -1,3 +1,3 @@ """ Pydantic models for the auth module -""" \ No newline at end of file +""" diff --git a/src/auth/service.py b/src/auth/service.py index f156a9d..a27d421 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -4,6 +4,7 @@ Module specific business logic for the auth module Exports: - claims_dependency: Dict[str, Any] containing OIDC claims and database ID """ + import json import requests @@ -25,11 +26,14 @@ from src.database import db_dependency oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc_dependency = Annotated[str, Depends(oidc)] + def get_dev_user(): return {"db_id": 1} -async def get_current_user(oidc_auth_string: oidc_dependency, db: db_dependency) -> dict[str, Any]: +async def get_current_user( + oidc_auth_string: oidc_dependency, db: db_dependency +) -> dict[str, Any]: config_url = urlopen(auth_settings.OIDC_CONFIG) config = json.loads(config_url.read()) jwks_uri = config["jwks_uri"] @@ -41,10 +45,7 @@ async def get_current_user(oidc_auth_string: oidc_dependency, db: db_dependency) "iss": {"essential": True, "value": auth_settings.OIDC_ISSUER}, } - token = jwt.decode( - oidc_auth_string.replace("Bearer ", ""), - jwk_keys - ) + token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys) claims_requests = jwt.JWTClaimsRegistry(**claims_options) diff --git a/src/auth/utils.py b/src/auth/utils.py index ed66e7c..178518a 100644 --- a/src/auth/utils.py +++ b/src/auth/utils.py @@ -1,3 +1,3 @@ """ Non-business logic reusable functions and classes for the auth module -""" \ No newline at end of file +""" diff --git a/src/config.py b/src/config.py index afecd8f..ddce0c8 100644 --- a/src/config.py +++ b/src/config.py @@ -16,25 +16,26 @@ from src.constants import Environment class CustomBaseSettings(BaseSettings): - model_config = SettingsConfigDict( - env_file=".env", env_file_encoding="utf-8", extra="ignore" - ) + model_config = SettingsConfigDict( + env_file=".env", env_file_encoding="utf-8", extra="ignore" + ) class Config(CustomBaseSettings): - APP_VERSION: str = "0.1" - ENVIRONMENT: Environment = Environment.PRODUCTION - SECRET_KEY: SecretStr = "" - DISABLE_AUTH: bool = False + APP_VERSION: str = "0.1" + ENVIRONMENT: Environment = Environment.PRODUCTION + SECRET_KEY: SecretStr = "" + DISABLE_AUTH: bool = False - CORS_ORIGINS: list[str] = ["*"] - CORS_ORIGINS_REGEX: str | None = None - CORS_HEADERS: list[str] = ["*"] + CORS_ORIGINS: list[str] = ["*"] + CORS_ORIGINS_REGEX: str | None = None + CORS_HEADERS: list[str] = ["*"] + + DATABASE_NAME: str = "fastapi-exp" + DATABASE_PORT: str = "5432" + DATABASE_HOSTNAME: str = "localhost" + DATABASE_CREDENTIALS: SecretStr = ":" - DATABASE_NAME: str = "fastapi-exp" - DATABASE_PORT: str = "5432" - DATABASE_HOSTNAME: str = "localhost" - DATABASE_CREDENTIALS: SecretStr = ":" settings = Config() @@ -43,17 +44,21 @@ DATABASE_PORT = settings.DATABASE_PORT DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value() # this will support special chars for credentials -_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split(":") +_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str( + DATABASE_CREDENTIALS +).split(":") _QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD)) -SQLALCHEMY_DATABASE_URI = SecretStr(f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}") +SQLALCHEMY_DATABASE_URI = SecretStr( + f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}" +) if settings.ENVIRONMENT == Environment.TESTING: - SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:") + SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:") app_configs: dict[str, Any] = {"title": "App API"} if settings.ENVIRONMENT.is_deployed: - app_configs["root_path"] = f"/v{settings.APP_VERSION}" + app_configs["root_path"] = f"/v{settings.APP_VERSION}" if not settings.ENVIRONMENT.is_debug: - app_configs["openapi_url"] = None # hide docs + app_configs["openapi_url"] = None # hide docs diff --git a/src/constants.py b/src/constants.py index ab33afb..b0725bf 100644 --- a/src/constants.py +++ b/src/constants.py @@ -4,6 +4,7 @@ Global constants Classes: - Environment(StrEnum): LOCAL, TESTING, STAGING, PRODUCTION """ + from enum import StrEnum, auto diff --git a/src/contact/config.py b/src/contact/config.py index 2253a68..7480691 100644 --- a/src/contact/config.py +++ b/src/contact/config.py @@ -1,3 +1,3 @@ """ Configurations for the contact module -""" \ No newline at end of file +""" diff --git a/src/contact/constants.py b/src/contact/constants.py index 41f6ded..ad08c0e 100644 --- a/src/contact/constants.py +++ b/src/contact/constants.py @@ -1,3 +1,3 @@ """ Constants for the contact module -""" \ No newline at end of file +""" diff --git a/src/contact/dependencies.py b/src/contact/dependencies.py index de1d404..0844fd3 100644 --- a/src/contact/dependencies.py +++ b/src/contact/dependencies.py @@ -1,3 +1,3 @@ """ Dependencies for the contact module -""" \ No newline at end of file +""" diff --git a/src/contact/exceptions.py b/src/contact/exceptions.py index 6710bf3..55e9e30 100644 --- a/src/contact/exceptions.py +++ b/src/contact/exceptions.py @@ -4,6 +4,7 @@ Exceptions related to the contact module Exports: - ContactNotFoundException: Takes an optional contact ID int """ + from typing import Optional from fastapi import HTTPException, status @@ -11,7 +12,11 @@ from fastapi import HTTPException, status class ContactNotFoundException(HTTPException): def __init__(self, contact_id: Optional[int] = None) -> None: - detail = "Contact not found" if contact_id is None else f"Contact with ID '{contact_id}' was not found." + detail = ( + "Contact not found" + if contact_id is None + else f"Contact with ID '{contact_id}' was not found." + ) super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=detail, diff --git a/src/contact/models.py b/src/contact/models.py index 3369501..d741bd2 100644 --- a/src/contact/models.py +++ b/src/contact/models.py @@ -5,6 +5,7 @@ Models: - Contact: id[pk], email, first_name, last_name, phonenumber, vat_number street_address, street_address_line_2, post_office_box_number, address_locality, country_code, address_region, postal_code """ + from sqlalchemy import Column, Integer, String, ForeignKey from src.database import Base @@ -23,9 +24,11 @@ class Contact(Base): street_address = Column(String) street_address_line_2 = Column(String) post_office_box_number = Column(String, default=None, nullable=True) - locality = Column(String) # Ie City - country_code = Column(String) # Eg GB + locality = Column(String) # Ie City + country_code = Column(String) # Eg GB address_region = Column(String, default=None, nullable=True) postal_code = Column(String) - org_id = Column(Integer, ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False) + org_id = Column( + Integer, ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False + ) diff --git a/src/contact/router.py b/src/contact/router.py index cdab37f..2e5f8f4 100644 --- a/src/contact/router.py +++ b/src/contact/router.py @@ -1,10 +1,11 @@ """ Router endpoints for the contact module """ + from fastapi import APIRouter router = APIRouter( prefix="/contact", tags=["contact"], -) \ No newline at end of file +) diff --git a/src/contact/schemas.py b/src/contact/schemas.py index b008739..b9cec61 100644 --- a/src/contact/schemas.py +++ b/src/contact/schemas.py @@ -5,6 +5,7 @@ Models: - ContactAddress - ContactModel: Contains ContactAddress as a property """ + from typing import Optional from pydantic import EmailStr, ConfigDict diff --git a/src/contact/service.py b/src/contact/service.py index e04866a..3223e70 100644 --- a/src/contact/service.py +++ b/src/contact/service.py @@ -1,3 +1,3 @@ """ Module specific business logic for the contact module -""" \ No newline at end of file +""" diff --git a/src/contact/utils.py b/src/contact/utils.py index 6a1d14a..daa2449 100644 --- a/src/contact/utils.py +++ b/src/contact/utils.py @@ -1,3 +1,3 @@ """ Non-business logic reusable functions and classes for the contact module -""" \ No newline at end of file +""" diff --git a/src/database.py b/src/database.py index a56f80d..3838098 100644 --- a/src/database.py +++ b/src/database.py @@ -5,6 +5,7 @@ Exports: - db_dependency - Base (sqlalchemy base model) """ + from typing import Annotated from sqlalchemy import create_engine, StaticPool from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session @@ -16,7 +17,11 @@ from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings if global_settings.ENVIRONMENT == Environment.TESTING: connect_args = {"check_same_thread": False} - engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value(), connect_args=connect_args, poolclass=StaticPool) + engine = create_engine( + SQLALCHEMY_DATABASE_URI.get_secret_value(), + connect_args=connect_args, + poolclass=StaticPool, + ) else: engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value()) @@ -36,5 +41,7 @@ def get_db(): db_dependency = Annotated[Session, Depends(get_db)] + + class Base(DeclarativeBase): pass diff --git a/src/exceptions.py b/src/exceptions.py index 66507a4..8b3629c 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -5,6 +5,7 @@ Exports: - UnprocessableContentException - ConflictException """ + from typing import Optional from fastapi import HTTPException, status diff --git a/src/iam/config.py b/src/iam/config.py index 165dc07..8fef3ec 100644 --- a/src/iam/config.py +++ b/src/iam/config.py @@ -1,3 +1,3 @@ """ Configurations for the IAM module -""" \ No newline at end of file +""" diff --git a/src/iam/constants.py b/src/iam/constants.py index 0dc94e7..c2623ec 100644 --- a/src/iam/constants.py +++ b/src/iam/constants.py @@ -1,3 +1,3 @@ """ Constants for the IAM module -""" \ No newline at end of file +""" diff --git a/src/iam/dependencies.py b/src/iam/dependencies.py index 37b8e87..72cc683 100644 --- a/src/iam/dependencies.py +++ b/src/iam/dependencies.py @@ -6,6 +6,7 @@ Exports: - group_model_body_dependency: group_model: Gets group model from db, if it exists. Uses group_id from request body. - perm_model_body_dependency: perm_model: Gets perm model from db, if it exists. Uses perm_id from request body. """ + from typing import Annotated, Optional from fastapi import Depends, Query @@ -17,17 +18,22 @@ from src.iam.exceptions import GroupNotFoundException, PermNotFoundException from src.iam.schemas import GroupIDMixin, PermIDMixin -def get_group_model_query(db: db_dependency, group_id: Annotated[int, Query(gt=0)]) -> type[Group]: +def get_group_model_query( + db: db_dependency, group_id: Annotated[int, Query(gt=0)] +) -> type[Group]: group_model = db.get(Group, group_id) if group_model is None: raise GroupNotFoundException(group_id) return group_model + group_model_query_dependency = Annotated[type[Group], Depends(get_group_model_query)] -def get_group_model_body(db: db_dependency, request_model: Optional[GroupIDMixin] = None) -> type[Group]: +def get_group_model_body( + db: db_dependency, request_model: Optional[GroupIDMixin] = None +) -> type[Group]: group_id = getattr(request_model, "group_id", None) if group_id is None: raise GroupNotFoundException() @@ -37,10 +43,13 @@ def get_group_model_body(db: db_dependency, request_model: Optional[GroupIDMixin return group_model + group_model_body_dependency = Annotated[type[Group], Depends(get_group_model_body)] -def get_perm_model_body(db: db_dependency, request_model: Optional[PermIDMixin] = None) -> type[Permission]: +def get_perm_model_body( + db: db_dependency, request_model: Optional[PermIDMixin] = None +) -> type[Permission]: perm_id = getattr(request_model, "permission_id", None) if perm_id is None: raise PermNotFoundException @@ -50,4 +59,5 @@ def get_perm_model_body(db: db_dependency, request_model: Optional[PermIDMixin] return perm_model + perm_model_body_dependency = Annotated[type[Permission], Depends(get_perm_model_body)] diff --git a/src/iam/exceptions.py b/src/iam/exceptions.py index 84a77ed..503b844 100644 --- a/src/iam/exceptions.py +++ b/src/iam/exceptions.py @@ -5,6 +5,7 @@ Exceptions: - GroupNotFoundException: Takes an optional group_id int - PermNotFoundException: Takes an optional perm_id int """ + from typing import Optional from fastapi import HTTPException, status @@ -12,7 +13,11 @@ from fastapi import HTTPException, status class GroupNotFoundException(HTTPException): def __init__(self, group_id: Optional[int] = None) -> None: - detail = "Group not found" if group_id is None else f"User with ID '{group_id}' was not found." + detail = ( + "Group not found" + if group_id is None + else f"User with ID '{group_id}' was not found." + ) super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=detail, @@ -21,7 +26,11 @@ class GroupNotFoundException(HTTPException): class PermNotFoundException(HTTPException): def __init__(self, perm_id: Optional[int] = None) -> None: - detail = "Permission not found" if perm_id is None else f"User with ID '{perm_id}' was not found." + detail = ( + "Permission not found" + if perm_id is None + else f"User with ID '{perm_id}' was not found." + ) super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=detail, diff --git a/src/iam/models.py b/src/iam/models.py index f542b70..a35391f 100644 --- a/src/iam/models.py +++ b/src/iam/models.py @@ -17,6 +17,7 @@ Models: - UserGroups: - org_id[FK][PK], user_id[FK][PK], group_id[FK][PK] """ + from sqlalchemy import Column, Integer, String, ForeignKey, UniqueConstraint from sqlalchemy.orm import relationship @@ -32,7 +33,9 @@ class Permission(Base): service_id = Column(Integer, ForeignKey("service.id", ondelete="CASCADE")) - UniqueConstraint("service_id", "resource", "action", name="uniq_permission_resource_and_action") + UniqueConstraint( + "service_id", "resource", "action", name="uniq_permission_resource_and_action" + ) service_rel = relationship("Service", foreign_keys=[service_id]) @@ -41,13 +44,10 @@ class Permission(Base): return self.service_rel.name group_rel = relationship( - "Group", - secondary="group_permissions", - back_populates="permission_rel" + "Group", secondary="group_permissions", back_populates="permission_rel" ) - class Group(Base): __tablename__ = "group" id = Column(Integer, primary_key=True) @@ -55,28 +55,30 @@ class Group(Base): org_id = Column(Integer, ForeignKey("organisation.id", ondelete="CASCADE")) - user_rel = relationship( - "User", - secondary="user_groups", - back_populates="group_rel" - ) + user_rel = relationship("User", secondary="user_groups", back_populates="group_rel") org_rel = relationship("Organisation", back_populates="group_rel") permission_rel = relationship( - "Permission", - secondary="group_permissions", - back_populates="group_rel" + "Permission", secondary="group_permissions", back_populates="group_rel" ) class GroupPermissions(Base): __tablename__ = "group_permissions" - group_id = Column(Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True) - permission_id = Column(Integer, ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True) + group_id = Column( + Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True + ) + permission_id = Column( + Integer, ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True + ) class UserGroups(Base): __tablename__ = "user_groups" - user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True) - group_id = Column(Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True) + user_id = Column( + Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True + ) + group_id = Column( + Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True + ) diff --git a/src/iam/router.py b/src/iam/router.py index 6ed3ff0..e6715e4 100644 --- a/src/iam/router.py +++ b/src/iam/router.py @@ -15,6 +15,7 @@ Endpoints: - [DELETE](/iam/permission): [super admin]: Removes a permission - [GET](/iam/permissions/search): [root user]: Returns a list of permissions matching a filter(service|resource|action) """ + from fastapi import APIRouter, status from sqlalchemy.exc import IntegrityError from psycopg import errors @@ -25,21 +26,49 @@ from src.database import db_dependency from src.schemas import ResourceName from src.auth.exceptions import UnauthorizedException from src.auth.service import claims_dependency -from src.auth.dependencies import org_model_root_claim_query_dependency, org_model_root_claim_body_dependency, \ - super_admin_dependency +from src.auth.dependencies import ( + org_model_root_claim_query_dependency, + org_model_root_claim_body_dependency, + super_admin_dependency, +) from src.user.models import User from src.user.dependencies import user_model_body_dependency from src.organisation.models import Organisation as Org from src.service.models import Service from src.iam.service import service_key_dependency -from src.iam.models import Permission as Perm, GroupPermissions as GPerms, Group, UserGroups -from src.iam.dependencies import group_model_query_dependency, group_model_body_dependency, perm_model_body_dependency -from src.iam.schemas import IAMGetGroupPermissionsResponse, IAMGetGroupUsersResponse, IAMPostGroupRequest, \ - GroupSchema, IAMPostGroupResponse, IAMPutGroupPermissionRequest, IAMPutGroupPermissionResponse, \ - IAMPutGroupUserRequest, IAMPutGroupUserResponse, IAMDeleteGroupPermissionRequest, IAMDeleteGroupPermissionResponse, \ - IAMDeleteGroupUserRequest, IAMDeleteGroupUserResponse, IAMGetPermissionsResponse, IAMPostPermissionRequest, \ - IAMPostPermissionResponse, IAMDeletePermissionRequest, IAMGetPermissionsSearchRequest, IAMGetPermissionsSearchResponse +from src.iam.models import ( + Permission as Perm, + GroupPermissions as GPerms, + Group, + UserGroups, +) +from src.iam.dependencies import ( + group_model_query_dependency, + group_model_body_dependency, + perm_model_body_dependency, +) +from src.iam.schemas import ( + IAMGetGroupPermissionsResponse, + IAMGetGroupUsersResponse, + IAMPostGroupRequest, + GroupSchema, + IAMPostGroupResponse, + IAMPutGroupPermissionRequest, + IAMPutGroupPermissionResponse, + IAMPutGroupUserRequest, + IAMPutGroupUserResponse, + IAMDeleteGroupPermissionRequest, + IAMDeleteGroupPermissionResponse, + IAMDeleteGroupUserRequest, + IAMDeleteGroupUserResponse, + IAMGetPermissionsResponse, + IAMPostPermissionRequest, + IAMPostPermissionResponse, + IAMDeletePermissionRequest, + IAMGetPermissionsSearchRequest, + IAMGetPermissionsSearchResponse, +) router = APIRouter( tags=["IAM"], @@ -48,26 +77,32 @@ router = APIRouter( @router.post("/can_act_on_resource") -async def can_act_on_resource(valid_key: service_key_dependency, db: db_dependency, user_claims: claims_dependency, - rn: ResourceName, action: str) -> bool: +async def can_act_on_resource( + valid_key: service_key_dependency, + db: db_dependency, + user_claims: claims_dependency, + rn: ResourceName, + action: str, +) -> bool: try: user_id = user_claims["db_id"] rn_org = rn.organisation rn_service = rn.service rn_resource = rn.resource - result = (db.query(Perm) - .join(Service, Service.id == Perm.service_id) - .join(GPerms, GPerms.permission_id == Perm.id) - .join(Group, Group.id == GPerms.group_id) - .join(Org, Org.id == Group.org_id) - .join(UserGroups, UserGroups.group_id == Group.id) - .join(User, User.id == UserGroups.user_id) - .filter(User.id == user_id) - .filter(Org.name == rn_org) - .filter(Service.name == rn_service) - .filter(Perm.resource == rn_resource) - .filter(Perm.action == action) + result = ( + db.query(Perm) + .join(Service, Service.id == Perm.service_id) + .join(GPerms, GPerms.permission_id == Perm.id) + .join(Group, Group.id == GPerms.group_id) + .join(Org, Org.id == Group.org_id) + .join(UserGroups, UserGroups.group_id == Group.id) + .join(User, User.id == UserGroups.user_id) + .filter(User.id == user_id) + .filter(Org.name == rn_org) + .filter(Service.name == rn_service) + .filter(Perm.resource == rn_resource) + .filter(Perm.action == action) ).first() if result: @@ -79,21 +114,31 @@ async def can_act_on_resource(valid_key: service_key_dependency, db: db_dependen @router.get("/group/permissions", response_model=IAMGetGroupPermissionsResponse) -async def get_group_permissions(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency): +async def get_group_permissions( + group_model: group_model_query_dependency, + org_model: org_model_root_claim_query_dependency, +): if group_model.org_id != org_model.id: raise UnauthorizedException("Group does not belong to this organization") return {"permissions": group_model.permission_rel} @router.get("/group/users", response_model=IAMGetGroupUsersResponse) -async def get_group_users(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency): +async def get_group_users( + group_model: group_model_query_dependency, + org_model: org_model_root_claim_query_dependency, +): if group_model.org_id != org_model.id: raise UnauthorizedException("Group does not belong to this organization") return {"users": group_model.user_rel} @router.post("/group", response_model=IAMPostGroupResponse) -async def create_group(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPostGroupRequest): +async def create_group( + db: db_dependency, + org_model: org_model_root_claim_body_dependency, + request_model: IAMPostGroupRequest, +): group_model = Group(name=request_model.name, org_id=org_model.id) db.add(group_model) @@ -101,9 +146,9 @@ async def create_group(db: db_dependency, org_model: org_model_root_claim_body_d db.flush() except IntegrityError as e: if ( - getattr(e.orig, "pgcode", None) == "23505" # Postgres unique violation - or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation - ): + getattr(e.orig, "pgcode", None) == "23505" # Postgres unique violation + or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation + ): raise ConflictException("Group with this name already exists") response = GroupSchema(**group_model.__dict__) db.commit() @@ -111,7 +156,13 @@ async def create_group(db: db_dependency, org_model: org_model_root_claim_body_d @router.put("/group/permission", response_model=IAMPutGroupPermissionResponse) -async def add_group_permission(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupPermissionRequest): +async def add_group_permission( + db: db_dependency, + group_model: group_model_body_dependency, + perm_model: perm_model_body_dependency, + org_model: org_model_root_claim_body_dependency, + request_model: IAMPutGroupPermissionRequest, +): if group_model.org_id != org_model.id: raise UnauthorizedException("Group does not belong to this organization") @@ -121,13 +172,22 @@ async def add_group_permission(db: db_dependency, group_model: group_model_body_ group_model.permission_rel.append(perm_model) db.flush() - response = IAMPutGroupPermissionResponse(group=GroupSchema(**group_model.__dict__), permissions=group_model.permission_rel) + response = IAMPutGroupPermissionResponse( + group=GroupSchema(**group_model.__dict__), + permissions=group_model.permission_rel, + ) db.commit() return response @router.put("/group/user", response_model=IAMPutGroupUserResponse) -async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupUserRequest): +async def add_group_user( + db: db_dependency, + group_model: group_model_body_dependency, + user_model: user_model_body_dependency, + org_model: org_model_root_claim_body_dependency, + request_model: IAMPutGroupUserRequest, +): if group_model.org_id != org_model.id: raise UnauthorizedException("Group does not belong to this organization") @@ -136,46 +196,70 @@ async def add_group_user(db: db_dependency, group_model: group_model_body_depend group_model.user_rel.append(user_model) db.flush() - response = IAMPutGroupUserResponse(group=GroupSchema(**group_model.__dict__), users=group_model.user_rel) + response = IAMPutGroupUserResponse( + group=GroupSchema(**group_model.__dict__), users=group_model.user_rel + ) db.commit() return response @router.delete("/group/permissions") -async def remove_group_permissions(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupPermissionRequest): +async def remove_group_permissions( + db: db_dependency, + group_model: group_model_body_dependency, + perm_model: perm_model_body_dependency, + org_model: org_model_root_claim_body_dependency, + request_model: IAMDeleteGroupPermissionRequest, +): if group_model.org_id != org_model.id: raise UnauthorizedException("Group does not belong to this organization") group_model.permission_rel.remove(perm_model) db.flush() - response = IAMDeleteGroupPermissionResponse(group=GroupSchema(**group_model.__dict__), - permissions=group_model.permission_rel) + response = IAMDeleteGroupPermissionResponse( + group=GroupSchema(**group_model.__dict__), + permissions=group_model.permission_rel, + ) db.commit() return response @router.delete("/group/user") -async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupUserRequest): +async def remove_group_user( + db: db_dependency, + group_model: group_model_body_dependency, + user_model: user_model_body_dependency, + org_model: org_model_root_claim_body_dependency, + request_model: IAMDeleteGroupUserRequest, +): if group_model.org_id != org_model.id: raise UnauthorizedException("Group does not belong to this organization") user_model.group_rel.remove(group_model) db.flush() - response = IAMDeleteGroupUserResponse(group=GroupSchema(**group_model.__dict__), users=group_model.user_rel) + response = IAMDeleteGroupUserResponse( + group=GroupSchema(**group_model.__dict__), users=group_model.user_rel + ) db.commit() return response @router.get("/permissions", response_model=IAMGetPermissionsResponse) -async def get_permissions(db: db_dependency, org_model: org_model_root_claim_query_dependency): +async def get_permissions( + db: db_dependency, org_model: org_model_root_claim_query_dependency +): permission_models = db.query(Perm).all() return {"permissions": permission_models} @router.post("/permission", response_model=IAMPostPermissionResponse) -async def create_new_permission(db: db_dependency, su: super_admin_dependency, request_model: IAMPostPermissionRequest): +async def create_new_permission( + db: db_dependency, + su: super_admin_dependency, + request_model: IAMPostPermissionRequest, +): service_model = db.get(Service, request_model.service_id) if service_model is None: raise ServiceNotFoundException(service_id=request_model.service_id) @@ -186,29 +270,46 @@ async def create_new_permission(db: db_dependency, su: super_admin_dependency, r if isinstance(e.orig, errors.UniqueViolation): raise ConflictException(message="Permission already exists") db.flush() - response = {"service_name": perm_model.service_name, "resource": perm_model.resource, "action": perm_model.action} + response = { + "service_name": perm_model.service_name, + "resource": perm_model.resource, + "action": perm_model.action, + } db.commit() return {"permission": response} @router.delete("/permission", status_code=status.HTTP_204_NO_CONTENT) -async def delete_permission(db: db_dependency, su: super_admin_dependency, perm_model: perm_model_body_dependency, request_model: IAMDeletePermissionRequest): +async def delete_permission( + db: db_dependency, + su: super_admin_dependency, + perm_model: perm_model_body_dependency, + request_model: IAMDeletePermissionRequest, +): db.delete(perm_model) db.commit() @router.post("/permissions/search", response_model=IAMGetPermissionsSearchResponse) -async def post_permissions(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMGetPermissionsSearchRequest): +async def post_permissions( + db: db_dependency, + org_model: org_model_root_claim_body_dependency, + request_model: IAMGetPermissionsSearchRequest, +): permission_query = db.query(Perm) if request_model.service_id is not None: - permission_query = permission_query.filter(Perm.service_id == request_model.service_id) + permission_query = permission_query.filter( + Perm.service_id == request_model.service_id + ) if request_model.resource is not None: - permission_query = permission_query.filter(Perm.resource == request_model.resource) + permission_query = permission_query.filter( + Perm.resource == request_model.resource + ) if request_model.action is not None: - permission_query = permission_query.filter(Perm.action == request_model. action) + permission_query = permission_query.filter(Perm.action == request_model.action) permission_models = permission_query.all() diff --git a/src/iam/schemas.py b/src/iam/schemas.py index 0d370b8..16aaa29 100644 --- a/src/iam/schemas.py +++ b/src/iam/schemas.py @@ -6,6 +6,7 @@ Models follow the nomenclature of: - Mixins: "Mixin" - Models: "" ie "IAMGetGroupPermissionsResponse" """ + from typing import Optional, Annotated from pydantic import EmailStr, ConfigDict, Field @@ -24,6 +25,7 @@ class UserSchema(CustomBaseModel): last_name: str email: EmailStr + class PermissionSchema(CustomBaseModel): model_config = ConfigDict(from_attributes=True, extra="ignore") @@ -31,73 +33,94 @@ class PermissionSchema(CustomBaseModel): resource: str action: str + class GroupSchema(CustomBaseModel): id: int name: str + class GroupIDMixin(CustomBaseModel): group_id: int = Field(gt=0) + class PermIDMixin(CustomBaseModel): permission_id: int = Field(gt=0) + class IAMGetGroupPermissionsResponse(CustomBaseModel): permissions: list[PermissionSchema] + class IAMGetGroupUsersResponse(CustomBaseModel): - users : list[UserSchema] + users: list[UserSchema] + class IAMPostGroupRequest(OrgIDMixin): name: str = Field(min_length=3) + class IAMPostGroupResponse(CustomBaseModel): group: GroupSchema + class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin): pass + class IAMPutGroupPermissionResponse(CustomBaseModel): group: GroupSchema permissions: list[PermissionSchema] + class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin): pass + class IAMPutGroupUserResponse(CustomBaseModel): group: GroupSchema users: list[UserSchema] + class IAMDeleteGroupPermissionRequest(GroupIDMixin, PermIDMixin): pass + class IAMDeleteGroupPermissionResponse(CustomBaseModel): group: GroupSchema permissions: list[PermissionSchema] + class IAMDeleteGroupUserRequest(GroupIDMixin, UserIDMixin): pass + class IAMDeleteGroupUserResponse(CustomBaseModel): group: GroupSchema users: list[UserSchema] + class IAMGetPermissionsResponse(CustomBaseModel): permissions: list[PermissionSchema] + class IAMPostPermissionRequest(ServiceIDMixin): resource: str action: str + class IAMPostPermissionResponse(CustomBaseModel): permission: PermissionSchema + class IAMDeletePermissionRequest(PermIDMixin): pass + class IAMGetPermissionsSearchRequest(OrgIDMixin): service_id: Annotated[int | None, Field(gt=0)] = None resource: Optional[str] = None action: Optional[str] = None + class IAMGetPermissionsSearchResponse(CustomBaseModel): permissions: list[PermissionSchema] diff --git a/src/iam/service.py b/src/iam/service.py index c6a1030..1e0dfe8 100644 --- a/src/iam/service.py +++ b/src/iam/service.py @@ -4,6 +4,7 @@ Business logic reusable functions related to IAM Exports: - service_key_dependency: bool: verifies request headers contain the correct api key for the service """ + from typing import Annotated from src.service.models import Service @@ -19,10 +20,16 @@ def valid_service_key(db: db_dependency, request: Request, rn: ResourceName) -> if not api_key: raise UnauthorizedException("Missing API key") service = rn.service - result = db.query(Service).filter(Service.name == service).filter(Service.api_key == api_key).first() + result = ( + db.query(Service) + .filter(Service.name == service) + .filter(Service.api_key == api_key) + .first() + ) if result is None: raise UnauthorizedException("Invalid API key") return True + service_key_dependency = Annotated[bool, Depends(valid_service_key)] diff --git a/src/main.py b/src/main.py index c421a94..26dc6e0 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,7 @@ """ Application root file: Inits the FastAPI application """ + from contextlib import asynccontextmanager from typing import AsyncGenerator diff --git a/src/models.py b/src/models.py index fa198e4..87912aa 100644 --- a/src/models.py +++ b/src/models.py @@ -1,4 +1,3 @@ """ Global database models """ - diff --git a/src/organisation/config.py b/src/organisation/config.py index e24ca5b..7ce00f7 100644 --- a/src/organisation/config.py +++ b/src/organisation/config.py @@ -1,3 +1,3 @@ """ Configurations for the organisation module -""" \ No newline at end of file +""" diff --git a/src/organisation/constants.py b/src/organisation/constants.py index ced0682..79c22fd 100644 --- a/src/organisation/constants.py +++ b/src/organisation/constants.py @@ -5,6 +5,7 @@ Classes: - Status(StrEnum): PARTIAL, SUBMITTED, REMEDIATION, APPROVED, REJECTED, REMOVED - ContactType(StrEnum): BILLING, SECURITY, OWNER """ + from enum import StrEnum, auto diff --git a/src/organisation/dependencies.py b/src/organisation/dependencies.py index 728b8d0..35b09fc 100644 --- a/src/organisation/dependencies.py +++ b/src/organisation/dependencies.py @@ -5,6 +5,7 @@ Exports: - org_model_query_dependency: org_model: Gets org model from db, if it exists. Uses org_id from query param. Also verifies if the org has been approved. - org_model_body_dependency: org_model: Gets org model from db, if it exists. Uses org_id from request body. Also verifies if the org has been approved. """ + from typing import Annotated, Optional from sqlalchemy.orm import Session @@ -25,25 +26,40 @@ def get_org_model(db: Session, request: Request, org_id: int): root = "/api/v1" - pre_approval_endpoints = [f"PATCH{root}/org/status", f"PATCH{root}/org/questionnaire", f"GET{root}/org", f"GET{root}/org/contact", f"PATCH{root}/org/contact"] + pre_approval_endpoints = [ + f"PATCH{root}/org/status", + f"PATCH{root}/org/questionnaire", + f"GET{root}/org", + f"GET{root}/org/contact", + f"PATCH{root}/org/contact", + ] current_request = f"{request.method}{request.url.path}" - if current_request not in pre_approval_endpoints and org_model.status != OrgStatus.APPROVED: + if ( + current_request not in pre_approval_endpoints + and org_model.status != OrgStatus.APPROVED + ): raise AwaitingApprovalException(org_id) return org_model -def get_org_model_query(db: db_dependency, request: Request, org_id: Annotated[int, Query(gt=0)]) -> type[Org]: +def get_org_model_query( + db: db_dependency, request: Request, org_id: Annotated[int, Query(gt=0)] +) -> type[Org]: return get_org_model(db, request, org_id) + org_model_query_dependency = Annotated[type[Org], Depends(get_org_model_query)] -def get_org_model_body(db: db_dependency, request: Request, request_model: OrgIDMixin) -> type[Org]: +def get_org_model_body( + db: db_dependency, request: Request, request_model: OrgIDMixin +) -> type[Org]: org_id: Optional[int] = getattr(request_model, "organisation_id", None) if org_id is None: raise OrgNotFoundException return get_org_model(db, request, org_id) + org_model_body_dependency = Annotated[type[Org], Depends(get_org_model_body)] diff --git a/src/organisation/exceptions.py b/src/organisation/exceptions.py index 8fe61cc..a56b395 100644 --- a/src/organisation/exceptions.py +++ b/src/organisation/exceptions.py @@ -5,6 +5,7 @@ Exceptions: - OrgNotFoundException: Takes an optional org_id int - AwaitingApprovalException: Takes an optional org_id int """ + from typing import Optional from fastapi import HTTPException, status @@ -12,15 +13,24 @@ from fastapi import HTTPException, status class OrgNotFoundException(HTTPException): def __init__(self, org_id: Optional[int] = None) -> None: - detail = "Organisation not found" if org_id is None else f"Organisation with ID '{org_id}' was not found." + detail = ( + "Organisation not found" + if org_id is None + else f"Organisation with ID '{org_id}' was not found." + ) super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=detail, ) + class AwaitingApprovalException(HTTPException): def __init__(self, org_id: Optional[int] = None) -> None: - detail = "Organisation has not been approved." if org_id is None else f"Organisation with ID '{org_id}' has not been approved." + detail = ( + "Organisation has not been approved." + if org_id is None + else f"Organisation with ID '{org_id}' has not been approved." + ) super().__init__( status_code=status.HTTP_401_UNAUTHORIZED, detail=detail, diff --git a/src/organisation/models.py b/src/organisation/models.py index 3663f2c..e99d64f 100644 --- a/src/organisation/models.py +++ b/src/organisation/models.py @@ -13,6 +13,7 @@ Models: - owner_contact_rel: ORM relationship to Contact with owner_contact FK - OrgUsers: org_id[FK][PK], user_id[FK][PK] """ + from sqlalchemy import Column, Integer, String, ForeignKey, JSON from sqlalchemy.orm import relationship @@ -34,9 +35,7 @@ class Organisation(Base): owner_contact_id = Column(Integer, ForeignKey("contact.id")) user_rel = relationship( - "User", - secondary="orgusers", - back_populates="organisation_rel" + "User", secondary="orgusers", back_populates="organisation_rel" ) group_rel = relationship("Group", back_populates="org_rel") @@ -54,5 +53,9 @@ class Organisation(Base): class OrgUsers(Base): __tablename__ = "orgusers" - org_id = Column(Integer, ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True) - user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True) + org_id = Column( + Integer, ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True + ) + user_id = Column( + Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True + ) diff --git a/src/organisation/router.py b/src/organisation/router.py index 90a3627..6d2bb69 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -15,6 +15,7 @@ Endpoints: - [GET](/org/contact): [root user]: Gets the (contact_type) contact for an org(id) - [PATCH](/org/contact): [root user]: Updates the (contact_type) contact for an org(id). Any number of details can be changed. """ + from typing import Annotated from fastapi import APIRouter, status @@ -28,17 +29,40 @@ from src.contact.models import Contact from src.contact.schemas import ContactAddress from src.contact.exceptions import ContactNotFoundException from src.database import db_dependency -from src.user.dependencies import user_model_body_dependency, user_model_claims_dependency -from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency, org_model_root_claim_body_dependency +from src.user.dependencies import ( + user_model_body_dependency, + user_model_claims_dependency, +) +from src.auth.dependencies import ( + super_admin_dependency, + org_model_root_claim_query_dependency, + org_model_root_claim_body_dependency, +) from src.organisation.dependencies import org_model_body_dependency from src.organisation.constants import ContactType from src.organisation.models import Organisation as Org -from src.organisation.schemas import OrgPostOrgRequest, OrgPatchQuestionnaireRequest, OrgPatchStatusRequest, \ - OrgPatchContactRequest, \ - OrgPostUserRequest, OrgGetUserResponse, OrgGetContactResponse, OrgGetOrgResponse, OrgPatchRootRequest, \ - OrgGetGroupResponse, OrgDeleteUserRequest, OrgDeleteOrgRequest, OrgPostOrgResponse, OrgPatchQuestionnaireResponse, \ - OrgPatchStatusResponse, OrgPostUserResponse, OrgPatchRootResponse, Questionnaire, OrgPatchContactResponse +from src.organisation.schemas import ( + OrgPostOrgRequest, + OrgPatchQuestionnaireRequest, + OrgPatchStatusRequest, + OrgPatchContactRequest, + OrgPostUserRequest, + OrgGetUserResponse, + OrgGetContactResponse, + OrgGetOrgResponse, + OrgPatchRootRequest, + OrgGetGroupResponse, + OrgDeleteUserRequest, + OrgDeleteOrgRequest, + OrgPostOrgResponse, + OrgPatchQuestionnaireResponse, + OrgPatchStatusResponse, + OrgPostUserResponse, + OrgPatchRootResponse, + Questionnaire, + OrgPatchContactResponse, +) router = APIRouter( prefix="/org", @@ -46,16 +70,22 @@ router = APIRouter( ) -@router.get("", - summary="Get org details by ID.", - response_model=OrgGetOrgResponse, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - status.HTTP_404_NOT_FOUND: {"description": "Organisation not found"}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Missing or invalid org_id query parameter"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, - }) +@router.get( + "", + summary="Get org details by ID.", + response_model=OrgGetOrgResponse, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + status.HTTP_404_NOT_FOUND: {"description": "Organisation not found"}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Missing or invalid org_id query parameter" + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Not authorised. Must be org root user." + }, + }, +) async def get_org_by_id(org_model: org_model_root_claim_query_dependency): """ Returns organisation details including key member email addresses @@ -68,23 +98,35 @@ async def get_org_by_id(org_model: org_model_root_claim_query_dependency): "billing_contact": org_model.billing_contact_rel.email, "security_contact": org_model.security_contact_rel.email, "root_user": org_model.root_user_email, - "intake_questionnaire": org_model.intake_questionnaire + "intake_questionnaire": org_model.intake_questionnaire, } return response -@router.post("", - summary="Create new organisation.", - status_code=status.HTTP_201_CREATED, - response_model=OrgPostOrgResponse, - responses={ - status.HTTP_201_CREATED: {"description": "Successfully created organisation."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_401_UNAUTHORIZED: {"description": "User must be logged in with OIDC to create organisation."}, - status.HTTP_409_CONFLICT: {"description": "Organisation with this name already exists."}, - }) -async def create_org(db: db_dependency, user_model: user_model_claims_dependency, request_model: OrgPostOrgRequest): +@router.post( + "", + summary="Create new organisation.", + status_code=status.HTTP_201_CREATED, + response_model=OrgPostOrgResponse, + responses={ + status.HTTP_201_CREATED: {"description": "Successfully created organisation."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Invalid data in request." + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "User must be logged in with OIDC to create organisation." + }, + status.HTTP_409_CONFLICT: { + "description": "Organisation with this name already exists." + }, + }, +) +async def create_org( + db: db_dependency, + user_model: user_model_claims_dependency, + request_model: OrgPostOrgRequest, +): """ Creates a new organisation with optional questionnaire (to be completed or submitted). ALl organisations are given the "partial" status on creation. See update_questionnaire() for more details. @@ -102,11 +144,17 @@ async def create_org(db: db_dependency, user_model: user_model_claims_dependency db.flush() except IntegrityError as e: if isinstance(e.orig, UniqueViolation): - raise ConflictException(message="Organisation with this name already exists") + raise ConflictException( + message="Organisation with this name already exists" + ) # Adds currently logged-in user to org users list and sets them as root_user org_model.user_rel.append(user_model) org_model.root_user_rel = user_model - for contact_type in ["billing_contact_id", "security_contact_id", "owner_contact_id"]: + for contact_type in [ + "billing_contact_id", + "security_contact_id", + "owner_contact_id", + ]: contact_model = Contact(org_id=org_model.id) db.add(contact_model) db.flush() @@ -116,16 +164,26 @@ async def create_org(db: db_dependency, user_model: user_model_claims_dependency return response -@router.patch("/questionnaire", - summary="Update questionnaire.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchQuestionnaireResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated questionnaire."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, - }) -async def update_questionnaire(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchQuestionnaireRequest): +@router.patch( + "/questionnaire", + summary="Update questionnaire.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchQuestionnaireResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated questionnaire."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Invalid data in request." + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Not authorised. Must be org root user." + }, + }, +) +async def update_questionnaire( + db: db_dependency, + org_model: org_model_root_claim_body_dependency, + request_model: OrgPatchQuestionnaireRequest, +): """ Route for updating questionnaire. The partial bool allows for submission of partially completed questionnaire and/or @@ -150,16 +208,29 @@ async def update_questionnaire(db: db_dependency, org_model: org_model_root_clai return response -@router.patch("/status", - summary="Update status of organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchStatusResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated organisation status."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be super admin."}, - }) -async def update_status(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchStatusRequest): +@router.patch( + "/status", + summary="Update status of organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchStatusResponse, + responses={ + status.HTTP_200_OK: { + "description": "Successfully updated organisation status." + }, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Invalid data in request." + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Not authorised. Must be super admin." + }, + }, +) +async def update_status( + db: db_dependency, + org_model: org_model_body_dependency, + su: super_admin_dependency, + request_model: OrgPatchStatusRequest, +): """ Sets an organisation's status. This is the endpoint for approving or denying an organisation after reviewing the questionnaire. """ @@ -170,33 +241,57 @@ async def update_status(db: db_dependency, org_model: org_model_body_dependency, return response -@router.get("/users", - summary="Get email addresses of users of the organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgGetUserResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval of users."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Org ID missing or invalid."}, - status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, - }) +@router.get( + "/users", + summary="Get email addresses of users of the organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgGetUserResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval of users."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Org ID missing or invalid." + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Not authorised. Must be org root user." + }, + }, +) async def get_users(org_model: org_model_root_claim_query_dependency): """ Returns a list of the email addresses of all users of the organisation. """ - return {"users": [user.email for user in org_model.user_rel], "organisation": org_model} + return { + "users": [user.email for user in org_model.user_rel], + "organisation": org_model, + } -@router.post("/user", - summary="Add user to the organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPostUserResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully added user to the organisation."}, - status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_409_CONFLICT: {"description": "User is already a member of the organisation."}, - }) -async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgPostUserRequest): +@router.post( + "/user", + summary="Add user to the organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPostUserResponse, + responses={ + status.HTTP_200_OK: { + "description": "Successfully added user to the organisation." + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Not authorised. Must be org root user." + }, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Invalid data in request." + }, + status.HTTP_409_CONFLICT: { + "description": "User is already a member of the organisation." + }, + }, +) +async def add_user_to_org( + db: db_dependency, + org_model: org_model_root_claim_body_dependency, + user_model: user_model_body_dependency, + request_model: OrgPostUserRequest, +): """ Adds a user to the organisation. """ @@ -209,15 +304,28 @@ async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_bod return response -@router.delete("", - summary="Delete organisation from the hub.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."}, - status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be super admin."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Org ID missing or invalid."}, - }) -async def delete_organisation_by_id(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgDeleteOrgRequest): +@router.delete( + "", + summary="Delete organisation from the hub.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: { + "description": "Successfully deleted organisation." + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Not authorised. Must be super admin." + }, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Org ID missing or invalid." + }, + }, +) +async def delete_organisation_by_id( + db: db_dependency, + org_model: org_model_body_dependency, + su: super_admin_dependency, + request_model: OrgDeleteOrgRequest, +): """ Removes an organisation from the hub. """ @@ -225,37 +333,59 @@ async def delete_organisation_by_id(db: db_dependency, org_model: org_model_body db.commit() -@router.patch("/root_user", - summary="Update the root user of the organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchRootResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated root user."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be super admin."}, - }) -async def update_root_user(db: db_dependency, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchRootRequest): +@router.patch( + "/root_user", + summary="Update the root user of the organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchRootResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated root user."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Invalid data in request." + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Not authorised. Must be super admin." + }, + }, +) +async def update_root_user( + db: db_dependency, + org_model: org_model_body_dependency, + user_model: user_model_body_dependency, + su: super_admin_dependency, + request_model: OrgPatchRootRequest, +): """ Promotes an existing organisation user to the root user, giving them full control of the org. """ if user_model not in org_model.user_rel: - raise UnprocessableContentException(message="This user does not belong to your organisation.") + raise UnprocessableContentException( + message="This user does not belong to your organisation." + ) org_model.root_user_rel = user_model db.flush() - response = OrgPatchRootResponse(name=org_model.name, root_user_email=org_model.root_user_email) + response = OrgPatchRootResponse( + name=org_model.name, root_user_email=org_model.root_user_email + ) db.commit() return response -@router.get("/groups", - summary="Get all organisation IAM groups.", - status_code=status.HTTP_200_OK, - response_model=OrgGetGroupResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval of IAM groups."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Org ID missing or invalid."}, - status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, - }) +@router.get( + "/groups", + summary="Get all organisation IAM groups.", + status_code=status.HTTP_200_OK, + response_model=OrgGetGroupResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval of IAM groups."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Org ID missing or invalid." + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Not authorised. Must be org root user." + }, + }, +) async def get_org_groups(org_model: org_model_root_claim_query_dependency): """ Returns a list of the names of all IAM groups created by the organisation. @@ -263,15 +393,26 @@ async def get_org_groups(org_model: org_model_root_claim_query_dependency): return {"groups": [group.name for group in org_model.group_rel]} -@router.delete("/user", - summary="Remove user from organisation.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "Successfully removed user."}, - status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - }) -async def remove_user_from_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgDeleteUserRequest): +@router.delete( + "/user", + summary="Remove user from organisation.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "Successfully removed user."}, + status.HTTP_401_UNAUTHORIZED: { + "description": "Not authorised. Must be org root user." + }, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Invalid data in request." + }, + }, +) +async def remove_user_from_org( + db: db_dependency, + org_model: org_model_root_claim_body_dependency, + user_model: user_model_body_dependency, + request_model: OrgDeleteUserRequest, +): """ Revokes a user's membership in an organisation. """ @@ -282,16 +423,27 @@ async def remove_user_from_org(db: db_dependency, org_model: org_model_root_clai db.commit() -@router.get("/contact", - summary="Get contact for organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgGetContactResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval of contact."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, - }) -async def get_contact(org_model: org_model_root_claim_query_dependency, contact_type: Annotated[ContactType, Query(description="Must be billing|security|owner")]): +@router.get( + "/contact", + summary="Get contact for organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgGetContactResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval of contact."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Invalid data in request." + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Not authorised. Must be org root user." + }, + }, +) +async def get_contact( + org_model: org_model_root_claim_query_dependency, + contact_type: Annotated[ + ContactType, Query(description="Must be billing|security|owner") + ], +): """ Gets full details for a contact point at an organisation. """ @@ -309,21 +461,33 @@ async def get_contact(org_model: org_model_root_claim_query_dependency, contact_ raise ContactNotFoundException() address = ContactAddress.model_validate(contact_model) - contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address) + contact_response = ContactModel.model_construct( + **contact_model.__dict__, address=address + ) return {"contact": contact_response, "organisation": org_model} -@router.patch("/contact", - summary="Update contact for organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchContactResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated contact."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, - status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, - }) -async def update_contact(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchContactRequest): +@router.patch( + "/contact", + summary="Update contact for organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchContactResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated contact."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: { + "description": "Invalid data in request." + }, + status.HTTP_401_UNAUTHORIZED: { + "description": "Not authorised. Must be org root user." + }, + }, +) +async def update_contact( + db: db_dependency, + org_model: org_model_root_claim_body_dependency, + request_model: OrgPatchContactRequest, +): """ Updates details for a contact point at an organisation. """ @@ -351,7 +515,9 @@ async def update_contact(db: db_dependency, org_model: org_model_root_claim_body db.flush() address = ContactAddress.model_validate(contact_model) - contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address) + contact_response = ContactModel.model_construct( + **contact_model.__dict__, address=address + ) db.commit() diff --git a/src/organisation/schemas.py b/src/organisation/schemas.py index 305c6f7..31e59b9 100644 --- a/src/organisation/schemas.py +++ b/src/organisation/schemas.py @@ -6,6 +6,7 @@ Models follow the nomenclature of: - Mixins: "Mixin" - Models: "" ie "OrgPostOrgRequest" """ + from typing import Optional from pydantic import EmailStr, ConfigDict, Field @@ -20,11 +21,13 @@ from src.organisation.constants import Status, ContactType class OrgIDMixin(CustomBaseModel): organisation_id: int = Field(gt=0) + class Questionnaire(CustomBaseModel): question_one: Optional[str] = None question_two: Optional[str] = None question_three: Optional[str] = None + class OrgSchema(CustomBaseModel): id: int name: str @@ -34,26 +37,32 @@ class OrgPostOrgRequest(CustomBaseModel): name: str intake_questionnaire: Optional[Questionnaire] = None + class OrgPostOrgResponse(CustomBaseModel): name: str status: Status + class OrgPatchQuestionnaireRequest(OrgIDMixin): intake_questionnaire: Questionnaire partial: bool + class OrgPatchQuestionnaireResponse(CustomBaseModel): name: str intake_questionnaire: Questionnaire status: Status + class OrgPatchStatusRequest(OrgIDMixin): status: Status + class OrgPatchStatusResponse(CustomBaseModel): name: str status: Status + class OrgPatchContactRequest(OrgIDMixin): contact_type: ContactType @@ -70,41 +79,51 @@ class OrgPatchContactRequest(OrgIDMixin): country_code: Optional[str] = None postal_code: Optional[str] = None + class OrgPostUserRequest(OrgIDMixin, UserIDMixin): pass + class OrgPostUserResponse(CustomBaseModel): users: list[str] + class OrgDeleteUserRequest(OrgIDMixin, UserIDMixin): pass + class OrgPatchRootRequest(OrgIDMixin, UserIDMixin): pass + class OrgPatchRootResponse(CustomBaseModel): name: str root_user_email: str + class OrgGetUserResponse(CustomBaseModel): users: list[str] organisation: OrgSchema + class OrgGetGroupResponse(CustomBaseModel): groups: list[str] + class OrgGetContactResponse(CustomBaseModel): model_config = ConfigDict(from_attributes=True, extra="ignore") contact: ContactModel organisation: OrgSchema + class OrgPatchContactResponse(CustomBaseModel): model_config = ConfigDict(from_attributes=True, extra="ignore") contact: ContactModel organisation: OrgSchema + class OrgGetOrgResponse(CustomBaseModel): id: int name: str @@ -115,5 +134,6 @@ class OrgGetOrgResponse(CustomBaseModel): security_contact: Optional[str] = None intake_questionnaire: Optional[Questionnaire] = None + class OrgDeleteOrgRequest(OrgIDMixin): pass diff --git a/src/organisation/service.py b/src/organisation/service.py index 6d73399..cfe3925 100644 --- a/src/organisation/service.py +++ b/src/organisation/service.py @@ -1,3 +1,3 @@ """ Reusable business logic functions for the organisation module -""" \ No newline at end of file +""" diff --git a/src/organisation/utils.py b/src/organisation/utils.py index ead22ca..0337df5 100644 --- a/src/organisation/utils.py +++ b/src/organisation/utils.py @@ -1,3 +1,3 @@ """ Non-business logic reusable functions and classes for the organisation module -""" \ No newline at end of file +""" diff --git a/src/schemas.py b/src/schemas.py index 812b574..484031b 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -5,6 +5,7 @@ Exports: - CustomBaseModel: Schema used for all other Pydantic models - ResourceName """ + from pydantic import BaseModel from typing import Optional diff --git a/src/service/config.py b/src/service/config.py index 5d4fd3b..b23303f 100644 --- a/src/service/config.py +++ b/src/service/config.py @@ -1,3 +1,3 @@ """ Configurations for the services module -""" \ No newline at end of file +""" diff --git a/src/service/constants.py b/src/service/constants.py index 52a8701..0007ba0 100644 --- a/src/service/constants.py +++ b/src/service/constants.py @@ -1,3 +1,3 @@ """ Constants for the services module -""" \ No newline at end of file +""" diff --git a/src/service/dependencies.py b/src/service/dependencies.py index cda625a..9792f26 100644 --- a/src/service/dependencies.py +++ b/src/service/dependencies.py @@ -5,6 +5,7 @@ Exports: - service_model_query_dependency: service_model: Gets service model from db, if it exists. Uses service_id from query param. - service_model_body_dependency: service_model: Gets service model from db, if it exists. Uses service_id from request body. """ + from typing import Annotated from fastapi import Depends, Query @@ -15,14 +16,19 @@ from src.service.models import Service from src.service.schemas import ServiceIDMixin -async def get_service_model_query(db: db_dependency, service_id: Annotated[int, Query(gt=0)]): +async def get_service_model_query( + db: db_dependency, service_id: Annotated[int, Query(gt=0)] +): service_model = db.get(Service, service_id) if service_model is None: raise ServiceNotFoundException(service_id=service_id) return service_model -service_model_query_dependency = Annotated[type[Service], Depends(get_service_model_query)] + +service_model_query_dependency = Annotated[ + type[Service], Depends(get_service_model_query) +] async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixin): @@ -32,4 +38,7 @@ async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixi return service_model -service_model_body_dependency = Annotated[type[Service], Depends(get_service_model_body)] + +service_model_body_dependency = Annotated[ + type[Service], Depends(get_service_model_body) +] diff --git a/src/service/exceptions.py b/src/service/exceptions.py index 8a1a2e3..36a927d 100644 --- a/src/service/exceptions.py +++ b/src/service/exceptions.py @@ -4,6 +4,7 @@ Exceptions related to the services module Exceptions: - ServiceNotFoundException: Takes an optional service_id int """ + from typing import Optional from fastapi import HTTPException, status @@ -11,7 +12,11 @@ from fastapi import HTTPException, status class ServiceNotFoundException(HTTPException): def __init__(self, service_id: Optional[int] = None) -> None: - detail = "Service not found" if service_id is None else f"Service with ID '{service_id}' was not found." + detail = ( + "Service not found" + if service_id is None + else f"Service with ID '{service_id}' was not found." + ) super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=detail, diff --git a/src/service/models.py b/src/service/models.py index 68fd020..82bdba1 100644 --- a/src/service/models.py +++ b/src/service/models.py @@ -5,6 +5,7 @@ Models: - Service: - id[PK], name[U], api_key[U] """ + from sqlalchemy import Column, Integer, String from src.database import Base diff --git a/src/service/router.py b/src/service/router.py index a8f93ea..69a9369 100644 --- a/src/service/router.py +++ b/src/service/router.py @@ -7,19 +7,30 @@ Endpoints: - [PATCH](/key): [super_admin]: Refreshes the API key for a service(id), returning a new one. - [DELETE](/): [super_admin]: Removes a service(id) from the hub. """ + from fastapi import APIRouter, status from psycopg.errors import UniqueViolation from sqlalchemy.exc import IntegrityError from src.exceptions import ConflictException from src.database import db_dependency -from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency +from src.auth.dependencies import ( + super_admin_dependency, + org_model_root_claim_query_dependency, +) from src.service.models import Service from src.service.utils import generate_api_key from src.service.dependencies import service_model_body_dependency -from src.service.schemas import ServiceGetServiceResponse, ServicePostServiceRequest, ServicePostServiceResponse, \ - ServiceWithKeySchema, ServicePatchKeyResponse, ServicePatchKeyRequest, ServiceDeleteServiceRequest +from src.service.schemas import ( + ServiceGetServiceResponse, + ServicePostServiceRequest, + ServicePostServiceResponse, + ServiceWithKeySchema, + ServicePatchKeyResponse, + ServicePatchKeyRequest, + ServiceDeleteServiceRequest, +) router = APIRouter( tags=["Service"], @@ -27,15 +38,19 @@ router = APIRouter( ) -@router.get("/", - summary="Get all services", - status_code=status.HTTP_200_OK, - response_model=ServiceGetServiceResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - }) -async def get_all_services(db: db_dependency, org_model: org_model_root_claim_query_dependency): +@router.get( + "/", + summary="Get all services", + status_code=status.HTTP_200_OK, + response_model=ServiceGetServiceResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }, +) +async def get_all_services( + db: db_dependency, org_model: org_model_root_claim_query_dependency +): """ Returns the ID and name of all services registered to the hub. """ @@ -44,16 +59,24 @@ async def get_all_services(db: db_dependency, org_model: org_model_root_claim_qu return {"services": permission_models} -@router.post("/", - summary="Register a new service.", - status_code=status.HTTP_200_OK, - response_model=ServicePostServiceResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully registered a new service"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"}, - }) -async def register_service(db: db_dependency, su: super_admin_dependency, request_model: ServicePostServiceRequest): +@router.post( + "/", + summary="Register a new service.", + status_code=status.HTTP_200_OK, + response_model=ServicePostServiceResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully registered a new service"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + status.HTTP_409_CONFLICT: { + "description": "Service with this name already exists" + }, + }, +) +async def register_service( + db: db_dependency, + su: super_admin_dependency, + request_model: ServicePostServiceRequest, +): """ Registers a new service to the hub, generating and returning an API key for it. """ @@ -71,16 +94,22 @@ async def register_service(db: db_dependency, su: super_admin_dependency, reques return {"service": response} -@router.patch("/key", - summary="Regenerate service API key.", - status_code=status.HTTP_200_OK, - response_model=ServicePatchKeyResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful update of API key"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - }) -async def regenerate_api_key(db: db_dependency, su: super_admin_dependency, - service_model: service_model_body_dependency, request_model: ServicePatchKeyRequest): +@router.patch( + "/key", + summary="Regenerate service API key.", + status_code=status.HTTP_200_OK, + response_model=ServicePatchKeyResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful update of API key"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }, +) +async def regenerate_api_key( + db: db_dependency, + su: super_admin_dependency, + service_model: service_model_body_dependency, + request_model: ServicePatchKeyRequest, +): """ Generates and returns a new API key for the service to access the hub. """ @@ -93,15 +122,23 @@ async def regenerate_api_key(db: db_dependency, su: super_admin_dependency, return {"service": response} -@router.delete("/", - summary="Remove a service.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - }) -async def remove_service(db: db_dependency, service_model: service_model_body_dependency, su: super_admin_dependency, - request_model: ServiceDeleteServiceRequest): +@router.delete( + "/", + summary="Remove a service.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: { + "description": "Successfully removed service from db" + }, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }, +) +async def remove_service( + db: db_dependency, + service_model: service_model_body_dependency, + su: super_admin_dependency, + request_model: ServiceDeleteServiceRequest, +): """ Removes a service from the hub. """ diff --git a/src/service/schemas.py b/src/service/schemas.py index 50e7c35..9c04840 100644 --- a/src/service/schemas.py +++ b/src/service/schemas.py @@ -6,36 +6,46 @@ Models follow the nomenclature of: - Mixins: "Mixin" - Models: "" ie "ServiceGetServiceResponse" """ + from pydantic import ConfigDict, Field from src.schemas import CustomBaseModel + class ServiceIDMixin(CustomBaseModel): service_id: int = Field(gt=0) + class ServiceSchema(CustomBaseModel): model_config = ConfigDict(from_attributes=True, extra="ignore") id: int name: str + class ServiceWithKeySchema(ServiceSchema): api_key: str + class ServiceGetServiceResponse(CustomBaseModel): services: list[ServiceSchema] + class ServicePostServiceRequest(CustomBaseModel): name: str + class ServicePostServiceResponse(CustomBaseModel): service: ServiceWithKeySchema + class ServicePatchKeyRequest(ServiceIDMixin): pass + class ServicePatchKeyResponse(CustomBaseModel): service: ServiceWithKeySchema + class ServiceDeleteServiceRequest(ServiceIDMixin): pass diff --git a/src/service/service.py b/src/service/service.py index 2609565..2f59b44 100644 --- a/src/service/service.py +++ b/src/service/service.py @@ -1,3 +1,3 @@ """ Business logic for the services module -""" \ No newline at end of file +""" diff --git a/src/service/utils.py b/src/service/utils.py index 8920a5f..79bb91f 100644 --- a/src/service/utils.py +++ b/src/service/utils.py @@ -4,6 +4,7 @@ Non-business logic reusable functions and classes for the services module Exports: - generate_api_key(): returns a new UUID """ + import uuid diff --git a/src/user/config.py b/src/user/config.py index 9bbcbc4..a25018d 100644 --- a/src/user/config.py +++ b/src/user/config.py @@ -1,3 +1,3 @@ """ Configurations for the user module -""" \ No newline at end of file +""" diff --git a/src/user/constants.py b/src/user/constants.py index fc6a780..2adb24d 100644 --- a/src/user/constants.py +++ b/src/user/constants.py @@ -1,3 +1,3 @@ """ Constants for the user module -""" \ No newline at end of file +""" diff --git a/src/user/dependencies.py b/src/user/dependencies.py index dc22429..9d2fe02 100644 --- a/src/user/dependencies.py +++ b/src/user/dependencies.py @@ -6,6 +6,7 @@ Exports: - user_model_query_dependency: user_model: Gets user model from db, if it exists. Uses user_id from query param - user_model_body_dependency: user_model: Gets user model from db, if it exists. Uses user_id from request body. """ + from typing import Annotated from fastapi import Depends, Query @@ -28,6 +29,7 @@ async def get_user_model_claims(claims: claims_dependency, db: db_dependency): return user_model + user_model_claims_dependency = Annotated[type[User], Depends(get_user_model_claims)] @@ -38,6 +40,7 @@ async def get_user_model_query(db: db_dependency, user_id: Annotated[int, Query( return user_model + user_model_query_dependency = Annotated[type[User], Depends(get_user_model_query)] @@ -48,4 +51,5 @@ async def get_user_model_body(db: db_dependency, request_model: UserIDMixin): return user_model + user_model_body_dependency = Annotated[type[User], Depends(get_user_model_body)] diff --git a/src/user/exceptions.py b/src/user/exceptions.py index 2e03b05..fa4db2f 100644 --- a/src/user/exceptions.py +++ b/src/user/exceptions.py @@ -4,6 +4,7 @@ Exceptions related to the user module Exceptions: - UserNotFoundException: Takes an optional user_id int """ + from typing import Optional from fastapi import HTTPException, status @@ -11,7 +12,11 @@ from fastapi import HTTPException, status class UserNotFoundException(HTTPException): def __init__(self, user_id: Optional[int] = None) -> None: - detail = "User not found" if user_id is None else f"User with ID '{user_id}' was not found." + detail = ( + "User not found" + if user_id is None + else f"User with ID '{user_id}' was not found." + ) super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=detail, diff --git a/src/user/models.py b/src/user/models.py index 97c0f22..62c7ef6 100644 --- a/src/user/models.py +++ b/src/user/models.py @@ -9,6 +9,7 @@ Models: - organisations: Calc property list of organisation_rel.name - groups: Calc property dict of {group_rel.org_rel.name: group_rel.name} """ + from collections import defaultdict from sqlalchemy import Column, Integer, String @@ -18,29 +19,29 @@ from src.database import Base class User(Base): - __tablename__ = "user" + __tablename__ = "user" - id = Column(Integer, primary_key=True) - email = Column(String) - first_name = Column(String) - last_name = Column(String) - oidc_id = Column(String, index=True, unique=True) + id = Column(Integer, primary_key=True) + email = Column(String) + first_name = Column(String) + last_name = Column(String) + oidc_id = Column(String, index=True, unique=True) - organisation_rel = relationship( - "Organisation", secondary="orgusers", back_populates="user_rel" - ) + organisation_rel = relationship( + "Organisation", secondary="orgusers", back_populates="user_rel" + ) - @property - def organisations(self): - return [{"name": org.name, "id": org.id} for org in self.organisation_rel] + @property + def organisations(self): + return [{"name": org.name, "id": org.id} for org in self.organisation_rel] - group_rel = relationship( - "Group", secondary="user_groups", back_populates="user_rel" - ) + group_rel = relationship( + "Group", secondary="user_groups", back_populates="user_rel" + ) - @property - def groups(self): - result = defaultdict(list) - for group in self.group_rel: - result[group.org_rel.name].append({"name": group.name, "id": group.id}) - return dict(result) + @property + def groups(self): + result = defaultdict(list) + for group in self.group_rel: + result[group.org_rel.name].append({"name": group.name, "id": group.id}) + return dict(result) diff --git a/src/user/router.py b/src/user/router.py index 0966850..f380fb4 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -7,11 +7,16 @@ Endpoints: - [GET](/user/): [super admin]: Returns user(id) details. - [DELETE](/user/): [super admin]: Removes a User(id) from the hub database. """ + from fastapi import APIRouter from starlette import status from src.user.schemas import UserResponse, OIDCClaims, UserDeleteUserRequest -from src.user.dependencies import user_model_claims_dependency, user_model_query_dependency, user_model_body_dependency +from src.user.dependencies import ( + user_model_claims_dependency, + user_model_query_dependency, + user_model_body_dependency, +) from src.auth.dependencies import super_admin_dependency from src.auth.service import claims_dependency @@ -23,13 +28,15 @@ router = APIRouter( ) -@router.get("/self/claims", - summary="Get current user OIDC claims.", - response_model=OIDCClaims, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - }) +@router.get( + "/self/claims", + summary="Get current user OIDC claims.", + response_model=OIDCClaims, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + }, +) async def current_user_claims(user: claims_dependency): """ Returns the full OIDC claims associated with the currently logged-in user. @@ -38,14 +45,16 @@ async def current_user_claims(user: claims_dependency): return user -@router.get("/self/db", - summary="Get current user hub details.", - response_model=UserResponse, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - }) +@router.get( + "/self/db", + summary="Get current user hub details.", + response_model=UserResponse, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + }, +) async def current_user(user_model: user_model_claims_dependency): """ Returns the database details associated with the currently logged-in user. @@ -53,30 +62,40 @@ async def current_user(user_model: user_model_claims_dependency): return user_model -@router.get("/", - summary="Get user hub details by ID.", - response_model=UserResponse, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - }) -async def get_user_by_id(user_model: user_model_query_dependency, su: super_admin_dependency): +@router.get( + "/", + summary="Get user hub details by ID.", + response_model=UserResponse, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + }, +) +async def get_user_by_id( + user_model: user_model_query_dependency, su: super_admin_dependency +): """ Returns the database details associated with the provided user ID. """ return user_model -@router.delete("/", - summary="Delete user from hub by ID.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - }) -async def delete_user_by_id(db: db_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, - request_model: UserDeleteUserRequest): +@router.delete( + "/", + summary="Delete user from hub by ID.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + }, +) +async def delete_user_by_id( + db: db_dependency, + user_model: user_model_body_dependency, + su: super_admin_dependency, + request_model: UserDeleteUserRequest, +): """ Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login. """ diff --git a/src/user/schemas.py b/src/user/schemas.py index 211004c..5fd44df 100644 --- a/src/user/schemas.py +++ b/src/user/schemas.py @@ -1,6 +1,7 @@ """ Pydantic models for the user module """ + from typing import Optional from pydantic import Field @@ -47,8 +48,8 @@ class UserResponse(CustomBaseModel): first_name: str last_name: str email: str - organisations: list[Optional[dict[str, str|int]]] - groups: Optional[dict[str, list[dict[str, str|int]]]] = None + organisations: list[Optional[dict[str, str | int]]] + groups: Optional[dict[str, list[dict[str, str | int]]]] = None class OrgResponse(CustomBaseModel): @@ -57,4 +58,4 @@ class OrgResponse(CustomBaseModel): class UserDeleteUserRequest(UserIDMixin): - pass \ No newline at end of file + pass diff --git a/src/user/service.py b/src/user/service.py index e072251..49ab238 100644 --- a/src/user/service.py +++ b/src/user/service.py @@ -4,6 +4,7 @@ Module specific business logic for user module Exports: - add_user_to_db: Creates a User record from OIDC claims, or updates user details """ + from typing import Any from sqlalchemy.orm import Session @@ -16,7 +17,12 @@ from src.user.models import User async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: try: - valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"]) + valid_user = OIDCUser( + first_name=user_claims["given_name"], + last_name=user_claims["family_name"], + email=user_claims["email"], + oidc_id=user_claims["sub"], + ) except Exception as e: print(e) raise UnprocessableContentException("Invalid or missing OIDC data") diff --git a/src/user/utils.py b/src/user/utils.py index 35fcc1a..2f10b91 100644 --- a/src/user/utils.py +++ b/src/user/utils.py @@ -1,3 +1,3 @@ """ Non-business logic reusable functions and classes for the user module -""" \ No newline at end of file +""" diff --git a/test/conftest.py b/test/conftest.py index e16e799..e06b3d0 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -37,11 +37,14 @@ def db_session(): async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]: def get_db_override(): return db_session + app.dependency_overrides[get_db] = get_db_override app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_super_admin_list] = testing_su_list transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac: + async with AsyncClient( + transport=transport, base_url="http://localhost:8000/api/v1" + ) as ac: yield ac app.dependency_overrides.clear() @@ -51,37 +54,58 @@ async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]: async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]: def get_db_override(): return db_session + app.dependency_overrides[get_db] = get_db_override transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac: + async with AsyncClient( + transport=transport, base_url="http://localhost:8000/api/v1" + ) as ac: yield ac app.dependency_overrides.clear() - @pytest.fixture async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]: def get_db_override(): return db_session + app.dependency_overrides[get_db] = get_db_override app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_super_admin_list] = empty_su_list transport = ASGITransport(app=app) - async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac: + async with AsyncClient( + transport=transport, base_url="http://localhost:8000/api/v1" + ) as ac: yield ac app.dependency_overrides.clear() def _seed(db): - db.add(User(email="admin@test.com", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-mnop")) + db.add( + User( + email="admin@test.com", + first_name="Admin", + last_name="Test", + oidc_id="abcd-efgh-ijkl-mnop", + ) + ) db.add(Contact(org_id=1, email="billing@test.org", phonenumber="07521539927")) db.add(Contact(org_id=1, email="owner@test.org", phonenumber="07521539927")) db.add(Contact(org_id=1, email="security@test.org", phonenumber="07521539927")) db.flush() - db.add(Org(name="Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, - status="approved", intake_questionnaire={"question_two": "answer two"})) + db.add( + Org( + name="Test Org", + root_user_id=1, + billing_contact_id=1, + owner_contact_id=2, + security_contact_id=3, + status="approved", + intake_questionnaire={"question_two": "answer two"}, + ) + ) db.add(Service(name="Test Service", api_key="123456789")) db.add(Permission(service_id=1, resource="test_resource", action="read")) db.add(Group(name="Test Group", org_id=1)) @@ -131,6 +155,7 @@ def generate_query_and_status(params) -> list[tuple[str, int]]: return query_and_status + # # Produces a text file with method and path for every endpoint in the API # from fastapi.routing import APIRoute # diff --git a/test/test_auth_approval.py b/test/test_auth_approval.py index f395865..69c8a25 100644 --- a/test/test_auth_approval.py +++ b/test/test_auth_approval.py @@ -3,6 +3,7 @@ This test module checks relevant endpoints to ensure only approved orgs get acce Endpoints not checked here are endpoints that do not require an org check. Delete endpoints are currently skipped because the testing system cannot use bodies in deletes. """ + import pytest from httpx import AsyncClient @@ -27,18 +28,27 @@ async def test_get_org_auth_approval(default_client: AsyncClient): @pytest.mark.anyio async def test_patch_org_questionnaire_auth_approval(default_client: AsyncClient): - resp = await default_client.patch("/org/questionnaire", json={"organisation_id": 1, - "intake_questionnaire": {"question_one": "new answer one", - "question_two": None, - "question_three": None}, - "partial": True}) + resp = await default_client.patch( + "/org/questionnaire", + json={ + "organisation_id": 1, + "intake_questionnaire": { + "question_one": "new answer one", + "question_two": None, + "question_three": None, + }, + "partial": True, + }, + ) assert resp.status_code != 422 assert resp.status_code == 200 @pytest.mark.anyio async def test_patch_org_status_auth_approval(default_client: AsyncClient): - resp = await default_client.patch("/org/status", json={"organisation_id": 1, "status": "submitted"}) + resp = await default_client.patch( + "/org/status", json={"organisation_id": 1, "status": "submitted"} + ) assert resp.status_code != 422 assert resp.status_code == 200 @@ -52,22 +62,42 @@ async def test_get_org_users_auth_approval(default_client: AsyncClient): @pytest.mark.anyio async def test_post_org_user_auth_approval(default_client: AsyncClient, db_session): - db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) + db_session.add( + User( + email="user@test.org", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-1234", + ) + ) db_session.flush() - resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 2}) + resp = await default_client.post( + "/org/user", json={"organisation_id": 1, "user_id": 2} + ) assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_patch_org_root_user_auth_approval(default_client: AsyncClient, db_session): - db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) +async def test_patch_org_root_user_auth_approval( + default_client: AsyncClient, db_session +): + db_session.add( + User( + email="user@test.org", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-1234", + ) + ) db_session.flush() db_session.add(OrgUsers(org_id=1, user_id=2)) db_session.flush() - resp = await default_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2}) + resp = await default_client.patch( + "/org/root_user", json={"organisation_id": 1, "user_id": 2} + ) assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @@ -88,8 +118,14 @@ async def test_get_org_contact_auth_approval(default_client: AsyncClient): @pytest.mark.anyio async def test_patch_org_contact_auth_approval(default_client: AsyncClient): - resp = await default_client.patch("/org/contact", - json={"organisation_id": 1, "contact_type": "billing", "email": "user@example.com"}) + resp = await default_client.patch( + "/org/contact", + json={ + "organisation_id": 1, + "contact_type": "billing", + "email": "user@example.com", + }, + ) assert resp.status_code != 422 assert resp.status_code == 200 @@ -117,26 +153,44 @@ async def test_get_iam_group_users_auth_approval(default_client: AsyncClient): @pytest.mark.anyio async def test_post_iam_group_auth_approval(default_client: AsyncClient): - resp = await default_client.post("/iam/group", json={"name": "New Group", "organisation_id": 1}) + resp = await default_client.post( + "/iam/group", json={"name": "New Group", "organisation_id": 1} + ) assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_put_iam_group_permission_auth_approval(default_client: AsyncClient, db_session): +async def test_put_iam_group_permission_auth_approval( + default_client: AsyncClient, db_session +): db_session.add(Group(name="Test Group Two", org_id=1)) db_session.flush() - resp = await default_client.put("/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 1}) + resp = await default_client.put( + "/iam/group/permission", + json={"permission_id": 1, "group_id": 2, "organisation_id": 1}, + ) assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_put_iam_group_user_auth_approval(default_client: AsyncClient, db_session): - db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) +async def test_put_iam_group_user_auth_approval( + default_client: AsyncClient, db_session +): + db_session.add( + User( + email="user@test.org", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-1234", + ) + ) db_session.flush() - resp = await default_client.put("/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1}) + resp = await default_client.put( + "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1} + ) assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @@ -150,6 +204,8 @@ async def test_get_iam_permissions_auth_approval(default_client: AsyncClient): @pytest.mark.anyio async def test_post_iam_permissions_search_auth_approval(default_client: AsyncClient): - resp = await default_client.post("/iam/permissions/search", json={"organisation_id": 1, "action": "read"}) + resp = await default_client.post( + "/iam/permissions/search", json={"organisation_id": 1, "action": "read"} + ) assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] diff --git a/test/test_auth_general.py b/test/test_auth_general.py index 6aeba96..0547824 100644 --- a/test/test_auth_general.py +++ b/test/test_auth_general.py @@ -1,5 +1,5 @@ -""" -""" +""" """ + import pytest from httpx import AsyncClient @@ -10,11 +10,26 @@ from src.user.models import User @pytest.mark.anyio async def test_get_org_auth_root_su(default_client: AsyncClient, db_session): # If a super admin can access a resource when not the root user - db_session.add(User(email="admin@test.org", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-4321")) + db_session.add( + User( + email="admin@test.org", + first_name="Admin", + last_name="Test", + oidc_id="abcd-efgh-ijkl-4321", + ) + ) db_session.flush() db_session.add( - Org(name="Test Org Two", root_user_id=2, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, - status="approved", intake_questionnaire={})) + Org( + name="Test Org Two", + root_user_id=2, + billing_contact_id=1, + owner_contact_id=2, + security_contact_id=3, + status="approved", + intake_questionnaire={}, + ) + ) db_session.flush() resp = await default_client.get("/org?org_id=2") diff --git a/test/test_auth_root.py b/test/test_auth_root.py index 16e3afa..e67bc6a 100644 --- a/test/test_auth_root.py +++ b/test/test_auth_root.py @@ -2,6 +2,7 @@ This module ensures root user only endpoints do return a correctly formatted 401 when user is not the root user for the org DELETE endpoints are not tested """ + import pytest from httpx import AsyncClient @@ -12,10 +13,26 @@ from src.iam.models import Group @pytest.fixture(autouse=True) def add_second_org(db_session): - db_session.add(User(email="admin@test.org", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-4321")) + db_session.add( + User( + email="admin@test.org", + first_name="Admin", + last_name="Test", + oidc_id="abcd-efgh-ijkl-4321", + ) + ) db_session.flush() - db_session.add(Org(name="Test Org Two", root_user_id=2, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, - status="approved", intake_questionnaire={})) + db_session.add( + Org( + name="Test Org Two", + root_user_id=2, + billing_contact_id=1, + owner_contact_id=2, + security_contact_id=3, + status="approved", + intake_questionnaire={}, + ) + ) db_session.flush() @@ -29,11 +46,18 @@ async def test_get_org_auth_root(no_su_client: AsyncClient): @pytest.mark.anyio async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.patch("/org/questionnaire", json={"organisation_id": 2, - "intake_questionnaire": {"question_one": "new answer one", - "question_two": None, - "question_three": None}, - "partial": True}) + resp = await no_su_client.patch( + "/org/questionnaire", + json={ + "organisation_id": 2, + "intake_questionnaire": { + "question_one": "new answer one", + "question_two": None, + "question_three": None, + }, + "partial": True, + }, + ) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] @@ -49,10 +73,19 @@ async def test_get_org_users_auth_root(no_su_client: AsyncClient): @pytest.mark.anyio async def test_post_org_user_auth_root(no_su_client: AsyncClient, db_session): - db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) + db_session.add( + User( + email="user@test.org", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-1234", + ) + ) db_session.flush() - resp = await no_su_client.post("/org/user", json={"organisation_id": 2, "user_id": 2}) + resp = await no_su_client.post( + "/org/user", json={"organisation_id": 2, "user_id": 2} + ) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] @@ -76,8 +109,14 @@ async def test_get_org_contact_auth_root(no_su_client: AsyncClient): @pytest.mark.anyio async def test_patch_org_contact_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.patch("/org/contact", - json={"organisation_id": 2, "contact_type": "billing", "email": "user@example.com"}) + resp = await no_su_client.patch( + "/org/contact", + json={ + "organisation_id": 2, + "contact_type": "billing", + "email": "user@example.com", + }, + ) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] @@ -109,17 +148,24 @@ async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient): @pytest.mark.anyio async def test_post_iam_group_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.post("/iam/group", json={"name": "New Group", "organisation_id": 2}) + resp = await no_su_client.post( + "/iam/group", json={"name": "New Group", "organisation_id": 2} + ) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio -async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient, db_session): +async def test_put_iam_group_permission_auth_root( + no_su_client: AsyncClient, db_session +): db_session.add(Group(name="Test Group Two", org_id=2)) db_session.flush() - resp = await no_su_client.put("/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 2}) + resp = await no_su_client.put( + "/iam/group/permission", + json={"permission_id": 1, "group_id": 2, "organisation_id": 2}, + ) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] @@ -127,10 +173,19 @@ async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient, db_ @pytest.mark.anyio async def test_put_iam_group_user_auth_root(no_su_client: AsyncClient, db_session): - db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) + db_session.add( + User( + email="user@test.org", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-1234", + ) + ) db_session.flush() - resp = await no_su_client.put("/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2}) + resp = await no_su_client.put( + "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2} + ) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] @@ -146,7 +201,9 @@ async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient): @pytest.mark.anyio async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.post("/iam/permissions/search", json={"organisation_id": 2, "action": "read"}) + resp = await no_su_client.post( + "/iam/permissions/search", json={"organisation_id": 2, "action": "read"} + ) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] diff --git a/test/test_auth_su.py b/test/test_auth_su.py index 9cef61f..18319a4 100644 --- a/test/test_auth_su.py +++ b/test/test_auth_su.py @@ -2,6 +2,7 @@ This module ensures super admin only endpoints do return a correctly formatted 401 when user is not a super admin DELETE endpoints are not tested """ + import pytest from httpx import AsyncClient @@ -19,7 +20,9 @@ async def test_get_user_auth_su(no_su_client: AsyncClient): @pytest.mark.anyio async def test_patch_org_status_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.patch("/org/status", json={"organisation_id": 1, "status": "submitted"}) + resp = await no_su_client.patch( + "/org/status", json={"organisation_id": 1, "status": "submitted"} + ) assert resp.status_code != 422 assert resp.status_code == 401 assert resp.json()["detail"] == "Must be super admin" @@ -27,12 +30,21 @@ async def test_patch_org_status_auth_su(no_su_client: AsyncClient): @pytest.mark.anyio async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient, db_session): - db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) + db_session.add( + User( + email="user@test.org", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-1234", + ) + ) db_session.flush() db_session.add(OrgUsers(org_id=1, user_id=2)) db_session.flush() - resp = await no_su_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2}) + resp = await no_su_client.patch( + "/org/root_user", json={"organisation_id": 1, "user_id": 2} + ) assert resp.status_code != 422 assert resp.status_code == 401 assert resp.json()["detail"] == "Must be super admin" @@ -56,7 +68,10 @@ async def test_post_service_auth_su(no_su_client: AsyncClient): @pytest.mark.anyio async def test_post_perm_success(no_su_client: AsyncClient, db_session): - resp = await no_su_client.post("/iam/permission", json={"service_id": 1, "resource": "test_resource", "action": "create"}) + resp = await no_su_client.post( + "/iam/permission", + json={"service_id": 1, "resource": "test_resource", "action": "create"}, + ) assert resp.status_code != 422 assert resp.status_code == 401 assert resp.json()["detail"] == "Must be super admin" diff --git a/test/test_auth_user.py b/test/test_auth_user.py index 6f1c655..3cd66d0 100644 --- a/test/test_auth_user.py +++ b/test/test_auth_user.py @@ -1,6 +1,7 @@ """ This testing module removes the testing user override to verify that endpoints with only the user requirement return a 401 error when not logged in """ + import pytest from httpx import AsyncClient diff --git a/test/test_healthcheck.py b/test/test_healthcheck.py index 6fdb9be..47a3993 100644 --- a/test/test_healthcheck.py +++ b/test/test_healthcheck.py @@ -4,7 +4,7 @@ from httpx import AsyncClient @pytest.mark.anyio async def test_healthcheck(default_client: AsyncClient): - resp = await default_client.get("/healthcheck") + resp = await default_client.get("/healthcheck") - assert resp.status_code == 200 - assert resp.json() == {"status": "ok"} + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} diff --git a/test/test_iam.py b/test/test_iam.py index acc8358..05cbabb 100644 --- a/test/test_iam.py +++ b/test/test_iam.py @@ -1,5 +1,5 @@ -""" -""" +""" """ + import pytest from httpx import AsyncClient @@ -15,13 +15,15 @@ async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient body = { "service": "Test Service", "organisation": "Test Org", - "resource": "test_resource" + "resource": "test_resource", } headers = { "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": "123456789" + "X-API-Key": "123456789", } - resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers) + resp = await default_client.post( + "/iam/can_act_on_resource?action=read", json=body, headers=headers + ) data = resp.json() assert resp.status_code == 200 @@ -30,23 +32,20 @@ async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient @pytest.mark.parametrize( "service, api_key", - [ - ("Test Service", "not_the_correct_key"), - ("Test Service Two", "123456789") - ], + [("Test Service", "not_the_correct_key"), ("Test Service Two", "123456789")], ) @pytest.mark.anyio -async def test_act_on_resource_wrong_key(default_client: AsyncClient, db_session, service: str, api_key: str): - body = { - "service": service, - "organisation": "Test Org", - "resource": "test_resource" - } +async def test_act_on_resource_wrong_key( + default_client: AsyncClient, db_session, service: str, api_key: str +): + body = {"service": service, "organisation": "Test Org", "resource": "test_resource"} headers = { "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": api_key + "X-API-Key": api_key, } - resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers) + resp = await default_client.post( + "/iam/can_act_on_resource?action=read", json=body, headers=headers + ) data = resp.json() assert resp.status_code == 401 @@ -58,12 +57,12 @@ async def test_act_on_resource_missing_key(default_client: AsyncClient): body = { "service": "Test Service", "organisation": "Test Org", - "resource": "test_resource" + "resource": "test_resource", } - headers = { - "Authorization": "Bearer not_checked_when_auth_is_disabled" - } - resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers) + headers = {"Authorization": "Bearer not_checked_when_auth_is_disabled"} + resp = await default_client.post( + "/iam/can_act_on_resource?action=read", json=body, headers=headers + ) data = resp.json() assert resp.status_code == 401 @@ -82,18 +81,17 @@ async def test_act_on_resource_missing_key(default_client: AsyncClient): ], ) @pytest.mark.anyio -async def test_act_on_resource_endpoint_status_checks(default_client: AsyncClient, service, org, resource, action, - expected_status: int): - body = { - "service": service, - "organisation": org, - "resource": resource - } +async def test_act_on_resource_endpoint_status_checks( + default_client: AsyncClient, service, org, resource, action, expected_status: int +): + body = {"service": service, "organisation": org, "resource": resource} headers = { "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": "123456789" + "X-API-Key": "123456789", } - resp = await default_client.post(f"/iam/can_act_on_resource?action={action}", json=body, headers=headers) + resp = await default_client.post( + f"/iam/can_act_on_resource?action={action}", json=body, headers=headers + ) assert resp.status_code == expected_status @@ -108,18 +106,23 @@ async def test_act_on_resource_endpoint_status_checks(default_client: AsyncClien ], ) @pytest.mark.anyio -async def test_act_on_resource_logic(default_client: AsyncClient, db_session, service, org, resource, action, - expected_response: bool): - body = { - "service": service, - "organisation": org, - "resource": resource - } +async def test_act_on_resource_logic( + default_client: AsyncClient, + db_session, + service, + org, + resource, + action, + expected_response: bool, +): + body = {"service": service, "organisation": org, "resource": resource} headers = { "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": "123456789" + "X-API-Key": "123456789", } - resp = await default_client.post(f"/iam/can_act_on_resource?action={action}", json=body, headers=headers) + resp = await default_client.post( + f"/iam/can_act_on_resource?action={action}", json=body, headers=headers + ) data = resp.json() assert resp.status_code == 200 @@ -140,11 +143,12 @@ async def test_get_group_permissions_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", - generate_query_and_status(["group_id", "org_id"]) + "query, expected_status", generate_query_and_status(["group_id", "org_id"]) ) @pytest.mark.anyio -async def test_get_group_permissions_status_checks(default_client: AsyncClient, db_session, query: str, expected_status: int): +async def test_get_group_permissions_status_checks( + default_client: AsyncClient, db_session, query: str, expected_status: int +): resp = await default_client.get(f"/iam/group/permissions?{query}") assert resp.status_code == expected_status @@ -158,8 +162,19 @@ async def test_get_group_permissions_status_checks(default_client: AsyncClient, ], ) @pytest.mark.anyio -async def test_get_group_permissions_mismatch(default_client: AsyncClient, db_session, query: str): - db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved")) +async def test_get_group_permissions_mismatch( + default_client: AsyncClient, db_session, query: str +): + db_session.add( + Org( + name="Another Test Org", + root_user_id=1, + billing_contact_id=1, + owner_contact_id=2, + security_contact_id=3, + status="approved", + ) + ) db_session.add(Group(name="Another Test Group", org_id=2)) db_session.flush() resp = await default_client.get(f"/iam/group/permissions?{query}") @@ -183,11 +198,12 @@ async def test_get_group_users_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", - generate_query_and_status(["group_id", "org_id"]) + "query, expected_status", generate_query_and_status(["group_id", "org_id"]) ) @pytest.mark.anyio -async def test_get_group_users_status_checks(default_client: AsyncClient, query: str, expected_status: int): +async def test_get_group_users_status_checks( + default_client: AsyncClient, query: str, expected_status: int +): resp = await default_client.get(f"/iam/group/users?{query}") assert resp.status_code == expected_status @@ -201,8 +217,19 @@ async def test_get_group_users_status_checks(default_client: AsyncClient, query: ], ) @pytest.mark.anyio -async def test_get_group_users_mismatch(default_client: AsyncClient, db_session, query: str): - db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved")) +async def test_get_group_users_mismatch( + default_client: AsyncClient, db_session, query: str +): + db_session.add( + Org( + name="Another Test Org", + root_user_id=1, + billing_contact_id=1, + owner_contact_id=2, + security_contact_id=3, + status="approved", + ) + ) db_session.add(Group(name="Another Test Group", org_id=2)) db_session.flush() resp = await default_client.get(f"/iam/group/users?{query}") @@ -213,7 +240,9 @@ async def test_get_group_users_mismatch(default_client: AsyncClient, db_session, @pytest.mark.anyio async def test_post_group_success(default_client: AsyncClient): - resp = await default_client.post("/iam/group", json={"name": "New Group", "organisation_id": 1}) + resp = await default_client.post( + "/iam/group", json={"name": "New Group", "organisation_id": 1} + ) data = resp.json() assert resp.status_code == 200 @@ -227,10 +256,22 @@ async def test_post_group_success(default_client: AsyncClient): "body, expected_status", [ ({"organisation_id": 1, "name": "Test Group"}, 409), - ({"organisation_id": 2, "name": "new group"}, 404), # Non-existent organisation, valid name - ({"organisation_id": "banana", "name": "new group"}, 422), # Invalid organisation ID, valid name - ({"organisation_id": "", "name": "new group"}, 422), # Blank organisation ID, valid name - ({"organisation_id": -1, "name": "new group"}, 422), # Negative organisation ID, valid name + ( + {"organisation_id": 2, "name": "new group"}, + 404, + ), # Non-existent organisation, valid name + ( + {"organisation_id": "banana", "name": "new group"}, + 422, + ), # Invalid organisation ID, valid name + ( + {"organisation_id": "", "name": "new group"}, + 422, + ), # Blank organisation ID, valid name + ( + {"organisation_id": -1, "name": "new group"}, + 422, + ), # Negative organisation ID, valid name ({"name": 1}, 422), # Only name ({}, 422), # Blank body ({"organisation_id": "", "name": ""}, 422), # Both blank @@ -241,7 +282,9 @@ async def test_post_group_success(default_client: AsyncClient): ], ) @pytest.mark.anyio -async def test_post_group_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): +async def test_post_group_status_checks( + default_client: AsyncClient, body: dict[str, str], expected_status: int +): resp = await default_client.post("/iam/group", json=body) assert resp.status_code == expected_status @@ -251,7 +294,10 @@ async def test_post_group_status_checks(default_client: AsyncClient, body: dict[ async def test_put_group_perm_success(default_client: AsyncClient, db_session): db_session.add(Group(name="Test Group Two", org_id=1)) db_session.flush() - resp = await default_client.put("/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 1}) + resp = await default_client.put( + "/iam/group/permission", + json={"permission_id": 1, "group_id": 2, "organisation_id": 1}, + ) data = resp.json() assert resp.status_code == 200 @@ -270,36 +316,71 @@ async def test_put_group_perm_success(default_client: AsyncClient, db_session): @pytest.mark.parametrize( "body, expected_status", [ - ({"organisation_id": 42, "group_id": 1, "permission_id": 1}, 404), # Non-existent organisation - ({"organisation_id": "banana", "group_id": 1, "permission_id": 1}, 422), # Invalid organisation ID - ({"organisation_id": "", "group_id": 1, "permission_id": 1}, 422), # Blank organisation ID - ({"organisation_id": -1, "group_id": 1, "permission_id": 1}, 422), # Negative organisation ID - - ({"organisation_id": 1, "group_id": 42, "permission_id": 1}, 404), # Non-existent group - ({"organisation_id": 1, "group_id": "banana", "permission_id": 1}, 422), # Invalid group ID - ({"organisation_id": 1, "group_id": "", "permission_id": 1}, 422), # Blank group ID - ({"organisation_id": 1, "group_id": -1, "permission_id": 1}, 422), # Negative group ID - - ({"organisation_id": 1, "group_id": 1, "permission_id": 42}, 404), # Non-existent permission - ({"organisation_id": 1, "group_id": 1, "permission_id": "banana"}, 422), # Invalid permission ID - ({"organisation_id": 1, "group_id": 1, "permission_id": ""}, 422), # Blank permission ID - ({"organisation_id": 1, "group_id": 1, "permission_id": -1}, 422), # Negative permission ID - + ( + {"organisation_id": 42, "group_id": 1, "permission_id": 1}, + 404, + ), # Non-existent organisation + ( + {"organisation_id": "banana", "group_id": 1, "permission_id": 1}, + 422, + ), # Invalid organisation ID + ( + {"organisation_id": "", "group_id": 1, "permission_id": 1}, + 422, + ), # Blank organisation ID + ( + {"organisation_id": -1, "group_id": 1, "permission_id": 1}, + 422, + ), # Negative organisation ID + ( + {"organisation_id": 1, "group_id": 42, "permission_id": 1}, + 404, + ), # Non-existent group + ( + {"organisation_id": 1, "group_id": "banana", "permission_id": 1}, + 422, + ), # Invalid group ID + ( + {"organisation_id": 1, "group_id": "", "permission_id": 1}, + 422, + ), # Blank group ID + ( + {"organisation_id": 1, "group_id": -1, "permission_id": 1}, + 422, + ), # Negative group ID + ( + {"organisation_id": 1, "group_id": 1, "permission_id": 42}, + 404, + ), # Non-existent permission + ( + {"organisation_id": 1, "group_id": 1, "permission_id": "banana"}, + 422, + ), # Invalid permission ID + ( + {"organisation_id": 1, "group_id": 1, "permission_id": ""}, + 422, + ), # Blank permission ID + ( + {"organisation_id": 1, "group_id": 1, "permission_id": -1}, + 422, + ), # Negative permission ID ({}, 422), # Blank body ({"permission_id": 1}, 422), # Only permission ({"organisation_id": 1}, 422), # Only organisation ({"group_id": 1}, 422), # Only group - ({"organisation_id": 1, "permission_id": 1}, 422), # Missing group ({"group_id": 1, "permission_id": 1}, 422), # Missing organisation ({"organisation_id": 1, "group_id": 1}, 422), # Missing permission - - ({"organisation_id": 1, "group_id": 1, "permission_id": 1}, 409), # Permission already in group - + ( + {"organisation_id": 1, "group_id": 1, "permission_id": 1}, + 409, + ), # Permission already in group ], ) @pytest.mark.anyio -async def test_put_group_perm_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): +async def test_put_group_perm_status_checks( + default_client: AsyncClient, body: dict[str, str], expected_status: int +): resp = await default_client.put("/iam/group/permission", json=body) assert resp.status_code == expected_status @@ -313,8 +394,19 @@ async def test_put_group_perm_status_checks(default_client: AsyncClient, body: d ], ) @pytest.mark.anyio -async def test_put_group_perm_mismatch(default_client: AsyncClient, db_session, body: dict): - db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved")) +async def test_put_group_perm_mismatch( + default_client: AsyncClient, db_session, body: dict +): + db_session.add( + Org( + name="Another Test Org", + root_user_id=1, + billing_contact_id=1, + owner_contact_id=2, + security_contact_id=3, + status="approved", + ) + ) db_session.add(Group(name="Another Test Group", org_id=2)) db_session.flush() resp = await default_client.put("/iam/group/permission", json=body) @@ -325,10 +417,19 @@ async def test_put_group_perm_mismatch(default_client: AsyncClient, db_session, @pytest.mark.anyio async def test_put_group_user_success(default_client: AsyncClient, db_session): - db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) + db_session.add( + User( + email="user@test.org", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-1234", + ) + ) db_session.flush() - resp = await default_client.put("/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1}) + resp = await default_client.put( + "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1} + ) data = resp.json() assert resp.status_code == 200 @@ -348,34 +449,58 @@ async def test_put_group_user_success(default_client: AsyncClient, db_session): @pytest.mark.parametrize( "body, expected_status", [ - ({"organisation_id": 42, "group_id": 1, "user_id": 1}, 404), # Non-existent organisation - ({"organisation_id": "banana", "group_id": 1, "user_id": 1}, 422), # Invalid organisation ID - ({"organisation_id": "", "group_id": 1, "user_id": 1}, 422), # Blank organisation ID - ({"organisation_id": -1, "group_id": 1, "user_id": 1}, 422), # Negative organisation ID - - ({"organisation_id": 1, "group_id": 42, "user_id": 1}, 404), # Non-existent group - ({"organisation_id": 1, "group_id": "banana", "user_id": 1}, 422), # Invalid group ID + ( + {"organisation_id": 42, "group_id": 1, "user_id": 1}, + 404, + ), # Non-existent organisation + ( + {"organisation_id": "banana", "group_id": 1, "user_id": 1}, + 422, + ), # Invalid organisation ID + ( + {"organisation_id": "", "group_id": 1, "user_id": 1}, + 422, + ), # Blank organisation ID + ( + {"organisation_id": -1, "group_id": 1, "user_id": 1}, + 422, + ), # Negative organisation ID + ( + {"organisation_id": 1, "group_id": 42, "user_id": 1}, + 404, + ), # Non-existent group + ( + {"organisation_id": 1, "group_id": "banana", "user_id": 1}, + 422, + ), # Invalid group ID ({"organisation_id": 1, "group_id": "", "user_id": 1}, 422), # Blank group ID - ({"organisation_id": 1, "group_id": -1, "user_id": 1}, 422), # Negative group ID - - ({"organisation_id": 1, "group_id": 1, "user_id": 42}, 404), # Non-existent user - ({"organisation_id": 1, "group_id": 1, "user_id": "banana"}, 422), # Invalid user ID + ( + {"organisation_id": 1, "group_id": -1, "user_id": 1}, + 422, + ), # Negative group ID + ( + {"organisation_id": 1, "group_id": 1, "user_id": 42}, + 404, + ), # Non-existent user + ( + {"organisation_id": 1, "group_id": 1, "user_id": "banana"}, + 422, + ), # Invalid user ID ({"organisation_id": 1, "group_id": 1, "user_id": ""}, 422), # Blank user ID ({"organisation_id": 1, "group_id": 1, "user_id": -1}, 422), # Negative user ID - ({}, 422), # Blank body ({"user_id": 1}, 422), # Only user ({"organisation_id": 1}, 422), # Only organisation ({"group_id": 1}, 422), # Only group - ({"organisation_id": 1, "user_id": 1}, 422), # Missing group ({"group_id": 1, "user_id": 1}, 422), # Missing organisation ({"organisation_id": 1, "group_id": 1}, 422), # Missing user - ], ) @pytest.mark.anyio -async def test_put_group_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): +async def test_put_group_user_status_checks( + default_client: AsyncClient, body: dict[str, str], expected_status: int +): resp = await default_client.put("/iam/group/user", json=body) assert resp.status_code == expected_status @@ -395,11 +520,12 @@ async def test_get_permissions_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", - generate_query_and_status(["org_id"]) + "query, expected_status", generate_query_and_status(["org_id"]) ) @pytest.mark.anyio -async def test_get_permissions_status_checks(default_client: AsyncClient, query: str, expected_status: int): +async def test_get_permissions_status_checks( + default_client: AsyncClient, query: str, expected_status: int +): resp = await default_client.get(f"/iam/permissions?{query}") assert resp.status_code == expected_status @@ -407,7 +533,10 @@ async def test_get_permissions_status_checks(default_client: AsyncClient, query: @pytest.mark.anyio async def test_post_perm_success(default_client: AsyncClient, db_session): - resp = await default_client.post("/iam/permission", json={"service_id": 1, "resource": "test_resource", "action": "create"}) + resp = await default_client.post( + "/iam/permission", + json={"service_id": 1, "resource": "test_resource", "action": "create"}, + ) data = resp.json() assert resp.status_code == 200 @@ -418,51 +547,70 @@ async def test_post_perm_success(default_client: AsyncClient, db_session): @pytest.mark.parametrize( - "body, expected_status", - [ - # service_id tests - ({"service_id": 42, "resource": "test_resource", "action": "read"}, 404), # Non-existent service - ({"service_id": "banana", "resource": "test_resource", "action": "read"}, 422), # Invalid service ID - ({"service_id": "", "resource": "test_resource", "action": "read"}, 422), # Blank service ID - ({"service_id": -1, "resource": "test_resource", "action": "read"}, 422), # Negative service ID - - # resource tests - ({"service_id": 1, "resource": 42, "action": "read"}, 422), # Invalid resource type - - # action tests - ({"service_id": 1, "resource": "test_resource", "action": 42}, 422), # Invalid action type - - # missing/partial body tests - ({}, 422), # Blank body - ({"resource": "test_resource"}, 422), # Only resource - ({"action": "read"}, 422), # Only action - ({"service_id": 1}, 422), # Only service - - ({"service_id": 1, "action": "read"}, 422), # Missing resource - ({"service_id": 1, "resource": "test_resource"}, 422), # Missing action - ({"resource": "test_resource", "action": "read"}, 422), # Missing service - - ], + "body, expected_status", + [ + # service_id tests + ( + {"service_id": 42, "resource": "test_resource", "action": "read"}, + 404, + ), # Non-existent service + ( + {"service_id": "banana", "resource": "test_resource", "action": "read"}, + 422, + ), # Invalid service ID + ( + {"service_id": "", "resource": "test_resource", "action": "read"}, + 422, + ), # Blank service ID + ( + {"service_id": -1, "resource": "test_resource", "action": "read"}, + 422, + ), # Negative service ID + # resource tests + ( + {"service_id": 1, "resource": 42, "action": "read"}, + 422, + ), # Invalid resource type + # action tests + ( + {"service_id": 1, "resource": "test_resource", "action": 42}, + 422, + ), # Invalid action type + # missing/partial body tests + ({}, 422), # Blank body + ({"resource": "test_resource"}, 422), # Only resource + ({"action": "read"}, 422), # Only action + ({"service_id": 1}, 422), # Only service + ({"service_id": 1, "action": "read"}, 422), # Missing resource + ({"service_id": 1, "resource": "test_resource"}, 422), # Missing action + ({"resource": "test_resource", "action": "read"}, 422), # Missing service + ], ) @pytest.mark.anyio -async def test_post_perm_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): +async def test_post_perm_status_checks( + default_client: AsyncClient, body: dict[str, str], expected_status: int +): resp = await default_client.post("/iam/permission", json=body) assert resp.status_code == expected_status @pytest.mark.parametrize( - "body", - [ - {"organisation_id": 1, "service_id": 1, "resource": "test_resource", "action": "read"}, - - {"organisation_id": 1, "service_id": 1}, - {"organisation_id": 1, "resource": "test_resource"}, - {"organisation_id": 1, "action": "read"}, - {"organisation_id": 1, "service_id": 1, "action": "read"}, - {"organisation_id": 1, "service_id": 1, "resource": "test_resource"}, - {"organisation_id": 1, "resource": "test_resource", "action": "read"}, - ], + "body", + [ + { + "organisation_id": 1, + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + {"organisation_id": 1, "service_id": 1}, + {"organisation_id": 1, "resource": "test_resource"}, + {"organisation_id": 1, "action": "read"}, + {"organisation_id": 1, "service_id": 1, "action": "read"}, + {"organisation_id": 1, "service_id": 1, "resource": "test_resource"}, + {"organisation_id": 1, "resource": "test_resource", "action": "read"}, + ], ) @pytest.mark.anyio async def test_post_perm_search_success(default_client: AsyncClient, db_session, body): @@ -478,33 +626,96 @@ async def test_post_perm_search_success(default_client: AsyncClient, db_session, @pytest.mark.parametrize( - "body, expected_status", - [ - # organisation_id tests - ({"organisation_id": 42, "service_id": 1, "resource": "test_resource", "action": "read"}, 404), # Non-existent organisation - ({"organisation_id": "banana", "service_id": 1, "resource": "test_resource", "action": "read"}, 422), # Invalid organisation ID - ({"organisation_id": "", "service_id": 1, "resource": "test_resource", "action": "read"}, 422), # Blank organisation ID - ({"organisation_id": -1, "service_id": 1, "resource": "test_resource", "action": "read"}, 422), # Negative organisation ID - - # service_id tests - ({"organisation_id": 1, "service_id": "banana", "resource": "test_resource", "action": "read"}, 422), # Invalid service ID - ({"organisation_id": 1, "service_id": "", "resource": "test_resource", "action": "read"}, 422), # Blank service ID - ({"organisation_id": 1, "service_id": -1, "resource": "test_resource", "action": "read"}, 422), # Negative service ID - - # resource tests - ({"organisation_id": 1, "service_id": 1, "resource": 42, "action": "read"}, 422), # Invalid resource type - - # action tests - ({"organisation_id": 1, "service_id": 1, "resource": "test_resource", "action": 42}, 422), # Invalid action type - - # missing/partial body tests - ({}, 422), # Blank body - ], + "body, expected_status", + [ + # organisation_id tests + ( + { + "organisation_id": 42, + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + 404, + ), # Non-existent organisation + ( + { + "organisation_id": "banana", + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Invalid organisation ID + ( + { + "organisation_id": "", + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Blank organisation ID + ( + { + "organisation_id": -1, + "service_id": 1, + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Negative organisation ID + # service_id tests + ( + { + "organisation_id": 1, + "service_id": "banana", + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Invalid service ID + ( + { + "organisation_id": 1, + "service_id": "", + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Blank service ID + ( + { + "organisation_id": 1, + "service_id": -1, + "resource": "test_resource", + "action": "read", + }, + 422, + ), # Negative service ID + # resource tests + ( + {"organisation_id": 1, "service_id": 1, "resource": 42, "action": "read"}, + 422, + ), # Invalid resource type + # action tests + ( + { + "organisation_id": 1, + "service_id": 1, + "resource": "test_resource", + "action": 42, + }, + 422, + ), # Invalid action type + # missing/partial body tests + ({}, 422), # Blank body + ], ) @pytest.mark.anyio -async def test_post_perm_search_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): +async def test_post_perm_search_status_checks( + default_client: AsyncClient, body: dict[str, str], expected_status: int +): resp = await default_client.post("/iam/permissions/search", json=body) assert resp.status_code == expected_status - - diff --git a/test/test_organisation.py b/test/test_organisation.py index 4b60aa7..f0fdf5d 100644 --- a/test/test_organisation.py +++ b/test/test_organisation.py @@ -1,6 +1,7 @@ """ [DELETE] /org/ is not tested because the testing client cannot attach a body to a delete request. """ + import pytest from httpx import AsyncClient @@ -24,11 +25,12 @@ async def test_get_org_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", - generate_query_and_status(["org_id"]) + "query, expected_status", generate_query_and_status(["org_id"]) ) @pytest.mark.anyio -async def test_get_org_status_checks(default_client: AsyncClient, query: str, expected_status: int): +async def test_get_org_status_checks( + default_client: AsyncClient, query: str, expected_status: int +): resp = await default_client.get(f"/org?{query}") assert resp.status_code == expected_status @@ -53,18 +55,33 @@ async def test_post_org_success(default_client: AsyncClient): ], ) @pytest.mark.anyio -async def test_post_org_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): +async def test_post_org_status_checks( + default_client: AsyncClient, body: dict[str, str], expected_status: int +): resp = await default_client.post("/org", json=body) assert resp.status_code == expected_status @pytest.mark.anyio -async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient, db_session): +async def test_patch_org_questionnaire_partial_success( + default_client: AsyncClient, db_session +): org_model = db_session.get(Organisation, 1) org_model.status = "partial" db_session.flush() - resp = await default_client.patch("/org/questionnaire", json={"organisation_id": 1, "intake_questionnaire": {"question_one": "new answer one", "question_two": None, "question_three": None}, "partial": True}) + resp = await default_client.patch( + "/org/questionnaire", + json={ + "organisation_id": 1, + "intake_questionnaire": { + "question_one": "new answer one", + "question_two": None, + "question_three": None, + }, + "partial": True, + }, + ) data = resp.json() assert resp.status_code == 200 @@ -83,24 +100,56 @@ async def test_patch_org_questionnaire_partial_success(default_client: AsyncClie ({"organisation_id": "Test Org"}, 422), ({"organisation_id": ""}, 422), ({}, 422), - ({"organisation_id": "1", "intake_questionnaire": {"question_one": 42}, "partial": True}, 422), - ({"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}}, 422), - ({"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}, "partial": 42}, 422), + ( + { + "organisation_id": "1", + "intake_questionnaire": {"question_one": 42}, + "partial": True, + }, + 422, + ), + ( + {"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}}, + 422, + ), + ( + { + "organisation_id": "1", + "intake_questionnaire": {"question_one": "valid"}, + "partial": 42, + }, + 422, + ), ], ) @pytest.mark.anyio -async def test_patch_questionnaire_partial_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): +async def test_patch_questionnaire_partial_status_checks( + default_client: AsyncClient, body: dict[str, str], expected_status: int +): resp = await default_client.patch("/org/questionnaire", json=body) assert resp.status_code == expected_status @pytest.mark.anyio -async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient, db_session): +async def test_patch_org_questionnaire_submit_success( + default_client: AsyncClient, db_session +): org_model = db_session.get(Organisation, 1) org_model.status = "partial" db_session.flush() - resp = await default_client.patch("/org/questionnaire", json={"organisation_id": 1, "intake_questionnaire": {"question_one": "new answer one", "question_two": None, "question_three": None}, "partial": False}) + resp = await default_client.patch( + "/org/questionnaire", + json={ + "organisation_id": 1, + "intake_questionnaire": { + "question_one": "new answer one", + "question_two": None, + "question_three": None, + }, + "partial": False, + }, + ) data = resp.json() assert resp.status_code == 200 @@ -113,12 +162,13 @@ async def test_patch_org_questionnaire_submit_success(default_client: AsyncClien @pytest.mark.parametrize( - "status", - ["partial", "submitted", "remediation", "approved", "rejected", "removed"] + "status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"] ) @pytest.mark.anyio async def test_patch_org_status_success(default_client: AsyncClient, status: str): - resp = await default_client.patch("/org/status", json={"organisation_id": 1, "status": status}) + resp = await default_client.patch( + "/org/status", json={"organisation_id": 1, "status": status} + ) data = resp.json() assert resp.status_code == 200 @@ -138,7 +188,9 @@ async def test_patch_org_status_success(default_client: AsyncClient, status: str ], ) @pytest.mark.anyio -async def test_patch_org_status_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): +async def test_patch_org_status_status_checks( + default_client: AsyncClient, body: dict[str, str], expected_status: int +): resp = await default_client.patch("/org/status", json=body) assert resp.status_code == expected_status @@ -161,11 +213,12 @@ async def test_get_org_users_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", - generate_query_and_status(["org_id"]) + "query, expected_status", generate_query_and_status(["org_id"]) ) @pytest.mark.anyio -async def test_get_org_users_status_checks(default_client: AsyncClient, query: str, expected_status: int): +async def test_get_org_users_status_checks( + default_client: AsyncClient, query: str, expected_status: int +): resp = await default_client.get(f"/org/users?{query}") assert resp.status_code == expected_status @@ -173,10 +226,19 @@ async def test_get_org_users_status_checks(default_client: AsyncClient, query: s @pytest.mark.anyio async def test_post_org_user_success(default_client: AsyncClient, db_session): - db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) + db_session.add( + User( + email="user@test.org", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-1234", + ) + ) db_session.flush() - resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 2}) + resp = await default_client.post( + "/org/user", json={"organisation_id": 1, "user_id": 2} + ) data = resp.json() assert resp.status_code == 200 @@ -197,8 +259,17 @@ async def test_post_org_user_success(default_client: AsyncClient, db_session): ], ) @pytest.mark.anyio -async def test_post_org_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session): - db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) +async def test_post_org_user_status_checks( + default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session +): + db_session.add( + User( + email="user@test.org", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-1234", + ) + ) db_session.flush() resp = await default_client.post("/org/user", json=body) @@ -208,12 +279,21 @@ async def test_post_org_user_status_checks(default_client: AsyncClient, body: di @pytest.mark.anyio async def test_patch_org_root_user_success(default_client: AsyncClient, db_session): - db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) + db_session.add( + User( + email="user@test.org", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-1234", + ) + ) db_session.flush() db_session.add(OrgUsers(org_id=1, user_id=2)) db_session.flush() - resp = await default_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2}) + resp = await default_client.patch( + "/org/root_user", json={"organisation_id": 1, "user_id": 2} + ) data = resp.json() assert resp.status_code == 200 @@ -234,8 +314,17 @@ async def test_patch_org_root_user_success(default_client: AsyncClient, db_sessi ], ) @pytest.mark.anyio -async def test_patch_root_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session): - db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) +async def test_patch_root_user_status_checks( + default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session +): + db_session.add( + User( + email="user@test.org", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-1234", + ) + ) db_session.flush() db_session.add(OrgUsers(org_id=1, user_id=2)) db_session.flush() @@ -247,10 +336,19 @@ async def test_patch_root_user_status_checks(default_client: AsyncClient, body: @pytest.mark.anyio async def test_patch_org_root_user_non_member(default_client: AsyncClient, db_session): - db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) + db_session.add( + User( + email="user@test.org", + first_name="User", + last_name="Test", + oidc_id="abcd-efgh-ijkl-1234", + ) + ) db_session.flush() - resp = await default_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2}) + resp = await default_client.patch( + "/org/root_user", json={"organisation_id": 1, "user_id": 2} + ) data = resp.json() assert resp.status_code == 422 @@ -269,23 +367,23 @@ async def test_get_org_groups_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", - generate_query_and_status(["org_id"]) + "query, expected_status", generate_query_and_status(["org_id"]) ) @pytest.mark.anyio -async def test_get_org_groups_status_checks(default_client: AsyncClient, query: str, expected_status: int): +async def test_get_org_groups_status_checks( + default_client: AsyncClient, query: str, expected_status: int +): resp = await default_client.get(f"/org/groups?{query}") assert resp.status_code == expected_status -@pytest.mark.parametrize( - "contact_type", - ["billing", "security", "owner"] -) +@pytest.mark.parametrize("contact_type", ["billing", "security", "owner"]) @pytest.mark.anyio async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str): - resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}") + resp = await default_client.get( + f"/org/contact?org_id=1&contact_type={contact_type}" + ) data = resp.json() assert resp.status_code == 200 @@ -327,7 +425,9 @@ async def test_get_org_contact_success(default_client: AsyncClient, contact_type ], ) @pytest.mark.anyio -async def test_get_org_contact_status_checks(default_client: AsyncClient, query: str, expected_status: int): +async def test_get_org_contact_status_checks( + default_client: AsyncClient, query: str, expected_status: int +): resp = await default_client.get(f"/org/contact?{query}") assert resp.status_code == expected_status @@ -348,11 +448,16 @@ async def test_get_org_contact_status_checks(default_client: AsyncClient, query: ("address_region", "Glasgow City"), ("country_code", "GB"), ("postal_code", "G1 1AA"), - ] + ], ) @pytest.mark.anyio -async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str): - resp = await default_client.patch("/org/contact", json={"organisation_id": 1, "contact_type": "billing", key: value}) +async def test_patch_org_contact_success( + default_client: AsyncClient, key: str, value: str +): + resp = await default_client.patch( + "/org/contact", + json={"organisation_id": 1, "contact_type": "billing", key: value}, + ) data = resp.json() assert resp.status_code == 200 @@ -379,7 +484,9 @@ async def test_patch_org_contact_success(default_client: AsyncClient, key: str, ], ) @pytest.mark.anyio -async def test_patch_org_contact_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): +async def test_patch_org_contact_status_checks( + default_client: AsyncClient, body: dict[str, str], expected_status: int +): resp = await default_client.patch("/org/contact", json=body) assert resp.status_code == expected_status diff --git a/test/test_service.py b/test/test_service.py index c13eb80..6bb28b3 100644 --- a/test/test_service.py +++ b/test/test_service.py @@ -1,6 +1,7 @@ """ 409 on [POST]/service/ not tested because SQLite throws a different error than Postgres """ + import pytest from httpx import AsyncClient @@ -19,11 +20,12 @@ async def test_get_services_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", - generate_query_and_status(["org_id"]) + "query, expected_status", generate_query_and_status(["org_id"]) ) @pytest.mark.anyio -async def test_get_services_status_checks(default_client: AsyncClient, query: str, expected_status: int): +async def test_get_services_status_checks( + default_client: AsyncClient, query: str, expected_status: int +): resp = await default_client.get(f"/service/?{query}") assert resp.status_code == expected_status @@ -49,7 +51,9 @@ async def test_post_service_success(default_client: AsyncClient): ], ) @pytest.mark.anyio -async def test_post_services_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): +async def test_post_services_status_checks( + default_client: AsyncClient, body: dict[str, str], expected_status: int +): resp = await default_client.post("/service/", json=body) assert resp.status_code == expected_status @@ -77,7 +81,9 @@ async def test_patch_service_success(default_client: AsyncClient): ], ) @pytest.mark.anyio -async def test_patch_services_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): +async def test_patch_services_status_checks( + default_client: AsyncClient, body: dict[str, str], expected_status: int +): resp = await default_client.patch("/service/key", json=body) assert resp.status_code == expected_status diff --git a/test/test_user.py b/test/test_user.py index a50e654..996c67a 100644 --- a/test/test_user.py +++ b/test/test_user.py @@ -8,38 +8,40 @@ from httpx import AsyncClient from .conftest import generate_query_and_status + @pytest.mark.anyio async def test_get_self_db_success(default_client: AsyncClient): - resp = await default_client.get("/user/self/db") - data = resp.json() + resp = await default_client.get("/user/self/db") + data = resp.json() - assert resp.status_code == 200 - assert data["first_name"] == "Admin" - assert data["last_name"] == "Test" - assert data["email"] == "admin@test.com" - assert "organisations" in data - assert "groups" in data + assert resp.status_code == 200 + assert data["first_name"] == "Admin" + assert data["last_name"] == "Test" + assert data["email"] == "admin@test.com" + assert "organisations" in data + assert "groups" in data @pytest.mark.anyio async def test_get_user_success(default_client: AsyncClient): - resp = await default_client.get("/user/?user_id=1") - data = resp.json() + resp = await default_client.get("/user/?user_id=1") + data = resp.json() - assert resp.status_code == 200 - assert data["first_name"] == "Admin" - assert data["last_name"] == "Test" - assert data["email"] == "admin@test.com" - assert "organisations" in data - assert "groups" in data + assert resp.status_code == 200 + assert data["first_name"] == "Admin" + assert data["last_name"] == "Test" + assert data["email"] == "admin@test.com" + assert "organisations" in data + assert "groups" in data @pytest.mark.anyio @pytest.mark.parametrize( - "query, expected_status", - generate_query_and_status(["user_id"]) + "query, expected_status", generate_query_and_status(["user_id"]) ) -async def test_get_user_status_checks(default_client: AsyncClient, query: str, expected_status: int): - resp = await default_client.get(f"/user/?{query}") +async def test_get_user_status_checks( + default_client: AsyncClient, query: str, expected_status: int +): + resp = await default_client.get(f"/user/?{query}") - assert resp.status_code == expected_status + assert resp.status_code == expected_status