diff --git a/.alembic/versions/2026-06-08_fix_permission_unique.py b/.alembic/versions/2026-06-08_fix_permission_unique.py deleted file mode 100644 index f00b7b8..0000000 --- a/.alembic/versions/2026-06-08_fix_permission_unique.py +++ /dev/null @@ -1,32 +0,0 @@ -"""fix permission unique - -Revision ID: b6c8614ef799 -Revises: d9dc6986fe38 -Create Date: 2026-06-08 16:00:27.533099 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = 'b6c8614ef799' -down_revision: Union[str, Sequence[str], None] = 'd9dc6986fe38' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - """Upgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.create_unique_constraint('uniq_permission_resource_and_action', 'permission', ['service_id', 'resource', 'action']) - # ### end Alembic commands ### - - -def downgrade() -> None: - """Downgrade schema.""" - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint('uniq_permission_resource_and_action', 'permission', type_='unique') - # ### end Alembic commands ### diff --git a/pyproject.toml b/pyproject.toml index 63a648b..297f21d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,10 +7,6 @@ exclude = [ ".alembic" ] -[tool.ruff.format] -quote-style = "double" -indent-style = "tab" - [project] name = "cloud-api" version = "0.1.0" diff --git a/src/_module_template/config.py b/src/_module_template/config.py index 927e7bc..45d3182 100644 --- a/src/_module_template/config.py +++ b/src/_module_template/config.py @@ -2,4 +2,4 @@ Configurations for the module Exports: -""" +""" \ No newline at end of file diff --git a/src/_module_template/constants.py b/src/_module_template/constants.py index 9e8da5b..cc72009 100644 --- a/src/_module_template/constants.py +++ b/src/_module_template/constants.py @@ -2,4 +2,4 @@ Constants for the module Exports: -""" +""" \ No newline at end of file diff --git a/src/_module_template/dependencies.py b/src/_module_template/dependencies.py index c61b149..71750bc 100644 --- a/src/_module_template/dependencies.py +++ b/src/_module_template/dependencies.py @@ -3,4 +3,4 @@ Dependencies related to the module Exports: - : : -""" +""" \ No newline at end of file diff --git a/src/_module_template/exceptions.py b/src/_module_template/exceptions.py index 976b6c3..402940a 100644 --- a/src/_module_template/exceptions.py +++ b/src/_module_template/exceptions.py @@ -3,4 +3,4 @@ Exceptions related to the modules Exceptions: - : Details e.g. optional params -""" +""" \ No newline at end of file diff --git a/src/_module_template/models.py b/src/_module_template/models.py index d059461..d03c882 100644 --- a/src/_module_template/models.py +++ b/src/_module_template/models.py @@ -6,4 +6,4 @@ Models: - - - -""" +""" \ No newline at end of file diff --git a/src/_module_template/router.py b/src/_module_template/router.py index 09250df..c81fc26 100644 --- a/src/_module_template/router.py +++ b/src/_module_template/router.py @@ -17,7 +17,6 @@ Exports: - Dependencies should be used for db model get and validation where possible - Verify module level docstring is still accurate after updates """ - from fastapi import APIRouter diff --git a/src/_module_template/schemas.py b/src/_module_template/schemas.py index f72482a..71cfc07 100644 --- a/src/_module_template/schemas.py +++ b/src/_module_template/schemas.py @@ -5,4 +5,4 @@ Models follow the nomenclature of: - Sub-models: "Schema" - Mixins: "Mixin" - Models: "" ie "" -""" +""" \ No newline at end of file diff --git a/src/_module_template/service.py b/src/_module_template/service.py index 39764da..139a237 100644 --- a/src/_module_template/service.py +++ b/src/_module_template/service.py @@ -2,4 +2,4 @@ Module specific business logic for the module Exports: -""" +""" \ No newline at end of file diff --git a/src/_module_template/utils.py b/src/_module_template/utils.py index 5f52b1c..4e99ff6 100644 --- a/src/_module_template/utils.py +++ b/src/_module_template/utils.py @@ -1,3 +1,3 @@ """ Non-business logic reusable functions and classes for the module -""" +""" \ No newline at end of file diff --git a/src/admin/config.py b/src/admin/config.py index 1b96e18..46e4142 100644 --- a/src/admin/config.py +++ b/src/admin/config.py @@ -1,3 +1,3 @@ """ Configurations for the admin module -""" +""" \ No newline at end of file diff --git a/src/admin/constants.py b/src/admin/constants.py index d02c272..c75163f 100644 --- a/src/admin/constants.py +++ b/src/admin/constants.py @@ -1,3 +1,3 @@ """ Constants for the admin module -""" +""" \ No newline at end of file diff --git a/src/admin/dependencies.py b/src/admin/dependencies.py index 0b7fefb..aff00b3 100644 --- a/src/admin/dependencies.py +++ b/src/admin/dependencies.py @@ -1,3 +1,3 @@ """ Dependencies for the admin module -""" +""" \ No newline at end of file diff --git a/src/admin/exceptions.py b/src/admin/exceptions.py index 18dba86..513805c 100644 --- a/src/admin/exceptions.py +++ b/src/admin/exceptions.py @@ -1,3 +1,3 @@ """ Custom exceptions for the admin module -""" +""" \ No newline at end of file diff --git a/src/admin/models.py b/src/admin/models.py index 1faf06c..304e336 100644 --- a/src/admin/models.py +++ b/src/admin/models.py @@ -1,3 +1,3 @@ """ Database models for the admin module -""" +""" \ No newline at end of file diff --git a/src/admin/router.py b/src/admin/router.py index 9fe91eb..e0246a4 100644 --- a/src/admin/router.py +++ b/src/admin/router.py @@ -4,7 +4,6 @@ Router endpoints for the admin module Exports: - router: fastapi.APIRouter """ - from fastapi import APIRouter router = APIRouter( diff --git a/src/admin/schemas.py b/src/admin/schemas.py index 5d65867..1289bcb 100644 --- a/src/admin/schemas.py +++ b/src/admin/schemas.py @@ -1,3 +1,3 @@ """ Pydantic models for the admin module -""" +""" \ No newline at end of file diff --git a/src/admin/service.py b/src/admin/service.py index 9dbfebb..1db3599 100644 --- a/src/admin/service.py +++ b/src/admin/service.py @@ -1,3 +1,3 @@ """ Module specific business logic for the admin module -""" +""" \ No newline at end of file diff --git a/src/admin/utils.py b/src/admin/utils.py index 161d101..e570f14 100644 --- a/src/admin/utils.py +++ b/src/admin/utils.py @@ -1,3 +1,3 @@ """ Non-business logic reusable functions and classes for the admin module -""" +""" \ No newline at end of file diff --git a/src/api.py b/src/api.py index 1461fe1..67ee8d1 100644 --- a/src/api.py +++ b/src/api.py @@ -1,7 +1,6 @@ """ This module hooks the routers for the main endpoints into a single router for importing to the app. """ - from fastapi import APIRouter from src.auth.router import router as auth_router @@ -13,7 +12,9 @@ from src.iam.router import router as iam_router from src.service.router import router as service_router -api_router = APIRouter(prefix="/api/v1") +api_router = APIRouter( + prefix="/api/v1" +) api_router.include_router(auth_router) api_router.include_router(contact_router) @@ -26,5 +27,5 @@ api_router.include_router(iam_router) @api_router.get("/healthcheck", include_in_schema=False) def healthcheck(): - """Simple healthcheck endpoint.""" - return {"status": "ok"} + """Simple healthcheck endpoint.""" + return {"status": "ok"} diff --git a/src/auth/config.py b/src/auth/config.py index 030c36e..979c0e5 100644 --- a/src/auth/config.py +++ b/src/auth/config.py @@ -4,14 +4,12 @@ Configurations for the auth module Exports: - auth_settings: Contains OIDC information """ - from src.config import CustomBaseSettings - class AuthConfig(CustomBaseSettings): - OIDC_CONFIG: str = "" - OIDC_ISSUER: str = "" - CLIENT_ID: str = "" + OIDC_CONFIG: str = "" + OIDC_ISSUER: str = "" + CLIENT_ID: str = "" auth_settings = AuthConfig() diff --git a/src/auth/constants.py b/src/auth/constants.py index 382aac7..faabd82 100644 --- a/src/auth/constants.py +++ b/src/auth/constants.py @@ -1,3 +1,3 @@ """ Constants for the auth module -""" +""" \ No newline at end of file diff --git a/src/auth/dependencies.py b/src/auth/dependencies.py index e29b641..959a830 100644 --- a/src/auth/dependencies.py +++ b/src/auth/dependencies.py @@ -7,24 +7,18 @@ Exports: - org_model_root_claim_body_dependency: org_model: verifies org exists and user is either root or su, gets org from body - super_admin_dependency: user_model: verifies the user is a super admin """ - from typing import Annotated from fastapi import Depends from src.user.dependencies import user_model_claims_dependency from src.user.models import User -from src.organisation.dependencies import ( - org_model_query_dependency, - org_model_body_dependency, -) +from src.organisation.dependencies import org_model_query_dependency, org_model_body_dependency from src.organisation.models import Organisation as Org from src.auth.exceptions import UnauthorizedException -async def org_query_user_claims( - org_model: org_model_query_dependency, user_model: user_model_claims_dependency -): +async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency): if user_model in org_model.user_rel: return True @@ -34,11 +28,7 @@ async def org_query_user_claims( org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)] -async def org_query_root_claims( - user_model: user_model_claims_dependency, - org_model: org_model_query_dependency, - su_emails: su_list_dependency, -): +async def org_query_root_claims(user_model: user_model_claims_dependency, org_model: org_model_query_dependency, su_emails: su_list_dependency): if org_model.root_user_id == user_model.id: return org_model @@ -51,16 +41,10 @@ async def org_query_root_claims( raise UnauthorizedException(message="Must be the org's root user") -org_model_root_claim_query_dependency = Annotated[ - type[Org], Depends(org_query_root_claims) -] +org_model_root_claim_query_dependency = Annotated[type[Org], Depends(org_query_root_claims)] -async def org_body_root_claims( - user_model: user_model_claims_dependency, - org_model: org_model_body_dependency, - su_emails: su_list_dependency, -): +async def org_body_root_claims(user_model: user_model_claims_dependency, org_model: org_model_body_dependency, su_emails: su_list_dependency): if org_model.root_user_id == user_model.id: return org_model @@ -73,29 +57,21 @@ async def org_body_root_claims( raise UnauthorizedException(message="Must be the org's root user") -org_model_root_claim_body_dependency = Annotated[ - type[Org], Depends(org_body_root_claims) -] +org_model_root_claim_body_dependency = Annotated[type[Org], Depends(org_body_root_claims)] def get_super_admin_list(): return [] - def empty_su_list(): return [] - def testing_su_list(): return ["admin@test.com"] - su_list_dependency = Annotated[list[User], Depends(get_super_admin_list)] - -async def user_model_super_admin( - user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency -): +async def user_model_super_admin(user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency): if user_model.email in super_admin_emails: return user_model diff --git a/src/auth/exceptions.py b/src/auth/exceptions.py index f4a2cba..613b166 100644 --- a/src/auth/exceptions.py +++ b/src/auth/exceptions.py @@ -4,7 +4,6 @@ Module specific exceptions for the auth module Exceptions: - UnauthorizedException: Takes an optional message string """ - from typing import Optional from fastapi import HTTPException, status diff --git a/src/auth/models.py b/src/auth/models.py index aaa8362..4717477 100644 --- a/src/auth/models.py +++ b/src/auth/models.py @@ -1,3 +1,3 @@ """ Database models for the auth module -""" +""" \ No newline at end of file diff --git a/src/auth/router.py b/src/auth/router.py index ee32033..9cd7fad 100644 --- a/src/auth/router.py +++ b/src/auth/router.py @@ -4,9 +4,8 @@ Router endpoints for the auth module Exports: - router: fastapi.APIRouter """ - from fastapi import APIRouter router = APIRouter( tags=["auth"], -) +) \ No newline at end of file diff --git a/src/auth/schemas.py b/src/auth/schemas.py index 5f5ac35..279bb1b 100644 --- a/src/auth/schemas.py +++ b/src/auth/schemas.py @@ -1,3 +1,3 @@ """ Pydantic models for the auth module -""" +""" \ No newline at end of file diff --git a/src/auth/service.py b/src/auth/service.py index a27d421..f156a9d 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -4,7 +4,6 @@ Module specific business logic for the auth module Exports: - claims_dependency: Dict[str, Any] containing OIDC claims and database ID """ - import json import requests @@ -26,14 +25,11 @@ from src.database import db_dependency oidc = OpenIdConnect(openIdConnectUrl=auth_settings.OIDC_CONFIG) oidc_dependency = Annotated[str, Depends(oidc)] - def get_dev_user(): return {"db_id": 1} -async def get_current_user( - oidc_auth_string: oidc_dependency, db: db_dependency -) -> dict[str, Any]: +async def get_current_user(oidc_auth_string: oidc_dependency, db: db_dependency) -> dict[str, Any]: config_url = urlopen(auth_settings.OIDC_CONFIG) config = json.loads(config_url.read()) jwks_uri = config["jwks_uri"] @@ -45,7 +41,10 @@ async def get_current_user( "iss": {"essential": True, "value": auth_settings.OIDC_ISSUER}, } - token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys) + token = jwt.decode( + oidc_auth_string.replace("Bearer ", ""), + jwk_keys + ) claims_requests = jwt.JWTClaimsRegistry(**claims_options) diff --git a/src/auth/utils.py b/src/auth/utils.py index 178518a..ed66e7c 100644 --- a/src/auth/utils.py +++ b/src/auth/utils.py @@ -1,3 +1,3 @@ """ Non-business logic reusable functions and classes for the auth module -""" +""" \ No newline at end of file diff --git a/src/config.py b/src/config.py index ddce0c8..afecd8f 100644 --- a/src/config.py +++ b/src/config.py @@ -16,26 +16,25 @@ from src.constants import Environment class CustomBaseSettings(BaseSettings): - model_config = SettingsConfigDict( - env_file=".env", env_file_encoding="utf-8", extra="ignore" - ) + model_config = SettingsConfigDict( + env_file=".env", env_file_encoding="utf-8", extra="ignore" + ) class Config(CustomBaseSettings): - APP_VERSION: str = "0.1" - ENVIRONMENT: Environment = Environment.PRODUCTION - SECRET_KEY: SecretStr = "" - DISABLE_AUTH: bool = False + APP_VERSION: str = "0.1" + ENVIRONMENT: Environment = Environment.PRODUCTION + SECRET_KEY: SecretStr = "" + DISABLE_AUTH: bool = False - CORS_ORIGINS: list[str] = ["*"] - CORS_ORIGINS_REGEX: str | None = None - CORS_HEADERS: list[str] = ["*"] - - DATABASE_NAME: str = "fastapi-exp" - DATABASE_PORT: str = "5432" - DATABASE_HOSTNAME: str = "localhost" - DATABASE_CREDENTIALS: SecretStr = ":" + CORS_ORIGINS: list[str] = ["*"] + CORS_ORIGINS_REGEX: str | None = None + CORS_HEADERS: list[str] = ["*"] + DATABASE_NAME: str = "fastapi-exp" + DATABASE_PORT: str = "5432" + DATABASE_HOSTNAME: str = "localhost" + DATABASE_CREDENTIALS: SecretStr = ":" settings = Config() @@ -44,21 +43,17 @@ DATABASE_PORT = settings.DATABASE_PORT DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value() # this will support special chars for credentials -_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str( - DATABASE_CREDENTIALS -).split(":") +_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split(":") _QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD)) -SQLALCHEMY_DATABASE_URI = SecretStr( - f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}" -) +SQLALCHEMY_DATABASE_URI = SecretStr(f"postgresql+psycopg://{_DATABASE_CREDENTIAL_USER}:{_QUOTED_DATABASE_PASSWORD}@{DATABASE_HOSTNAME}:{DATABASE_PORT}/{DATABASE_NAME}") if settings.ENVIRONMENT == Environment.TESTING: - SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:") + SQLALCHEMY_DATABASE_URI = SecretStr("sqlite:///:memory:") app_configs: dict[str, Any] = {"title": "App API"} if settings.ENVIRONMENT.is_deployed: - app_configs["root_path"] = f"/v{settings.APP_VERSION}" + app_configs["root_path"] = f"/v{settings.APP_VERSION}" if not settings.ENVIRONMENT.is_debug: - app_configs["openapi_url"] = None # hide docs + app_configs["openapi_url"] = None # hide docs diff --git a/src/constants.py b/src/constants.py index b0725bf..ab33afb 100644 --- a/src/constants.py +++ b/src/constants.py @@ -4,7 +4,6 @@ Global constants Classes: - Environment(StrEnum): LOCAL, TESTING, STAGING, PRODUCTION """ - from enum import StrEnum, auto diff --git a/src/contact/config.py b/src/contact/config.py index 7480691..2253a68 100644 --- a/src/contact/config.py +++ b/src/contact/config.py @@ -1,3 +1,3 @@ """ Configurations for the contact module -""" +""" \ No newline at end of file diff --git a/src/contact/constants.py b/src/contact/constants.py index ad08c0e..41f6ded 100644 --- a/src/contact/constants.py +++ b/src/contact/constants.py @@ -1,3 +1,3 @@ """ Constants for the contact module -""" +""" \ No newline at end of file diff --git a/src/contact/dependencies.py b/src/contact/dependencies.py index 0844fd3..de1d404 100644 --- a/src/contact/dependencies.py +++ b/src/contact/dependencies.py @@ -1,3 +1,3 @@ """ Dependencies for the contact module -""" +""" \ No newline at end of file diff --git a/src/contact/exceptions.py b/src/contact/exceptions.py index 55e9e30..6710bf3 100644 --- a/src/contact/exceptions.py +++ b/src/contact/exceptions.py @@ -4,7 +4,6 @@ Exceptions related to the contact module Exports: - ContactNotFoundException: Takes an optional contact ID int """ - from typing import Optional from fastapi import HTTPException, status @@ -12,11 +11,7 @@ from fastapi import HTTPException, status class ContactNotFoundException(HTTPException): def __init__(self, contact_id: Optional[int] = None) -> None: - detail = ( - "Contact not found" - if contact_id is None - else f"Contact with ID '{contact_id}' was not found." - ) + detail = "Contact not found" if contact_id is None else f"Contact with ID '{contact_id}' was not found." super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=detail, diff --git a/src/contact/models.py b/src/contact/models.py index d741bd2..3369501 100644 --- a/src/contact/models.py +++ b/src/contact/models.py @@ -5,7 +5,6 @@ Models: - Contact: id[pk], email, first_name, last_name, phonenumber, vat_number street_address, street_address_line_2, post_office_box_number, address_locality, country_code, address_region, postal_code """ - from sqlalchemy import Column, Integer, String, ForeignKey from src.database import Base @@ -24,11 +23,9 @@ class Contact(Base): street_address = Column(String) street_address_line_2 = Column(String) post_office_box_number = Column(String, default=None, nullable=True) - locality = Column(String) # Ie City - country_code = Column(String) # Eg GB + locality = Column(String) # Ie City + country_code = Column(String) # Eg GB address_region = Column(String, default=None, nullable=True) postal_code = Column(String) - org_id = Column( - Integer, ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False - ) + org_id = Column(Integer, ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False) diff --git a/src/contact/router.py b/src/contact/router.py index 2e5f8f4..cdab37f 100644 --- a/src/contact/router.py +++ b/src/contact/router.py @@ -1,11 +1,10 @@ """ Router endpoints for the contact module """ - from fastapi import APIRouter router = APIRouter( prefix="/contact", tags=["contact"], -) +) \ No newline at end of file diff --git a/src/contact/schemas.py b/src/contact/schemas.py index b9cec61..b008739 100644 --- a/src/contact/schemas.py +++ b/src/contact/schemas.py @@ -5,7 +5,6 @@ Models: - ContactAddress - ContactModel: Contains ContactAddress as a property """ - from typing import Optional from pydantic import EmailStr, ConfigDict diff --git a/src/contact/service.py b/src/contact/service.py index 3223e70..e04866a 100644 --- a/src/contact/service.py +++ b/src/contact/service.py @@ -1,3 +1,3 @@ """ Module specific business logic for the contact module -""" +""" \ No newline at end of file diff --git a/src/contact/utils.py b/src/contact/utils.py index daa2449..6a1d14a 100644 --- a/src/contact/utils.py +++ b/src/contact/utils.py @@ -1,3 +1,3 @@ """ Non-business logic reusable functions and classes for the contact module -""" +""" \ No newline at end of file diff --git a/src/database.py b/src/database.py index 3838098..a56f80d 100644 --- a/src/database.py +++ b/src/database.py @@ -5,7 +5,6 @@ Exports: - db_dependency - Base (sqlalchemy base model) """ - from typing import Annotated from sqlalchemy import create_engine, StaticPool from sqlalchemy.orm import DeclarativeBase, sessionmaker, Session @@ -17,11 +16,7 @@ from src.config import SQLALCHEMY_DATABASE_URI, settings as global_settings if global_settings.ENVIRONMENT == Environment.TESTING: connect_args = {"check_same_thread": False} - engine = create_engine( - SQLALCHEMY_DATABASE_URI.get_secret_value(), - connect_args=connect_args, - poolclass=StaticPool, - ) + engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value(), connect_args=connect_args, poolclass=StaticPool) else: engine = create_engine(SQLALCHEMY_DATABASE_URI.get_secret_value()) @@ -41,7 +36,5 @@ def get_db(): db_dependency = Annotated[Session, Depends(get_db)] - - class Base(DeclarativeBase): pass diff --git a/src/exceptions.py b/src/exceptions.py index 8b3629c..66507a4 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -5,7 +5,6 @@ Exports: - UnprocessableContentException - ConflictException """ - from typing import Optional from fastapi import HTTPException, status diff --git a/src/iam/config.py b/src/iam/config.py index 8fef3ec..165dc07 100644 --- a/src/iam/config.py +++ b/src/iam/config.py @@ -1,3 +1,3 @@ """ Configurations for the IAM module -""" +""" \ No newline at end of file diff --git a/src/iam/constants.py b/src/iam/constants.py index c2623ec..0dc94e7 100644 --- a/src/iam/constants.py +++ b/src/iam/constants.py @@ -1,3 +1,3 @@ """ Constants for the IAM module -""" +""" \ No newline at end of file diff --git a/src/iam/dependencies.py b/src/iam/dependencies.py index 6ebc4e5..37b8e87 100644 --- a/src/iam/dependencies.py +++ b/src/iam/dependencies.py @@ -6,7 +6,6 @@ Exports: - group_model_body_dependency: group_model: Gets group model from db, if it exists. Uses group_id from request body. - perm_model_body_dependency: perm_model: Gets perm model from db, if it exists. Uses perm_id from request body. """ - from typing import Annotated, Optional from fastapi import Depends, Query @@ -18,22 +17,17 @@ from src.iam.exceptions import GroupNotFoundException, PermNotFoundException from src.iam.schemas import GroupIDMixin, PermIDMixin -def get_group_model_query( - db: db_dependency, group_id: Annotated[int, Query(gt=0)] -) -> type[Group]: +def get_group_model_query(db: db_dependency, group_id: Annotated[int, Query(gt=0)]) -> type[Group]: group_model = db.get(Group, group_id) if group_model is None: raise GroupNotFoundException(group_id) return group_model - group_model_query_dependency = Annotated[type[Group], Depends(get_group_model_query)] -def get_group_model_body( - db: db_dependency, request_model: Optional[GroupIDMixin] = None -) -> type[Group]: +def get_group_model_body(db: db_dependency, request_model: Optional[GroupIDMixin] = None) -> type[Group]: group_id = getattr(request_model, "group_id", None) if group_id is None: raise GroupNotFoundException() @@ -43,13 +37,10 @@ def get_group_model_body( return group_model - group_model_body_dependency = Annotated[type[Group], Depends(get_group_model_body)] -def get_perm_model_body( - db: db_dependency, request_model: Optional[PermIDMixin] = None -) -> type[Permission]: +def get_perm_model_body(db: db_dependency, request_model: Optional[PermIDMixin] = None) -> type[Permission]: perm_id = getattr(request_model, "permission_id", None) if perm_id is None: raise PermNotFoundException @@ -59,18 +50,4 @@ def get_perm_model_body( return perm_model - perm_model_body_dependency = Annotated[type[Permission], Depends(get_perm_model_body)] - - -def get_perm_model_query( - db: db_dependency, perm_id: Annotated[int, Query(gt=0)] -) -> type[Permission]: - perm_model = db.get(Permission, perm_id) - if perm_model is None: - raise PermNotFoundException(perm_id) - - return perm_model - - -perm_model_query_dependency = Annotated[type[Permission], Depends(get_perm_model_query)] diff --git a/src/iam/exceptions.py b/src/iam/exceptions.py index 503b844..84a77ed 100644 --- a/src/iam/exceptions.py +++ b/src/iam/exceptions.py @@ -5,7 +5,6 @@ Exceptions: - GroupNotFoundException: Takes an optional group_id int - PermNotFoundException: Takes an optional perm_id int """ - from typing import Optional from fastapi import HTTPException, status @@ -13,11 +12,7 @@ from fastapi import HTTPException, status class GroupNotFoundException(HTTPException): def __init__(self, group_id: Optional[int] = None) -> None: - detail = ( - "Group not found" - if group_id is None - else f"User with ID '{group_id}' was not found." - ) + detail = "Group not found" if group_id is None else f"User with ID '{group_id}' was not found." super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=detail, @@ -26,11 +21,7 @@ class GroupNotFoundException(HTTPException): class PermNotFoundException(HTTPException): def __init__(self, perm_id: Optional[int] = None) -> None: - detail = ( - "Permission not found" - if perm_id is None - else f"User with ID '{perm_id}' was not found." - ) + detail = "Permission not found" if perm_id is None else f"User with ID '{perm_id}' was not found." super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=detail, diff --git a/src/iam/models.py b/src/iam/models.py index 0087abc..f542b70 100644 --- a/src/iam/models.py +++ b/src/iam/models.py @@ -17,7 +17,6 @@ Models: - UserGroups: - org_id[FK][PK], user_id[FK][PK], group_id[FK][PK] """ - from sqlalchemy import Column, Integer, String, ForeignKey, UniqueConstraint from sqlalchemy.orm import relationship @@ -33,14 +32,7 @@ class Permission(Base): service_id = Column(Integer, ForeignKey("service.id", ondelete="CASCADE")) - __table_args__ = ( - UniqueConstraint( - "service_id", - "resource", - "action", - name="uniq_permission_resource_and_action", - ), - ) + UniqueConstraint("service_id", "resource", "action", name="uniq_permission_resource_and_action") service_rel = relationship("Service", foreign_keys=[service_id]) @@ -49,10 +41,13 @@ class Permission(Base): return self.service_rel.name group_rel = relationship( - "Group", secondary="group_permissions", back_populates="permission_rel" + "Group", + secondary="group_permissions", + back_populates="permission_rel" ) + class Group(Base): __tablename__ = "group" id = Column(Integer, primary_key=True) @@ -60,30 +55,28 @@ class Group(Base): org_id = Column(Integer, ForeignKey("organisation.id", ondelete="CASCADE")) - user_rel = relationship("User", secondary="user_groups", back_populates="group_rel") + user_rel = relationship( + "User", + secondary="user_groups", + back_populates="group_rel" + ) org_rel = relationship("Organisation", back_populates="group_rel") permission_rel = relationship( - "Permission", secondary="group_permissions", back_populates="group_rel" + "Permission", + secondary="group_permissions", + back_populates="group_rel" ) class GroupPermissions(Base): __tablename__ = "group_permissions" - group_id = Column( - Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True - ) - permission_id = Column( - Integer, ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True - ) + group_id = Column(Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True) + permission_id = Column(Integer, ForeignKey("permission.id", ondelete="CASCADE"), primary_key=True) class UserGroups(Base): __tablename__ = "user_groups" - user_id = Column( - Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True - ) - group_id = Column( - Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True - ) + user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True) + group_id = Column(Integer, ForeignKey("group.id", ondelete="CASCADE"), primary_key=True) diff --git a/src/iam/router.py b/src/iam/router.py index d428e80..016cebc 100644 --- a/src/iam/router.py +++ b/src/iam/router.py @@ -15,9 +15,9 @@ Endpoints: - [DELETE](/iam/permission): [super admin]: Removes a permission - [GET](/iam/permissions/search): [root user]: Returns a list of permissions matching a filter(service|resource|action) """ - from fastapi import APIRouter, status from sqlalchemy.exc import IntegrityError +from psycopg import errors from src.service.exceptions import ServiceNotFoundException from src.exceptions import ConflictException @@ -25,50 +25,21 @@ from src.database import db_dependency from src.schemas import ResourceName from src.auth.exceptions import UnauthorizedException from src.auth.service import claims_dependency -from src.auth.dependencies import ( - org_model_root_claim_query_dependency, - org_model_root_claim_body_dependency, - super_admin_dependency, -) +from src.auth.dependencies import org_model_root_claim_query_dependency, org_model_root_claim_body_dependency, \ + super_admin_dependency from src.user.models import User -from src.user.dependencies import ( - user_model_body_dependency, - user_model_query_dependency, -) +from src.user.dependencies import user_model_body_dependency from src.organisation.models import Organisation as Org from src.service.models import Service from src.iam.service import service_key_dependency -from src.iam.models import ( - Permission as Perm, - GroupPermissions as GPerms, - Group, - UserGroups, -) -from src.iam.dependencies import ( - group_model_query_dependency, - group_model_body_dependency, - perm_model_body_dependency, - perm_model_query_dependency, -) -from src.iam.schemas import ( - IAMGetGroupPermissionsResponse, - IAMGetGroupUsersResponse, - IAMPostGroupRequest, - GroupSchema, - IAMPostGroupResponse, - IAMPutGroupPermissionRequest, - IAMPutGroupPermissionResponse, - IAMPutGroupUserRequest, - IAMPutGroupUserResponse, - IAMDeleteGroupPermissionResponse, - IAMDeleteGroupUserResponse, - IAMGetPermissionsResponse, - IAMPostPermissionRequest, - IAMPostPermissionResponse, - IAMGetPermissionsSearchRequest, - IAMGetPermissionsSearchResponse, -) +from src.iam.models import Permission as Perm, GroupPermissions as GPerms, Group, UserGroups +from src.iam.dependencies import group_model_query_dependency, group_model_body_dependency, perm_model_body_dependency +from src.iam.schemas import IAMGetGroupPermissionsResponse, IAMGetGroupUsersResponse, IAMPostGroupRequest, \ + GroupSchema, IAMPostGroupResponse, IAMPutGroupPermissionRequest, IAMPutGroupPermissionResponse, \ + IAMPutGroupUserRequest, IAMPutGroupUserResponse, IAMDeleteGroupPermissionRequest, IAMDeleteGroupPermissionResponse, \ + IAMDeleteGroupUserRequest, IAMDeleteGroupUserResponse, IAMGetPermissionsResponse, IAMPostPermissionRequest, \ + IAMPostPermissionResponse, IAMDeletePermissionRequest, IAMGetPermissionsSearchRequest, IAMGetPermissionsSearchResponse router = APIRouter( tags=["IAM"], @@ -77,32 +48,26 @@ router = APIRouter( @router.post("/can_act_on_resource") -async def can_act_on_resource( - valid_key: service_key_dependency, - db: db_dependency, - user_claims: claims_dependency, - rn: ResourceName, - action: str, -) -> bool: +async def can_act_on_resource(valid_key: service_key_dependency, db: db_dependency, user_claims: claims_dependency, + rn: ResourceName, action: str) -> bool: try: user_id = user_claims["db_id"] rn_org = rn.organisation rn_service = rn.service rn_resource = rn.resource - result = ( - db.query(Perm) - .join(Service, Service.id == Perm.service_id) - .join(GPerms, GPerms.permission_id == Perm.id) - .join(Group, Group.id == GPerms.group_id) - .join(Org, Org.id == Group.org_id) - .join(UserGroups, UserGroups.group_id == Group.id) - .join(User, User.id == UserGroups.user_id) - .filter(User.id == user_id) - .filter(Org.name == rn_org) - .filter(Service.name == rn_service) - .filter(Perm.resource == rn_resource) - .filter(Perm.action == action) + result = (db.query(Perm) + .join(Service, Service.id == Perm.service_id) + .join(GPerms, GPerms.permission_id == Perm.id) + .join(Group, Group.id == GPerms.group_id) + .join(Org, Org.id == Group.org_id) + .join(UserGroups, UserGroups.group_id == Group.id) + .join(User, User.id == UserGroups.user_id) + .filter(User.id == user_id) + .filter(Org.name == rn_org) + .filter(Service.name == rn_service) + .filter(Perm.resource == rn_resource) + .filter(Perm.action == action) ).first() if result: @@ -114,41 +79,28 @@ async def can_act_on_resource( @router.get("/group/permissions", response_model=IAMGetGroupPermissionsResponse) -async def get_group_permissions( - group_model: group_model_query_dependency, - org_model: org_model_root_claim_query_dependency, -): +async def get_group_permissions(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency): if group_model.org_id != org_model.id: raise UnauthorizedException("Group does not belong to this organization") return {"permissions": group_model.permission_rel} @router.get("/group/users", response_model=IAMGetGroupUsersResponse) -async def get_group_users( - group_model: group_model_query_dependency, - org_model: org_model_root_claim_query_dependency, -): +async def get_group_users(group_model: group_model_query_dependency, org_model: org_model_root_claim_query_dependency): if group_model.org_id != org_model.id: raise UnauthorizedException("Group does not belong to this organization") return {"users": group_model.user_rel} @router.post("/group", response_model=IAMPostGroupResponse) -async def create_group( - db: db_dependency, - org_model: org_model_root_claim_body_dependency, - request_model: IAMPostGroupRequest, -): +async def create_group(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPostGroupRequest): group_model = Group(name=request_model.name, org_id=org_model.id) db.add(group_model) try: db.flush() except IntegrityError as e: - if ( - getattr(e.orig, "pgcode", None) == "23505" # Postgres unique violation - or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation - ): + if isinstance(e.orig, errors.UniqueViolation): raise ConflictException("Group with this name already exists") response = GroupSchema(**group_model.__dict__) db.commit() @@ -156,13 +108,7 @@ async def create_group( @router.put("/group/permission", response_model=IAMPutGroupPermissionResponse) -async def add_group_permission( - db: db_dependency, - group_model: group_model_body_dependency, - perm_model: perm_model_body_dependency, - org_model: org_model_root_claim_body_dependency, - request_model: IAMPutGroupPermissionRequest, -): +async def add_group_permission(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupPermissionRequest): if group_model.org_id != org_model.id: raise UnauthorizedException("Group does not belong to this organization") @@ -172,22 +118,13 @@ async def add_group_permission( group_model.permission_rel.append(perm_model) db.flush() - response = IAMPutGroupPermissionResponse( - group=GroupSchema(**group_model.__dict__), - permissions=group_model.permission_rel, - ) + response = IAMPutGroupPermissionResponse(group=GroupSchema(**group_model.__dict__), permissions=group_model.permission_rel) db.commit() return response @router.put("/group/user", response_model=IAMPutGroupUserResponse) -async def add_group_user( - db: db_dependency, - group_model: group_model_body_dependency, - user_model: user_model_body_dependency, - org_model: org_model_root_claim_body_dependency, - request_model: IAMPutGroupUserRequest, -): +async def add_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMPutGroupUserRequest): if group_model.org_id != org_model.id: raise UnauthorizedException("Group does not belong to this organization") @@ -196,120 +133,79 @@ async def add_group_user( group_model.user_rel.append(user_model) db.flush() - response = IAMPutGroupUserResponse( - group=GroupSchema(**group_model.__dict__), users=group_model.user_rel - ) + response = IAMPutGroupUserResponse(group=GroupSchema(**group_model.__dict__), users=group_model.user_rel) db.commit() return response @router.delete("/group/permissions") -async def remove_group_permissions( - db: db_dependency, - group_model: group_model_query_dependency, - perm_model: perm_model_query_dependency, - org_model: org_model_root_claim_query_dependency, -): +async def remove_group_permissions(db: db_dependency, group_model: group_model_body_dependency, perm_model: perm_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupPermissionRequest): if group_model.org_id != org_model.id: raise UnauthorizedException("Group does not belong to this organization") group_model.permission_rel.remove(perm_model) db.flush() - response = IAMDeleteGroupPermissionResponse( - group=GroupSchema(**group_model.__dict__), - permissions=group_model.permission_rel, - ) + response = IAMDeleteGroupPermissionResponse(group=GroupSchema(**group_model.__dict__), + permissions=group_model.permission_rel) db.commit() return response @router.delete("/group/user") -async def remove_group_user( - db: db_dependency, - group_model: group_model_query_dependency, - user_model: user_model_query_dependency, - org_model: org_model_root_claim_query_dependency, -): +async def remove_group_user(db: db_dependency, group_model: group_model_body_dependency, user_model: user_model_body_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMDeleteGroupUserRequest): if group_model.org_id != org_model.id: raise UnauthorizedException("Group does not belong to this organization") user_model.group_rel.remove(group_model) db.flush() - response = IAMDeleteGroupUserResponse( - group=GroupSchema(**group_model.__dict__), users=group_model.user_rel - ) + response = IAMDeleteGroupUserResponse(group=GroupSchema(**group_model.__dict__), users=group_model.user_rel) db.commit() return response @router.get("/permissions", response_model=IAMGetPermissionsResponse) -async def get_permissions( - db: db_dependency, org_model: org_model_root_claim_query_dependency -): +async def get_permissions(db: db_dependency, org_model: org_model_root_claim_query_dependency): permission_models = db.query(Perm).all() return {"permissions": permission_models} @router.post("/permission", response_model=IAMPostPermissionResponse) -async def create_new_permission( - db: db_dependency, - su: super_admin_dependency, - request_model: IAMPostPermissionRequest, -): +async def create_new_permission(db: db_dependency, su: super_admin_dependency, request_model: IAMPostPermissionRequest): service_model = db.get(Service, request_model.service_id) if service_model is None: raise ServiceNotFoundException(service_id=request_model.service_id) perm_model = Perm(**request_model.__dict__) - db.add(perm_model) try: - db.flush() + db.add(perm_model) except IntegrityError as e: - if ( - getattr(e.orig, "pgcode", None) == "23505" # Postgres unique violation - or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation - ): + if isinstance(e.orig, errors.UniqueViolation): raise ConflictException(message="Permission already exists") - response = { - "service_name": perm_model.service_name, - "resource": perm_model.resource, - "action": perm_model.action, - } + db.flush() + response = {"service_name": perm_model.service_name, "resource": perm_model.resource, "action": perm_model.action} db.commit() return {"permission": response} @router.delete("/permission", status_code=status.HTTP_204_NO_CONTENT) -async def delete_permission( - db: db_dependency, - su: super_admin_dependency, - perm_model: perm_model_query_dependency, -): +async def delete_permission(db: db_dependency, su: super_admin_dependency, perm_model: perm_model_body_dependency, request_model: IAMDeletePermissionRequest): db.delete(perm_model) db.commit() @router.post("/permissions/search", response_model=IAMGetPermissionsSearchResponse) -async def post_permissions( - db: db_dependency, - org_model: org_model_root_claim_body_dependency, - request_model: IAMGetPermissionsSearchRequest, -): +async def post_permissions(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: IAMGetPermissionsSearchRequest): permission_query = db.query(Perm) if request_model.service_id is not None: - permission_query = permission_query.filter( - Perm.service_id == request_model.service_id - ) + permission_query = permission_query.filter(Perm.service_id == request_model.service_id) if request_model.resource is not None: - permission_query = permission_query.filter( - Perm.resource == request_model.resource - ) + permission_query = permission_query.filter(Perm.resource == request_model.resource) if request_model.action is not None: - permission_query = permission_query.filter(Perm.action == request_model.action) + permission_query = permission_query.filter(Perm.action == request_model. action) permission_models = permission_query.all() diff --git a/src/iam/schemas.py b/src/iam/schemas.py index fca76de..0d370b8 100644 --- a/src/iam/schemas.py +++ b/src/iam/schemas.py @@ -6,7 +6,6 @@ Models follow the nomenclature of: - Mixins: "Mixin" - Models: "" ie "IAMGetGroupPermissionsResponse" """ - from typing import Optional, Annotated from pydantic import EmailStr, ConfigDict, Field @@ -25,7 +24,6 @@ class UserSchema(CustomBaseModel): last_name: str email: EmailStr - class PermissionSchema(CustomBaseModel): model_config = ConfigDict(from_attributes=True, extra="ignore") @@ -33,82 +31,73 @@ class PermissionSchema(CustomBaseModel): resource: str action: str - class GroupSchema(CustomBaseModel): id: int name: str - class GroupIDMixin(CustomBaseModel): group_id: int = Field(gt=0) - class PermIDMixin(CustomBaseModel): permission_id: int = Field(gt=0) - class IAMGetGroupPermissionsResponse(CustomBaseModel): permissions: list[PermissionSchema] - class IAMGetGroupUsersResponse(CustomBaseModel): - users: list[UserSchema] - + users : list[UserSchema] class IAMPostGroupRequest(OrgIDMixin): name: str = Field(min_length=3) - class IAMPostGroupResponse(CustomBaseModel): group: GroupSchema - class IAMPutGroupPermissionRequest(GroupIDMixin, PermIDMixin, OrgIDMixin): pass - class IAMPutGroupPermissionResponse(CustomBaseModel): group: GroupSchema permissions: list[PermissionSchema] - class IAMPutGroupUserRequest(GroupIDMixin, UserIDMixin, OrgIDMixin): pass - class IAMPutGroupUserResponse(CustomBaseModel): group: GroupSchema users: list[UserSchema] +class IAMDeleteGroupPermissionRequest(GroupIDMixin, PermIDMixin): + pass class IAMDeleteGroupPermissionResponse(CustomBaseModel): group: GroupSchema permissions: list[PermissionSchema] +class IAMDeleteGroupUserRequest(GroupIDMixin, UserIDMixin): + pass class IAMDeleteGroupUserResponse(CustomBaseModel): group: GroupSchema users: list[UserSchema] - class IAMGetPermissionsResponse(CustomBaseModel): permissions: list[PermissionSchema] - class IAMPostPermissionRequest(ServiceIDMixin): resource: str action: str - class IAMPostPermissionResponse(CustomBaseModel): permission: PermissionSchema +class IAMDeletePermissionRequest(PermIDMixin): + pass class IAMGetPermissionsSearchRequest(OrgIDMixin): service_id: Annotated[int | None, Field(gt=0)] = None resource: Optional[str] = None action: Optional[str] = None - class IAMGetPermissionsSearchResponse(CustomBaseModel): permissions: list[PermissionSchema] diff --git a/src/iam/service.py b/src/iam/service.py index 1e0dfe8..c6a1030 100644 --- a/src/iam/service.py +++ b/src/iam/service.py @@ -4,7 +4,6 @@ Business logic reusable functions related to IAM Exports: - service_key_dependency: bool: verifies request headers contain the correct api key for the service """ - from typing import Annotated from src.service.models import Service @@ -20,16 +19,10 @@ def valid_service_key(db: db_dependency, request: Request, rn: ResourceName) -> if not api_key: raise UnauthorizedException("Missing API key") service = rn.service - result = ( - db.query(Service) - .filter(Service.name == service) - .filter(Service.api_key == api_key) - .first() - ) + result = db.query(Service).filter(Service.name == service).filter(Service.api_key == api_key).first() if result is None: raise UnauthorizedException("Invalid API key") return True - service_key_dependency = Annotated[bool, Depends(valid_service_key)] diff --git a/src/main.py b/src/main.py index 26dc6e0..c421a94 100644 --- a/src/main.py +++ b/src/main.py @@ -1,7 +1,6 @@ """ Application root file: Inits the FastAPI application """ - from contextlib import asynccontextmanager from typing import AsyncGenerator diff --git a/src/models.py b/src/models.py index 87912aa..fa198e4 100644 --- a/src/models.py +++ b/src/models.py @@ -1,3 +1,4 @@ """ Global database models """ + diff --git a/src/organisation/config.py b/src/organisation/config.py index 7ce00f7..e24ca5b 100644 --- a/src/organisation/config.py +++ b/src/organisation/config.py @@ -1,3 +1,3 @@ """ Configurations for the organisation module -""" +""" \ No newline at end of file diff --git a/src/organisation/constants.py b/src/organisation/constants.py index 79c22fd..ced0682 100644 --- a/src/organisation/constants.py +++ b/src/organisation/constants.py @@ -5,7 +5,6 @@ Classes: - Status(StrEnum): PARTIAL, SUBMITTED, REMEDIATION, APPROVED, REJECTED, REMOVED - ContactType(StrEnum): BILLING, SECURITY, OWNER """ - from enum import StrEnum, auto diff --git a/src/organisation/dependencies.py b/src/organisation/dependencies.py index 35b09fc..728b8d0 100644 --- a/src/organisation/dependencies.py +++ b/src/organisation/dependencies.py @@ -5,7 +5,6 @@ Exports: - org_model_query_dependency: org_model: Gets org model from db, if it exists. Uses org_id from query param. Also verifies if the org has been approved. - org_model_body_dependency: org_model: Gets org model from db, if it exists. Uses org_id from request body. Also verifies if the org has been approved. """ - from typing import Annotated, Optional from sqlalchemy.orm import Session @@ -26,40 +25,25 @@ def get_org_model(db: Session, request: Request, org_id: int): root = "/api/v1" - pre_approval_endpoints = [ - f"PATCH{root}/org/status", - f"PATCH{root}/org/questionnaire", - f"GET{root}/org", - f"GET{root}/org/contact", - f"PATCH{root}/org/contact", - ] + pre_approval_endpoints = [f"PATCH{root}/org/status", f"PATCH{root}/org/questionnaire", f"GET{root}/org", f"GET{root}/org/contact", f"PATCH{root}/org/contact"] current_request = f"{request.method}{request.url.path}" - if ( - current_request not in pre_approval_endpoints - and org_model.status != OrgStatus.APPROVED - ): + if current_request not in pre_approval_endpoints and org_model.status != OrgStatus.APPROVED: raise AwaitingApprovalException(org_id) return org_model -def get_org_model_query( - db: db_dependency, request: Request, org_id: Annotated[int, Query(gt=0)] -) -> type[Org]: +def get_org_model_query(db: db_dependency, request: Request, org_id: Annotated[int, Query(gt=0)]) -> type[Org]: return get_org_model(db, request, org_id) - org_model_query_dependency = Annotated[type[Org], Depends(get_org_model_query)] -def get_org_model_body( - db: db_dependency, request: Request, request_model: OrgIDMixin -) -> type[Org]: +def get_org_model_body(db: db_dependency, request: Request, request_model: OrgIDMixin) -> type[Org]: org_id: Optional[int] = getattr(request_model, "organisation_id", None) if org_id is None: raise OrgNotFoundException return get_org_model(db, request, org_id) - org_model_body_dependency = Annotated[type[Org], Depends(get_org_model_body)] diff --git a/src/organisation/exceptions.py b/src/organisation/exceptions.py index a56b395..8fe61cc 100644 --- a/src/organisation/exceptions.py +++ b/src/organisation/exceptions.py @@ -5,7 +5,6 @@ Exceptions: - OrgNotFoundException: Takes an optional org_id int - AwaitingApprovalException: Takes an optional org_id int """ - from typing import Optional from fastapi import HTTPException, status @@ -13,24 +12,15 @@ from fastapi import HTTPException, status class OrgNotFoundException(HTTPException): def __init__(self, org_id: Optional[int] = None) -> None: - detail = ( - "Organisation not found" - if org_id is None - else f"Organisation with ID '{org_id}' was not found." - ) + detail = "Organisation not found" if org_id is None else f"Organisation with ID '{org_id}' was not found." super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=detail, ) - class AwaitingApprovalException(HTTPException): def __init__(self, org_id: Optional[int] = None) -> None: - detail = ( - "Organisation has not been approved." - if org_id is None - else f"Organisation with ID '{org_id}' has not been approved." - ) + detail = "Organisation has not been approved." if org_id is None else f"Organisation with ID '{org_id}' has not been approved." super().__init__( status_code=status.HTTP_401_UNAUTHORIZED, detail=detail, diff --git a/src/organisation/models.py b/src/organisation/models.py index e99d64f..3663f2c 100644 --- a/src/organisation/models.py +++ b/src/organisation/models.py @@ -13,7 +13,6 @@ Models: - owner_contact_rel: ORM relationship to Contact with owner_contact FK - OrgUsers: org_id[FK][PK], user_id[FK][PK] """ - from sqlalchemy import Column, Integer, String, ForeignKey, JSON from sqlalchemy.orm import relationship @@ -35,7 +34,9 @@ class Organisation(Base): owner_contact_id = Column(Integer, ForeignKey("contact.id")) user_rel = relationship( - "User", secondary="orgusers", back_populates="organisation_rel" + "User", + secondary="orgusers", + back_populates="organisation_rel" ) group_rel = relationship("Group", back_populates="org_rel") @@ -53,9 +54,5 @@ class Organisation(Base): class OrgUsers(Base): __tablename__ = "orgusers" - org_id = Column( - Integer, ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True - ) - user_id = Column( - Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True - ) + org_id = Column(Integer, ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True) + user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE"), primary_key=True) diff --git a/src/organisation/router.py b/src/organisation/router.py index 3f4fcbc..90a3627 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -15,11 +15,11 @@ Endpoints: - [GET](/org/contact): [root user]: Gets the (contact_type) contact for an org(id) - [PATCH](/org/contact): [root user]: Updates the (contact_type) contact for an org(id). Any number of details can be changed. """ - from typing import Annotated from fastapi import APIRouter, status from fastapi.params import Query +from psycopg.errors import UniqueViolation from sqlalchemy.exc import IntegrityError from src.contact.schemas import ContactModel @@ -28,42 +28,17 @@ from src.contact.models import Contact from src.contact.schemas import ContactAddress from src.contact.exceptions import ContactNotFoundException from src.database import db_dependency -from src.user.dependencies import ( - user_model_body_dependency, - user_model_claims_dependency, - user_model_query_dependency, -) -from src.auth.dependencies import ( - super_admin_dependency, - org_model_root_claim_query_dependency, - org_model_root_claim_body_dependency, -) +from src.user.dependencies import user_model_body_dependency, user_model_claims_dependency +from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency, org_model_root_claim_body_dependency -from src.organisation.dependencies import ( - org_model_body_dependency, - org_model_query_dependency, -) +from src.organisation.dependencies import org_model_body_dependency from src.organisation.constants import ContactType from src.organisation.models import Organisation as Org -from src.organisation.schemas import ( - OrgPostOrgRequest, - OrgPatchQuestionnaireRequest, - OrgPatchStatusRequest, - OrgPatchContactRequest, - OrgPostUserRequest, - OrgGetUserResponse, - OrgGetContactResponse, - OrgGetOrgResponse, - OrgPatchRootRequest, - OrgGetGroupResponse, - OrgPostOrgResponse, - OrgPatchQuestionnaireResponse, - OrgPatchStatusResponse, - OrgPostUserResponse, - OrgPatchRootResponse, - Questionnaire, - OrgPatchContactResponse, -) +from src.organisation.schemas import OrgPostOrgRequest, OrgPatchQuestionnaireRequest, OrgPatchStatusRequest, \ + OrgPatchContactRequest, \ + OrgPostUserRequest, OrgGetUserResponse, OrgGetContactResponse, OrgGetOrgResponse, OrgPatchRootRequest, \ + OrgGetGroupResponse, OrgDeleteUserRequest, OrgDeleteOrgRequest, OrgPostOrgResponse, OrgPatchQuestionnaireResponse, \ + OrgPatchStatusResponse, OrgPostUserResponse, OrgPatchRootResponse, Questionnaire, OrgPatchContactResponse router = APIRouter( prefix="/org", @@ -71,22 +46,16 @@ router = APIRouter( ) -@router.get( - "", - summary="Get org details by ID.", - response_model=OrgGetOrgResponse, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - status.HTTP_404_NOT_FOUND: {"description": "Organisation not found"}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Missing or invalid org_id query parameter" - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Not authorised. Must be org root user." - }, - }, -) +@router.get("", + summary="Get org details by ID.", + response_model=OrgGetOrgResponse, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + status.HTTP_404_NOT_FOUND: {"description": "Organisation not found"}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Missing or invalid org_id query parameter"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, + }) async def get_org_by_id(org_model: org_model_root_claim_query_dependency): """ Returns organisation details including key member email addresses @@ -99,35 +68,23 @@ async def get_org_by_id(org_model: org_model_root_claim_query_dependency): "billing_contact": org_model.billing_contact_rel.email, "security_contact": org_model.security_contact_rel.email, "root_user": org_model.root_user_email, - "intake_questionnaire": org_model.intake_questionnaire, + "intake_questionnaire": org_model.intake_questionnaire } return response -@router.post( - "", - summary="Create new organisation.", - status_code=status.HTTP_201_CREATED, - response_model=OrgPostOrgResponse, - responses={ - status.HTTP_201_CREATED: {"description": "Successfully created organisation."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "User must be logged in with OIDC to create organisation." - }, - status.HTTP_409_CONFLICT: { - "description": "Organisation with this name already exists." - }, - }, -) -async def create_org( - db: db_dependency, - user_model: user_model_claims_dependency, - request_model: OrgPostOrgRequest, -): +@router.post("", + summary="Create new organisation.", + status_code=status.HTTP_201_CREATED, + response_model=OrgPostOrgResponse, + responses={ + status.HTTP_201_CREATED: {"description": "Successfully created organisation."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_401_UNAUTHORIZED: {"description": "User must be logged in with OIDC to create organisation."}, + status.HTTP_409_CONFLICT: {"description": "Organisation with this name already exists."}, + }) +async def create_org(db: db_dependency, user_model: user_model_claims_dependency, request_model: OrgPostOrgRequest): """ Creates a new organisation with optional questionnaire (to be completed or submitted). ALl organisations are given the "partial" status on creation. See update_questionnaire() for more details. @@ -144,21 +101,12 @@ async def create_org( try: db.flush() except IntegrityError as e: - if ( - getattr(e.orig, "pgcode", None) == "23505" # Postgres unique violation - or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation - ): - raise ConflictException( - message="Organisation with this name already exists" - ) + if isinstance(e.orig, UniqueViolation): + raise ConflictException(message="Organisation with this name already exists") # Adds currently logged-in user to org users list and sets them as root_user org_model.user_rel.append(user_model) org_model.root_user_rel = user_model - for contact_type in [ - "billing_contact_id", - "security_contact_id", - "owner_contact_id", - ]: + for contact_type in ["billing_contact_id", "security_contact_id", "owner_contact_id"]: contact_model = Contact(org_id=org_model.id) db.add(contact_model) db.flush() @@ -168,26 +116,16 @@ async def create_org( return response -@router.patch( - "/questionnaire", - summary="Update questionnaire.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchQuestionnaireResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated questionnaire."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Not authorised. Must be org root user." - }, - }, -) -async def update_questionnaire( - db: db_dependency, - org_model: org_model_root_claim_body_dependency, - request_model: OrgPatchQuestionnaireRequest, -): +@router.patch("/questionnaire", + summary="Update questionnaire.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchQuestionnaireResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated questionnaire."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, + }) +async def update_questionnaire(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchQuestionnaireRequest): """ Route for updating questionnaire. The partial bool allows for submission of partially completed questionnaire and/or @@ -212,29 +150,16 @@ async def update_questionnaire( return response -@router.patch( - "/status", - summary="Update status of organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchStatusResponse, - responses={ - status.HTTP_200_OK: { - "description": "Successfully updated organisation status." - }, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Not authorised. Must be super admin." - }, - }, -) -async def update_status( - db: db_dependency, - org_model: org_model_body_dependency, - su: super_admin_dependency, - request_model: OrgPatchStatusRequest, -): +@router.patch("/status", + summary="Update status of organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchStatusResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated organisation status."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be super admin."}, + }) +async def update_status(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchStatusRequest): """ Sets an organisation's status. This is the endpoint for approving or denying an organisation after reviewing the questionnaire. """ @@ -245,57 +170,33 @@ async def update_status( return response -@router.get( - "/users", - summary="Get email addresses of users of the organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgGetUserResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval of users."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Org ID missing or invalid." - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Not authorised. Must be org root user." - }, - }, -) +@router.get("/users", + summary="Get email addresses of users of the organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgGetUserResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval of users."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Org ID missing or invalid."}, + status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, + }) async def get_users(org_model: org_model_root_claim_query_dependency): """ Returns a list of the email addresses of all users of the organisation. """ - return { - "users": [user.email for user in org_model.user_rel], - "organisation": org_model, - } + return {"users": [user.email for user in org_model.user_rel], "organisation": org_model} -@router.post( - "/user", - summary="Add user to the organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPostUserResponse, - responses={ - status.HTTP_200_OK: { - "description": "Successfully added user to the organisation." - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Not authorised. Must be org root user." - }, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, - status.HTTP_409_CONFLICT: { - "description": "User is already a member of the organisation." - }, - }, -) -async def add_user_to_org( - db: db_dependency, - org_model: org_model_root_claim_body_dependency, - user_model: user_model_body_dependency, - request_model: OrgPostUserRequest, -): +@router.post("/user", + summary="Add user to the organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPostUserResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully added user to the organisation."}, + status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_409_CONFLICT: {"description": "User is already a member of the organisation."}, + }) +async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgPostUserRequest): """ Adds a user to the organisation. """ @@ -308,27 +209,15 @@ async def add_user_to_org( return response -@router.delete( - "", - summary="Delete organisation from the hub.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: { - "description": "Successfully deleted organisation." - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Not authorised. Must be super admin." - }, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Org ID missing or invalid." - }, - }, -) -async def delete_organisation_by_id( - db: db_dependency, - org_model: org_model_query_dependency, - su: super_admin_dependency, -): +@router.delete("", + summary="Delete organisation from the hub.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."}, + status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be super admin."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Org ID missing or invalid."}, + }) +async def delete_organisation_by_id(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgDeleteOrgRequest): """ Removes an organisation from the hub. """ @@ -336,59 +225,37 @@ async def delete_organisation_by_id( db.commit() -@router.patch( - "/root_user", - summary="Update the root user of the organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchRootResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated root user."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Not authorised. Must be super admin." - }, - }, -) -async def update_root_user( - db: db_dependency, - org_model: org_model_body_dependency, - user_model: user_model_body_dependency, - su: super_admin_dependency, - request_model: OrgPatchRootRequest, -): +@router.patch("/root_user", + summary="Update the root user of the organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchRootResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated root user."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be super admin."}, + }) +async def update_root_user(db: db_dependency, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchRootRequest): """ Promotes an existing organisation user to the root user, giving them full control of the org. """ if user_model not in org_model.user_rel: - raise UnprocessableContentException( - message="This user does not belong to your organisation." - ) + raise UnprocessableContentException(message="This user does not belong to your organisation.") org_model.root_user_rel = user_model db.flush() - response = OrgPatchRootResponse( - name=org_model.name, root_user_email=org_model.root_user_email - ) + response = OrgPatchRootResponse(name=org_model.name, root_user_email=org_model.root_user_email) db.commit() return response -@router.get( - "/groups", - summary="Get all organisation IAM groups.", - status_code=status.HTTP_200_OK, - response_model=OrgGetGroupResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval of IAM groups."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Org ID missing or invalid." - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Not authorised. Must be org root user." - }, - }, -) +@router.get("/groups", + summary="Get all organisation IAM groups.", + status_code=status.HTTP_200_OK, + response_model=OrgGetGroupResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval of IAM groups."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Org ID missing or invalid."}, + status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, + }) async def get_org_groups(org_model: org_model_root_claim_query_dependency): """ Returns a list of the names of all IAM groups created by the organisation. @@ -396,25 +263,15 @@ async def get_org_groups(org_model: org_model_root_claim_query_dependency): return {"groups": [group.name for group in org_model.group_rel]} -@router.delete( - "/user", - summary="Remove user from organisation.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "Successfully removed user."}, - status.HTTP_401_UNAUTHORIZED: { - "description": "Not authorised. Must be org root user." - }, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, - }, -) -async def remove_user_from_org( - db: db_dependency, - org_model: org_model_root_claim_query_dependency, - user_model: user_model_query_dependency, -): +@router.delete("/user", + summary="Remove user from organisation.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "Successfully removed user."}, + status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + }) +async def remove_user_from_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgDeleteUserRequest): """ Revokes a user's membership in an organisation. """ @@ -425,27 +282,16 @@ async def remove_user_from_org( db.commit() -@router.get( - "/contact", - summary="Get contact for organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgGetContactResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval of contact."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Not authorised. Must be org root user." - }, - }, -) -async def get_contact( - org_model: org_model_root_claim_query_dependency, - contact_type: Annotated[ - ContactType, Query(description="Must be billing|security|owner") - ], -): +@router.get("/contact", + summary="Get contact for organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgGetContactResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval of contact."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, + }) +async def get_contact(org_model: org_model_root_claim_query_dependency, contact_type: Annotated[ContactType, Query(description="Must be billing|security|owner")]): """ Gets full details for a contact point at an organisation. """ @@ -463,33 +309,21 @@ async def get_contact( raise ContactNotFoundException() address = ContactAddress.model_validate(contact_model) - contact_response = ContactModel.model_construct( - **contact_model.__dict__, address=address - ) + contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address) return {"contact": contact_response, "organisation": org_model} -@router.patch( - "/contact", - summary="Update contact for organisation.", - status_code=status.HTTP_200_OK, - response_model=OrgPatchContactResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully updated contact."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, - status.HTTP_401_UNAUTHORIZED: { - "description": "Not authorised. Must be org root user." - }, - }, -) -async def update_contact( - db: db_dependency, - org_model: org_model_root_claim_body_dependency, - request_model: OrgPatchContactRequest, -): +@router.patch("/contact", + summary="Update contact for organisation.", + status_code=status.HTTP_200_OK, + response_model=OrgPatchContactResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully updated contact."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_401_UNAUTHORIZED: {"description": "Not authorised. Must be org root user."}, + }) +async def update_contact(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchContactRequest): """ Updates details for a contact point at an organisation. """ @@ -517,9 +351,7 @@ async def update_contact( db.flush() address = ContactAddress.model_validate(contact_model) - contact_response = ContactModel.model_construct( - **contact_model.__dict__, address=address - ) + contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address) db.commit() diff --git a/src/organisation/schemas.py b/src/organisation/schemas.py index 9f120f3..305c6f7 100644 --- a/src/organisation/schemas.py +++ b/src/organisation/schemas.py @@ -6,7 +6,6 @@ Models follow the nomenclature of: - Mixins: "Mixin" - Models: "" ie "OrgPostOrgRequest" """ - from typing import Optional from pydantic import EmailStr, ConfigDict, Field @@ -21,13 +20,11 @@ from src.organisation.constants import Status, ContactType class OrgIDMixin(CustomBaseModel): organisation_id: int = Field(gt=0) - class Questionnaire(CustomBaseModel): question_one: Optional[str] = None question_two: Optional[str] = None question_three: Optional[str] = None - class OrgSchema(CustomBaseModel): id: int name: str @@ -37,32 +34,26 @@ class OrgPostOrgRequest(CustomBaseModel): name: str intake_questionnaire: Optional[Questionnaire] = None - class OrgPostOrgResponse(CustomBaseModel): name: str status: Status - class OrgPatchQuestionnaireRequest(OrgIDMixin): intake_questionnaire: Questionnaire partial: bool - class OrgPatchQuestionnaireResponse(CustomBaseModel): name: str intake_questionnaire: Questionnaire status: Status - class OrgPatchStatusRequest(OrgIDMixin): status: Status - class OrgPatchStatusResponse(CustomBaseModel): name: str status: Status - class OrgPatchContactRequest(OrgIDMixin): contact_type: ContactType @@ -79,47 +70,41 @@ class OrgPatchContactRequest(OrgIDMixin): country_code: Optional[str] = None postal_code: Optional[str] = None - class OrgPostUserRequest(OrgIDMixin, UserIDMixin): pass - class OrgPostUserResponse(CustomBaseModel): users: list[str] +class OrgDeleteUserRequest(OrgIDMixin, UserIDMixin): + pass class OrgPatchRootRequest(OrgIDMixin, UserIDMixin): pass - class OrgPatchRootResponse(CustomBaseModel): name: str root_user_email: str - class OrgGetUserResponse(CustomBaseModel): users: list[str] organisation: OrgSchema - class OrgGetGroupResponse(CustomBaseModel): groups: list[str] - class OrgGetContactResponse(CustomBaseModel): model_config = ConfigDict(from_attributes=True, extra="ignore") contact: ContactModel organisation: OrgSchema - class OrgPatchContactResponse(CustomBaseModel): model_config = ConfigDict(from_attributes=True, extra="ignore") contact: ContactModel organisation: OrgSchema - class OrgGetOrgResponse(CustomBaseModel): id: int name: str @@ -129,3 +114,6 @@ class OrgGetOrgResponse(CustomBaseModel): billing_contact: Optional[str] = None security_contact: Optional[str] = None intake_questionnaire: Optional[Questionnaire] = None + +class OrgDeleteOrgRequest(OrgIDMixin): + pass diff --git a/src/organisation/service.py b/src/organisation/service.py index cfe3925..6d73399 100644 --- a/src/organisation/service.py +++ b/src/organisation/service.py @@ -1,3 +1,3 @@ """ Reusable business logic functions for the organisation module -""" +""" \ No newline at end of file diff --git a/src/organisation/utils.py b/src/organisation/utils.py index 0337df5..ead22ca 100644 --- a/src/organisation/utils.py +++ b/src/organisation/utils.py @@ -1,3 +1,3 @@ """ Non-business logic reusable functions and classes for the organisation module -""" +""" \ No newline at end of file diff --git a/src/schemas.py b/src/schemas.py index 484031b..812b574 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -5,7 +5,6 @@ Exports: - CustomBaseModel: Schema used for all other Pydantic models - ResourceName """ - from pydantic import BaseModel from typing import Optional diff --git a/src/service/config.py b/src/service/config.py index b23303f..5d4fd3b 100644 --- a/src/service/config.py +++ b/src/service/config.py @@ -1,3 +1,3 @@ """ Configurations for the services module -""" +""" \ No newline at end of file diff --git a/src/service/constants.py b/src/service/constants.py index 0007ba0..52a8701 100644 --- a/src/service/constants.py +++ b/src/service/constants.py @@ -1,3 +1,3 @@ """ Constants for the services module -""" +""" \ No newline at end of file diff --git a/src/service/dependencies.py b/src/service/dependencies.py index 9792f26..cda625a 100644 --- a/src/service/dependencies.py +++ b/src/service/dependencies.py @@ -5,7 +5,6 @@ Exports: - service_model_query_dependency: service_model: Gets service model from db, if it exists. Uses service_id from query param. - service_model_body_dependency: service_model: Gets service model from db, if it exists. Uses service_id from request body. """ - from typing import Annotated from fastapi import Depends, Query @@ -16,19 +15,14 @@ from src.service.models import Service from src.service.schemas import ServiceIDMixin -async def get_service_model_query( - db: db_dependency, service_id: Annotated[int, Query(gt=0)] -): +async def get_service_model_query(db: db_dependency, service_id: Annotated[int, Query(gt=0)]): service_model = db.get(Service, service_id) if service_model is None: raise ServiceNotFoundException(service_id=service_id) return service_model - -service_model_query_dependency = Annotated[ - type[Service], Depends(get_service_model_query) -] +service_model_query_dependency = Annotated[type[Service], Depends(get_service_model_query)] async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixin): @@ -38,7 +32,4 @@ async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixi return service_model - -service_model_body_dependency = Annotated[ - type[Service], Depends(get_service_model_body) -] +service_model_body_dependency = Annotated[type[Service], Depends(get_service_model_body)] diff --git a/src/service/exceptions.py b/src/service/exceptions.py index 36a927d..8a1a2e3 100644 --- a/src/service/exceptions.py +++ b/src/service/exceptions.py @@ -4,7 +4,6 @@ Exceptions related to the services module Exceptions: - ServiceNotFoundException: Takes an optional service_id int """ - from typing import Optional from fastapi import HTTPException, status @@ -12,11 +11,7 @@ from fastapi import HTTPException, status class ServiceNotFoundException(HTTPException): def __init__(self, service_id: Optional[int] = None) -> None: - detail = ( - "Service not found" - if service_id is None - else f"Service with ID '{service_id}' was not found." - ) + detail = "Service not found" if service_id is None else f"Service with ID '{service_id}' was not found." super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=detail, diff --git a/src/service/models.py b/src/service/models.py index 82bdba1..68fd020 100644 --- a/src/service/models.py +++ b/src/service/models.py @@ -5,7 +5,6 @@ Models: - Service: - id[PK], name[U], api_key[U] """ - from sqlalchemy import Column, Integer, String from src.database import Base diff --git a/src/service/router.py b/src/service/router.py index bfd282e..a8f93ea 100644 --- a/src/service/router.py +++ b/src/service/router.py @@ -7,31 +7,19 @@ Endpoints: - [PATCH](/key): [super_admin]: Refreshes the API key for a service(id), returning a new one. - [DELETE](/): [super_admin]: Removes a service(id) from the hub. """ - from fastapi import APIRouter, status +from psycopg.errors import UniqueViolation from sqlalchemy.exc import IntegrityError from src.exceptions import ConflictException from src.database import db_dependency -from src.auth.dependencies import ( - super_admin_dependency, - org_model_root_claim_query_dependency, -) +from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency from src.service.models import Service from src.service.utils import generate_api_key -from src.service.dependencies import ( - service_model_body_dependency, - service_model_query_dependency, -) -from src.service.schemas import ( - ServiceGetServiceResponse, - ServicePostServiceRequest, - ServicePostServiceResponse, - ServiceWithKeySchema, - ServicePatchKeyResponse, - ServicePatchKeyRequest, -) +from src.service.dependencies import service_model_body_dependency +from src.service.schemas import ServiceGetServiceResponse, ServicePostServiceRequest, ServicePostServiceResponse, \ + ServiceWithKeySchema, ServicePatchKeyResponse, ServicePatchKeyRequest, ServiceDeleteServiceRequest router = APIRouter( tags=["Service"], @@ -39,19 +27,15 @@ router = APIRouter( ) -@router.get( - "/", - summary="Get all services", - status_code=status.HTTP_200_OK, - response_model=ServiceGetServiceResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - }, -) -async def get_all_services( - db: db_dependency, org_model: org_model_root_claim_query_dependency -): +@router.get("/", + summary="Get all services", + status_code=status.HTTP_200_OK, + response_model=ServiceGetServiceResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }) +async def get_all_services(db: db_dependency, org_model: org_model_root_claim_query_dependency): """ Returns the ID and name of all services registered to the hub. """ @@ -60,24 +44,16 @@ async def get_all_services( return {"services": permission_models} -@router.post( - "/", - summary="Register a new service.", - status_code=status.HTTP_200_OK, - response_model=ServicePostServiceResponse, - responses={ - status.HTTP_200_OK: {"description": "Successfully registered a new service"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - status.HTTP_409_CONFLICT: { - "description": "Service with this name already exists" - }, - }, -) -async def register_service( - db: db_dependency, - su: super_admin_dependency, - request_model: ServicePostServiceRequest, -): +@router.post("/", + summary="Register a new service.", + status_code=status.HTTP_200_OK, + response_model=ServicePostServiceResponse, + responses={ + status.HTTP_200_OK: {"description": "Successfully registered a new service"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"}, + }) +async def register_service(db: db_dependency, su: super_admin_dependency, request_model: ServicePostServiceRequest): """ Registers a new service to the hub, generating and returning an API key for it. """ @@ -88,32 +64,23 @@ async def register_service( try: db.flush() except IntegrityError as e: - if ( - getattr(e.orig, "pgcode", None) == "23505" # Postgres unique violation - or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation - ): + if isinstance(e.orig, UniqueViolation): raise ConflictException(message="Service with this name already exists") response = ServiceWithKeySchema(**service_model.__dict__) db.commit() return {"service": response} -@router.patch( - "/key", - summary="Regenerate service API key.", - status_code=status.HTTP_200_OK, - response_model=ServicePatchKeyResponse, - responses={ - status.HTTP_200_OK: {"description": "Successful update of API key"}, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - }, -) -async def regenerate_api_key( - db: db_dependency, - su: super_admin_dependency, - service_model: service_model_body_dependency, - request_model: ServicePatchKeyRequest, -): +@router.patch("/key", + summary="Regenerate service API key.", + status_code=status.HTTP_200_OK, + response_model=ServicePatchKeyResponse, + responses={ + status.HTTP_200_OK: {"description": "Successful update of API key"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }) +async def regenerate_api_key(db: db_dependency, su: super_admin_dependency, + service_model: service_model_body_dependency, request_model: ServicePatchKeyRequest): """ Generates and returns a new API key for the service to access the hub. """ @@ -126,22 +93,15 @@ async def regenerate_api_key( return {"service": response} -@router.delete( - "/", - summary="Remove a service.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: { - "description": "Successfully removed service from db" - }, - status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - }, -) -async def remove_service( - db: db_dependency, - service_model: service_model_query_dependency, - su: super_admin_dependency, -): +@router.delete("/", + summary="Remove a service.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"}, + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + }) +async def remove_service(db: db_dependency, service_model: service_model_body_dependency, su: super_admin_dependency, + request_model: ServiceDeleteServiceRequest): """ Removes a service from the hub. """ diff --git a/src/service/schemas.py b/src/service/schemas.py index 47d47f5..50e7c35 100644 --- a/src/service/schemas.py +++ b/src/service/schemas.py @@ -6,42 +6,36 @@ Models follow the nomenclature of: - Mixins: "Mixin" - Models: "" ie "ServiceGetServiceResponse" """ - from pydantic import ConfigDict, Field from src.schemas import CustomBaseModel - class ServiceIDMixin(CustomBaseModel): service_id: int = Field(gt=0) - class ServiceSchema(CustomBaseModel): model_config = ConfigDict(from_attributes=True, extra="ignore") id: int name: str - class ServiceWithKeySchema(ServiceSchema): api_key: str - class ServiceGetServiceResponse(CustomBaseModel): services: list[ServiceSchema] - class ServicePostServiceRequest(CustomBaseModel): name: str - class ServicePostServiceResponse(CustomBaseModel): service: ServiceWithKeySchema - class ServicePatchKeyRequest(ServiceIDMixin): pass - class ServicePatchKeyResponse(CustomBaseModel): service: ServiceWithKeySchema + +class ServiceDeleteServiceRequest(ServiceIDMixin): + pass diff --git a/src/service/service.py b/src/service/service.py index 2f59b44..2609565 100644 --- a/src/service/service.py +++ b/src/service/service.py @@ -1,3 +1,3 @@ """ Business logic for the services module -""" +""" \ No newline at end of file diff --git a/src/service/utils.py b/src/service/utils.py index 79bb91f..8920a5f 100644 --- a/src/service/utils.py +++ b/src/service/utils.py @@ -4,7 +4,6 @@ Non-business logic reusable functions and classes for the services module Exports: - generate_api_key(): returns a new UUID """ - import uuid diff --git a/src/user/config.py b/src/user/config.py index a25018d..9bbcbc4 100644 --- a/src/user/config.py +++ b/src/user/config.py @@ -1,3 +1,3 @@ """ Configurations for the user module -""" +""" \ No newline at end of file diff --git a/src/user/constants.py b/src/user/constants.py index 2adb24d..fc6a780 100644 --- a/src/user/constants.py +++ b/src/user/constants.py @@ -1,3 +1,3 @@ """ Constants for the user module -""" +""" \ No newline at end of file diff --git a/src/user/dependencies.py b/src/user/dependencies.py index 9d2fe02..dc22429 100644 --- a/src/user/dependencies.py +++ b/src/user/dependencies.py @@ -6,7 +6,6 @@ Exports: - user_model_query_dependency: user_model: Gets user model from db, if it exists. Uses user_id from query param - user_model_body_dependency: user_model: Gets user model from db, if it exists. Uses user_id from request body. """ - from typing import Annotated from fastapi import Depends, Query @@ -29,7 +28,6 @@ async def get_user_model_claims(claims: claims_dependency, db: db_dependency): return user_model - user_model_claims_dependency = Annotated[type[User], Depends(get_user_model_claims)] @@ -40,7 +38,6 @@ async def get_user_model_query(db: db_dependency, user_id: Annotated[int, Query( return user_model - user_model_query_dependency = Annotated[type[User], Depends(get_user_model_query)] @@ -51,5 +48,4 @@ async def get_user_model_body(db: db_dependency, request_model: UserIDMixin): return user_model - user_model_body_dependency = Annotated[type[User], Depends(get_user_model_body)] diff --git a/src/user/exceptions.py b/src/user/exceptions.py index fa4db2f..2e03b05 100644 --- a/src/user/exceptions.py +++ b/src/user/exceptions.py @@ -4,7 +4,6 @@ Exceptions related to the user module Exceptions: - UserNotFoundException: Takes an optional user_id int """ - from typing import Optional from fastapi import HTTPException, status @@ -12,11 +11,7 @@ from fastapi import HTTPException, status class UserNotFoundException(HTTPException): def __init__(self, user_id: Optional[int] = None) -> None: - detail = ( - "User not found" - if user_id is None - else f"User with ID '{user_id}' was not found." - ) + detail = "User not found" if user_id is None else f"User with ID '{user_id}' was not found." super().__init__( status_code=status.HTTP_404_NOT_FOUND, detail=detail, diff --git a/src/user/models.py b/src/user/models.py index 62c7ef6..97c0f22 100644 --- a/src/user/models.py +++ b/src/user/models.py @@ -9,7 +9,6 @@ Models: - organisations: Calc property list of organisation_rel.name - groups: Calc property dict of {group_rel.org_rel.name: group_rel.name} """ - from collections import defaultdict from sqlalchemy import Column, Integer, String @@ -19,29 +18,29 @@ from src.database import Base class User(Base): - __tablename__ = "user" + __tablename__ = "user" - id = Column(Integer, primary_key=True) - email = Column(String) - first_name = Column(String) - last_name = Column(String) - oidc_id = Column(String, index=True, unique=True) + id = Column(Integer, primary_key=True) + email = Column(String) + first_name = Column(String) + last_name = Column(String) + oidc_id = Column(String, index=True, unique=True) - organisation_rel = relationship( - "Organisation", secondary="orgusers", back_populates="user_rel" - ) + organisation_rel = relationship( + "Organisation", secondary="orgusers", back_populates="user_rel" + ) - @property - def organisations(self): - return [{"name": org.name, "id": org.id} for org in self.organisation_rel] + @property + def organisations(self): + return [{"name": org.name, "id": org.id} for org in self.organisation_rel] - group_rel = relationship( - "Group", secondary="user_groups", back_populates="user_rel" - ) + group_rel = relationship( + "Group", secondary="user_groups", back_populates="user_rel" + ) - @property - def groups(self): - result = defaultdict(list) - for group in self.group_rel: - result[group.org_rel.name].append({"name": group.name, "id": group.id}) - return dict(result) + @property + def groups(self): + result = defaultdict(list) + for group in self.group_rel: + result[group.org_rel.name].append({"name": group.name, "id": group.id}) + return dict(result) diff --git a/src/user/router.py b/src/user/router.py index 8613d47..0966850 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -7,16 +7,11 @@ Endpoints: - [GET](/user/): [super admin]: Returns user(id) details. - [DELETE](/user/): [super admin]: Removes a User(id) from the hub database. """ - from fastapi import APIRouter from starlette import status -from src.user.schemas import UserResponse, OIDCClaims -from src.user.dependencies import ( - user_model_claims_dependency, - user_model_query_dependency, - user_model_body_dependency, -) +from src.user.schemas import UserResponse, OIDCClaims, UserDeleteUserRequest +from src.user.dependencies import user_model_claims_dependency, user_model_query_dependency, user_model_body_dependency from src.auth.dependencies import super_admin_dependency from src.auth.service import claims_dependency @@ -28,15 +23,13 @@ router = APIRouter( ) -@router.get( - "/self/claims", - summary="Get current user OIDC claims.", - response_model=OIDCClaims, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - }, -) +@router.get("/self/claims", + summary="Get current user OIDC claims.", + response_model=OIDCClaims, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + }) async def current_user_claims(user: claims_dependency): """ Returns the full OIDC claims associated with the currently logged-in user. @@ -45,16 +38,14 @@ async def current_user_claims(user: claims_dependency): return user -@router.get( - "/self/db", - summary="Get current user hub details.", - response_model=UserResponse, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - }, -) +@router.get("/self/db", + summary="Get current user hub details.", + response_model=UserResponse, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + }) async def current_user(user_model: user_model_claims_dependency): """ Returns the database details associated with the currently logged-in user. @@ -62,39 +53,30 @@ async def current_user(user_model: user_model_claims_dependency): return user_model -@router.get( - "/", - summary="Get user hub details by ID.", - response_model=UserResponse, - status_code=status.HTTP_200_OK, - responses={ - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - status.HTTP_200_OK: {"description": "Successful retrieval from database"}, - }, -) -async def get_user_by_id( - user_model: user_model_query_dependency, su: super_admin_dependency -): +@router.get("/", + summary="Get user hub details by ID.", + response_model=UserResponse, + status_code=status.HTTP_200_OK, + responses={ + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + status.HTTP_200_OK: {"description": "Successful retrieval from database"}, + }) +async def get_user_by_id(user_model: user_model_query_dependency, su: super_admin_dependency): """ Returns the database details associated with the provided user ID. """ return user_model -@router.delete( - "/", - summary="Delete user from hub by ID.", - status_code=status.HTTP_204_NO_CONTENT, - responses={ - status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, - status.HTTP_404_NOT_FOUND: {"description": "User not found"}, - }, -) -async def delete_user_by_id( - db: db_dependency, - user_model: user_model_query_dependency, - su: super_admin_dependency, -): +@router.delete("/", + summary="Delete user from hub by ID.", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + status.HTTP_204_NO_CONTENT: {"description": "User deleted"}, + status.HTTP_404_NOT_FOUND: {"description": "User not found"}, + }) +async def delete_user_by_id(db: db_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, + request_model: UserDeleteUserRequest): """ Deletes the user with the provided ID from the database. This will not remove them from OIDC, and they will be automatically readded on next login. """ diff --git a/src/user/schemas.py b/src/user/schemas.py index 8ef46df..211004c 100644 --- a/src/user/schemas.py +++ b/src/user/schemas.py @@ -1,7 +1,6 @@ """ Pydantic models for the user module """ - from typing import Optional from pydantic import Field @@ -48,10 +47,14 @@ class UserResponse(CustomBaseModel): first_name: str last_name: str email: str - organisations: list[Optional[dict[str, str | int]]] - groups: Optional[dict[str, list[dict[str, str | int]]]] = None + organisations: list[Optional[dict[str, str|int]]] + groups: Optional[dict[str, list[dict[str, str|int]]]] = None class OrgResponse(CustomBaseModel): org_id: int name: str + + +class UserDeleteUserRequest(UserIDMixin): + pass \ No newline at end of file diff --git a/src/user/service.py b/src/user/service.py index 49ab238..e072251 100644 --- a/src/user/service.py +++ b/src/user/service.py @@ -4,7 +4,6 @@ Module specific business logic for user module Exports: - add_user_to_db: Creates a User record from OIDC claims, or updates user details """ - from typing import Any from sqlalchemy.orm import Session @@ -17,12 +16,7 @@ from src.user.models import User async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: try: - valid_user = OIDCUser( - first_name=user_claims["given_name"], - last_name=user_claims["family_name"], - email=user_claims["email"], - oidc_id=user_claims["sub"], - ) + valid_user = OIDCUser(first_name=user_claims["given_name"], last_name=user_claims["family_name"], email=user_claims["email"], oidc_id=user_claims["sub"]) except Exception as e: print(e) raise UnprocessableContentException("Invalid or missing OIDC data") diff --git a/src/user/utils.py b/src/user/utils.py index 2f10b91..35fcc1a 100644 --- a/src/user/utils.py +++ b/src/user/utils.py @@ -1,3 +1,3 @@ """ Non-business logic reusable functions and classes for the user module -""" +""" \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py index e06b3d0..e16e799 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -37,14 +37,11 @@ def db_session(): async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]: def get_db_override(): return db_session - app.dependency_overrides[get_db] = get_db_override app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_super_admin_list] = testing_su_list transport = ASGITransport(app=app) - async with AsyncClient( - transport=transport, base_url="http://localhost:8000/api/v1" - ) as ac: + async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac: yield ac app.dependency_overrides.clear() @@ -54,58 +51,37 @@ async def default_client(db_session) -> AsyncGenerator[AsyncClient, None]: async def no_user_client(db_session) -> AsyncGenerator[AsyncClient, None]: def get_db_override(): return db_session - app.dependency_overrides[get_db] = get_db_override transport = ASGITransport(app=app) - async with AsyncClient( - transport=transport, base_url="http://localhost:8000/api/v1" - ) as ac: + async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac: yield ac app.dependency_overrides.clear() + @pytest.fixture async def no_su_client(db_session) -> AsyncGenerator[AsyncClient, None]: def get_db_override(): return db_session - app.dependency_overrides[get_db] = get_db_override app.dependency_overrides[get_current_user] = get_dev_user app.dependency_overrides[get_super_admin_list] = empty_su_list transport = ASGITransport(app=app) - async with AsyncClient( - transport=transport, base_url="http://localhost:8000/api/v1" - ) as ac: + async with AsyncClient(transport=transport, base_url="http://localhost:8000/api/v1") as ac: yield ac app.dependency_overrides.clear() def _seed(db): - db.add( - User( - email="admin@test.com", - first_name="Admin", - last_name="Test", - oidc_id="abcd-efgh-ijkl-mnop", - ) - ) + db.add(User(email="admin@test.com", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-mnop")) db.add(Contact(org_id=1, email="billing@test.org", phonenumber="07521539927")) db.add(Contact(org_id=1, email="owner@test.org", phonenumber="07521539927")) db.add(Contact(org_id=1, email="security@test.org", phonenumber="07521539927")) db.flush() - db.add( - Org( - name="Test Org", - root_user_id=1, - billing_contact_id=1, - owner_contact_id=2, - security_contact_id=3, - status="approved", - intake_questionnaire={"question_two": "answer two"}, - ) - ) + db.add(Org(name="Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, + status="approved", intake_questionnaire={"question_two": "answer two"})) db.add(Service(name="Test Service", api_key="123456789")) db.add(Permission(service_id=1, resource="test_resource", action="read")) db.add(Group(name="Test Group", org_id=1)) @@ -155,7 +131,6 @@ def generate_query_and_status(params) -> list[tuple[str, int]]: return query_and_status - # # Produces a text file with method and path for every endpoint in the API # from fastapi.routing import APIRoute # diff --git a/test/test_auth_approval.py b/test/test_auth_approval.py index 69c8a25..f395865 100644 --- a/test/test_auth_approval.py +++ b/test/test_auth_approval.py @@ -3,7 +3,6 @@ This test module checks relevant endpoints to ensure only approved orgs get acce Endpoints not checked here are endpoints that do not require an org check. Delete endpoints are currently skipped because the testing system cannot use bodies in deletes. """ - import pytest from httpx import AsyncClient @@ -28,27 +27,18 @@ async def test_get_org_auth_approval(default_client: AsyncClient): @pytest.mark.anyio async def test_patch_org_questionnaire_auth_approval(default_client: AsyncClient): - resp = await default_client.patch( - "/org/questionnaire", - json={ - "organisation_id": 1, - "intake_questionnaire": { - "question_one": "new answer one", - "question_two": None, - "question_three": None, - }, - "partial": True, - }, - ) + resp = await default_client.patch("/org/questionnaire", json={"organisation_id": 1, + "intake_questionnaire": {"question_one": "new answer one", + "question_two": None, + "question_three": None}, + "partial": True}) assert resp.status_code != 422 assert resp.status_code == 200 @pytest.mark.anyio async def test_patch_org_status_auth_approval(default_client: AsyncClient): - resp = await default_client.patch( - "/org/status", json={"organisation_id": 1, "status": "submitted"} - ) + resp = await default_client.patch("/org/status", json={"organisation_id": 1, "status": "submitted"}) assert resp.status_code != 422 assert resp.status_code == 200 @@ -62,42 +52,22 @@ async def test_get_org_users_auth_approval(default_client: AsyncClient): @pytest.mark.anyio async def test_post_org_user_auth_approval(default_client: AsyncClient, db_session): - db_session.add( - User( - email="user@test.org", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-1234", - ) - ) + db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.flush() - resp = await default_client.post( - "/org/user", json={"organisation_id": 1, "user_id": 2} - ) + resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 2}) assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_patch_org_root_user_auth_approval( - default_client: AsyncClient, db_session -): - db_session.add( - User( - email="user@test.org", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-1234", - ) - ) +async def test_patch_org_root_user_auth_approval(default_client: AsyncClient, db_session): + db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.flush() db_session.add(OrgUsers(org_id=1, user_id=2)) db_session.flush() - resp = await default_client.patch( - "/org/root_user", json={"organisation_id": 1, "user_id": 2} - ) + resp = await default_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2}) assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @@ -118,14 +88,8 @@ async def test_get_org_contact_auth_approval(default_client: AsyncClient): @pytest.mark.anyio async def test_patch_org_contact_auth_approval(default_client: AsyncClient): - resp = await default_client.patch( - "/org/contact", - json={ - "organisation_id": 1, - "contact_type": "billing", - "email": "user@example.com", - }, - ) + resp = await default_client.patch("/org/contact", + json={"organisation_id": 1, "contact_type": "billing", "email": "user@example.com"}) assert resp.status_code != 422 assert resp.status_code == 200 @@ -153,44 +117,26 @@ async def test_get_iam_group_users_auth_approval(default_client: AsyncClient): @pytest.mark.anyio async def test_post_iam_group_auth_approval(default_client: AsyncClient): - resp = await default_client.post( - "/iam/group", json={"name": "New Group", "organisation_id": 1} - ) + resp = await default_client.post("/iam/group", json={"name": "New Group", "organisation_id": 1}) assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_put_iam_group_permission_auth_approval( - default_client: AsyncClient, db_session -): +async def test_put_iam_group_permission_auth_approval(default_client: AsyncClient, db_session): db_session.add(Group(name="Test Group Two", org_id=1)) db_session.flush() - resp = await default_client.put( - "/iam/group/permission", - json={"permission_id": 1, "group_id": 2, "organisation_id": 1}, - ) + resp = await default_client.put("/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 1}) assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @pytest.mark.anyio -async def test_put_iam_group_user_auth_approval( - default_client: AsyncClient, db_session -): - db_session.add( - User( - email="user@test.org", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-1234", - ) - ) +async def test_put_iam_group_user_auth_approval(default_client: AsyncClient, db_session): + db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.flush() - resp = await default_client.put( - "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1} - ) + resp = await default_client.put("/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1}) assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] @@ -204,8 +150,6 @@ async def test_get_iam_permissions_auth_approval(default_client: AsyncClient): @pytest.mark.anyio async def test_post_iam_permissions_search_auth_approval(default_client: AsyncClient): - resp = await default_client.post( - "/iam/permissions/search", json={"organisation_id": 1, "action": "read"} - ) + resp = await default_client.post("/iam/permissions/search", json={"organisation_id": 1, "action": "read"}) assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] diff --git a/test/test_auth_general.py b/test/test_auth_general.py index 0547824..6aeba96 100644 --- a/test/test_auth_general.py +++ b/test/test_auth_general.py @@ -1,5 +1,5 @@ -""" """ - +""" +""" import pytest from httpx import AsyncClient @@ -10,26 +10,11 @@ from src.user.models import User @pytest.mark.anyio async def test_get_org_auth_root_su(default_client: AsyncClient, db_session): # If a super admin can access a resource when not the root user - db_session.add( - User( - email="admin@test.org", - first_name="Admin", - last_name="Test", - oidc_id="abcd-efgh-ijkl-4321", - ) - ) + db_session.add(User(email="admin@test.org", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-4321")) db_session.flush() db_session.add( - Org( - name="Test Org Two", - root_user_id=2, - billing_contact_id=1, - owner_contact_id=2, - security_contact_id=3, - status="approved", - intake_questionnaire={}, - ) - ) + Org(name="Test Org Two", root_user_id=2, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, + status="approved", intake_questionnaire={})) db_session.flush() resp = await default_client.get("/org?org_id=2") diff --git a/test/test_auth_root.py b/test/test_auth_root.py index e67bc6a..16e3afa 100644 --- a/test/test_auth_root.py +++ b/test/test_auth_root.py @@ -2,7 +2,6 @@ This module ensures root user only endpoints do return a correctly formatted 401 when user is not the root user for the org DELETE endpoints are not tested """ - import pytest from httpx import AsyncClient @@ -13,26 +12,10 @@ from src.iam.models import Group @pytest.fixture(autouse=True) def add_second_org(db_session): - db_session.add( - User( - email="admin@test.org", - first_name="Admin", - last_name="Test", - oidc_id="abcd-efgh-ijkl-4321", - ) - ) + db_session.add(User(email="admin@test.org", first_name="Admin", last_name="Test", oidc_id="abcd-efgh-ijkl-4321")) db_session.flush() - db_session.add( - Org( - name="Test Org Two", - root_user_id=2, - billing_contact_id=1, - owner_contact_id=2, - security_contact_id=3, - status="approved", - intake_questionnaire={}, - ) - ) + db_session.add(Org(name="Test Org Two", root_user_id=2, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, + status="approved", intake_questionnaire={})) db_session.flush() @@ -46,18 +29,11 @@ async def test_get_org_auth_root(no_su_client: AsyncClient): @pytest.mark.anyio async def test_patch_org_questionnaire_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/questionnaire", - json={ - "organisation_id": 2, - "intake_questionnaire": { - "question_one": "new answer one", - "question_two": None, - "question_three": None, - }, - "partial": True, - }, - ) + resp = await no_su_client.patch("/org/questionnaire", json={"organisation_id": 2, + "intake_questionnaire": {"question_one": "new answer one", + "question_two": None, + "question_three": None}, + "partial": True}) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] @@ -73,19 +49,10 @@ async def test_get_org_users_auth_root(no_su_client: AsyncClient): @pytest.mark.anyio async def test_post_org_user_auth_root(no_su_client: AsyncClient, db_session): - db_session.add( - User( - email="user@test.org", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-1234", - ) - ) + db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.flush() - resp = await no_su_client.post( - "/org/user", json={"organisation_id": 2, "user_id": 2} - ) + resp = await no_su_client.post("/org/user", json={"organisation_id": 2, "user_id": 2}) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] @@ -109,14 +76,8 @@ async def test_get_org_contact_auth_root(no_su_client: AsyncClient): @pytest.mark.anyio async def test_patch_org_contact_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/contact", - json={ - "organisation_id": 2, - "contact_type": "billing", - "email": "user@example.com", - }, - ) + resp = await no_su_client.patch("/org/contact", + json={"organisation_id": 2, "contact_type": "billing", "email": "user@example.com"}) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] @@ -148,24 +109,17 @@ async def test_get_iam_group_users_auth_root(no_su_client: AsyncClient): @pytest.mark.anyio async def test_post_iam_group_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.post( - "/iam/group", json={"name": "New Group", "organisation_id": 2} - ) + resp = await no_su_client.post("/iam/group", json={"name": "New Group", "organisation_id": 2}) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] @pytest.mark.anyio -async def test_put_iam_group_permission_auth_root( - no_su_client: AsyncClient, db_session -): +async def test_put_iam_group_permission_auth_root(no_su_client: AsyncClient, db_session): db_session.add(Group(name="Test Group Two", org_id=2)) db_session.flush() - resp = await no_su_client.put( - "/iam/group/permission", - json={"permission_id": 1, "group_id": 2, "organisation_id": 2}, - ) + resp = await no_su_client.put("/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 2}) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] @@ -173,19 +127,10 @@ async def test_put_iam_group_permission_auth_root( @pytest.mark.anyio async def test_put_iam_group_user_auth_root(no_su_client: AsyncClient, db_session): - db_session.add( - User( - email="user@test.org", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-1234", - ) - ) + db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.flush() - resp = await no_su_client.put( - "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2} - ) + resp = await no_su_client.put("/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 2}) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] @@ -201,9 +146,7 @@ async def test_get_iam_permissions_auth_root(no_su_client: AsyncClient): @pytest.mark.anyio async def test_post_iam_permissions_search_auth_root(no_su_client: AsyncClient): - resp = await no_su_client.post( - "/iam/permissions/search", json={"organisation_id": 2, "action": "read"} - ) + resp = await no_su_client.post("/iam/permissions/search", json={"organisation_id": 2, "action": "read"}) assert resp.status_code != 422 assert resp.status_code == 401 assert "Must be the org's root user" in resp.json()["detail"] diff --git a/test/test_auth_su.py b/test/test_auth_su.py index 18319a4..9cef61f 100644 --- a/test/test_auth_su.py +++ b/test/test_auth_su.py @@ -2,7 +2,6 @@ This module ensures super admin only endpoints do return a correctly formatted 401 when user is not a super admin DELETE endpoints are not tested """ - import pytest from httpx import AsyncClient @@ -20,9 +19,7 @@ async def test_get_user_auth_su(no_su_client: AsyncClient): @pytest.mark.anyio async def test_patch_org_status_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.patch( - "/org/status", json={"organisation_id": 1, "status": "submitted"} - ) + resp = await no_su_client.patch("/org/status", json={"organisation_id": 1, "status": "submitted"}) assert resp.status_code != 422 assert resp.status_code == 401 assert resp.json()["detail"] == "Must be super admin" @@ -30,21 +27,12 @@ async def test_patch_org_status_auth_su(no_su_client: AsyncClient): @pytest.mark.anyio async def test_patch_org_root_user_auth_su(no_su_client: AsyncClient, db_session): - db_session.add( - User( - email="user@test.org", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-1234", - ) - ) + db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.flush() db_session.add(OrgUsers(org_id=1, user_id=2)) db_session.flush() - resp = await no_su_client.patch( - "/org/root_user", json={"organisation_id": 1, "user_id": 2} - ) + resp = await no_su_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2}) assert resp.status_code != 422 assert resp.status_code == 401 assert resp.json()["detail"] == "Must be super admin" @@ -68,10 +56,7 @@ async def test_post_service_auth_su(no_su_client: AsyncClient): @pytest.mark.anyio async def test_post_perm_success(no_su_client: AsyncClient, db_session): - resp = await no_su_client.post( - "/iam/permission", - json={"service_id": 1, "resource": "test_resource", "action": "create"}, - ) + resp = await no_su_client.post("/iam/permission", json={"service_id": 1, "resource": "test_resource", "action": "create"}) assert resp.status_code != 422 assert resp.status_code == 401 assert resp.json()["detail"] == "Must be super admin" diff --git a/test/test_auth_user.py b/test/test_auth_user.py index 3cd66d0..6f1c655 100644 --- a/test/test_auth_user.py +++ b/test/test_auth_user.py @@ -1,7 +1,6 @@ """ This testing module removes the testing user override to verify that endpoints with only the user requirement return a 401 error when not logged in """ - import pytest from httpx import AsyncClient diff --git a/test/test_healthcheck.py b/test/test_healthcheck.py index 47a3993..6fdb9be 100644 --- a/test/test_healthcheck.py +++ b/test/test_healthcheck.py @@ -4,7 +4,7 @@ from httpx import AsyncClient @pytest.mark.anyio async def test_healthcheck(default_client: AsyncClient): - resp = await default_client.get("/healthcheck") + resp = await default_client.get("/healthcheck") - assert resp.status_code == 200 - assert resp.json() == {"status": "ok"} + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} diff --git a/test/test_iam.py b/test/test_iam.py index 1599d7e..24a6478 100644 --- a/test/test_iam.py +++ b/test/test_iam.py @@ -1,5 +1,5 @@ -""" """ - +""" +""" import pytest from httpx import AsyncClient @@ -15,15 +15,13 @@ async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient body = { "service": "Test Service", "organisation": "Test Org", - "resource": "test_resource", + "resource": "test_resource" } headers = { "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": "123456789", + "X-API-Key": "123456789" } - resp = await default_client.post( - "/iam/can_act_on_resource?action=read", json=body, headers=headers - ) + resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers) data = resp.json() assert resp.status_code == 200 @@ -32,20 +30,23 @@ async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient @pytest.mark.parametrize( "service, api_key", - [("Test Service", "not_the_correct_key"), ("Test Service Two", "123456789")], + [ + ("Test Service", "not_the_correct_key"), + ("Test Service Two", "123456789") + ], ) @pytest.mark.anyio -async def test_act_on_resource_wrong_key( - default_client: AsyncClient, db_session, service: str, api_key: str -): - body = {"service": service, "organisation": "Test Org", "resource": "test_resource"} +async def test_act_on_resource_wrong_key(default_client: AsyncClient, db_session, service: str, api_key: str): + body = { + "service": service, + "organisation": "Test Org", + "resource": "test_resource" + } headers = { "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": api_key, + "X-API-Key": api_key } - resp = await default_client.post( - "/iam/can_act_on_resource?action=read", json=body, headers=headers - ) + resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers) data = resp.json() assert resp.status_code == 401 @@ -57,12 +58,12 @@ async def test_act_on_resource_missing_key(default_client: AsyncClient): body = { "service": "Test Service", "organisation": "Test Org", - "resource": "test_resource", + "resource": "test_resource" } - headers = {"Authorization": "Bearer not_checked_when_auth_is_disabled"} - resp = await default_client.post( - "/iam/can_act_on_resource?action=read", json=body, headers=headers - ) + headers = { + "Authorization": "Bearer not_checked_when_auth_is_disabled" + } + resp = await default_client.post("/iam/can_act_on_resource?action=read", json=body, headers=headers) data = resp.json() assert resp.status_code == 401 @@ -81,17 +82,18 @@ async def test_act_on_resource_missing_key(default_client: AsyncClient): ], ) @pytest.mark.anyio -async def test_act_on_resource_endpoint_status_checks( - default_client: AsyncClient, service, org, resource, action, expected_status: int -): - body = {"service": service, "organisation": org, "resource": resource} +async def test_act_on_resource_endpoint_status_checks(default_client: AsyncClient, service, org, resource, action, + expected_status: int): + body = { + "service": service, + "organisation": org, + "resource": resource + } headers = { "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": "123456789", + "X-API-Key": "123456789" } - resp = await default_client.post( - f"/iam/can_act_on_resource?action={action}", json=body, headers=headers - ) + resp = await default_client.post(f"/iam/can_act_on_resource?action={action}", json=body, headers=headers) assert resp.status_code == expected_status @@ -106,23 +108,18 @@ async def test_act_on_resource_endpoint_status_checks( ], ) @pytest.mark.anyio -async def test_act_on_resource_logic( - default_client: AsyncClient, - db_session, - service, - org, - resource, - action, - expected_response: bool, -): - body = {"service": service, "organisation": org, "resource": resource} +async def test_act_on_resource_logic(default_client: AsyncClient, db_session, service, org, resource, action, + expected_response: bool): + body = { + "service": service, + "organisation": org, + "resource": resource + } headers = { "Authorization": "Bearer not_checked_when_auth_is_disabled", - "X-API-Key": "123456789", + "X-API-Key": "123456789" } - resp = await default_client.post( - f"/iam/can_act_on_resource?action={action}", json=body, headers=headers - ) + resp = await default_client.post(f"/iam/can_act_on_resource?action={action}", json=body, headers=headers) data = resp.json() assert resp.status_code == 200 @@ -143,12 +140,11 @@ async def test_get_group_permissions_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["group_id", "org_id"]) + "query, expected_status", + generate_query_and_status(["group_id", "org_id"]) ) @pytest.mark.anyio -async def test_get_group_permissions_status_checks( - default_client: AsyncClient, db_session, query: str, expected_status: int -): +async def test_get_group_permissions_status_checks(default_client: AsyncClient, db_session, query: str, expected_status: int): resp = await default_client.get(f"/iam/group/permissions?{query}") assert resp.status_code == expected_status @@ -162,19 +158,8 @@ async def test_get_group_permissions_status_checks( ], ) @pytest.mark.anyio -async def test_get_group_permissions_mismatch( - default_client: AsyncClient, db_session, query: str -): - db_session.add( - Org( - name="Another Test Org", - root_user_id=1, - billing_contact_id=1, - owner_contact_id=2, - security_contact_id=3, - status="approved", - ) - ) +async def test_get_group_permissions_mismatch(default_client: AsyncClient, db_session, query: str): + db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved")) db_session.add(Group(name="Another Test Group", org_id=2)) db_session.flush() resp = await default_client.get(f"/iam/group/permissions?{query}") @@ -198,12 +183,11 @@ async def test_get_group_users_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["group_id", "org_id"]) + "query, expected_status", + generate_query_and_status(["group_id", "org_id"]) ) @pytest.mark.anyio -async def test_get_group_users_status_checks( - default_client: AsyncClient, query: str, expected_status: int -): +async def test_get_group_users_status_checks(default_client: AsyncClient, query: str, expected_status: int): resp = await default_client.get(f"/iam/group/users?{query}") assert resp.status_code == expected_status @@ -217,19 +201,8 @@ async def test_get_group_users_status_checks( ], ) @pytest.mark.anyio -async def test_get_group_users_mismatch( - default_client: AsyncClient, db_session, query: str -): - db_session.add( - Org( - name="Another Test Org", - root_user_id=1, - billing_contact_id=1, - owner_contact_id=2, - security_contact_id=3, - status="approved", - ) - ) +async def test_get_group_users_mismatch(default_client: AsyncClient, db_session, query: str): + db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved")) db_session.add(Group(name="Another Test Group", org_id=2)) db_session.flush() resp = await default_client.get(f"/iam/group/users?{query}") @@ -240,9 +213,7 @@ async def test_get_group_users_mismatch( @pytest.mark.anyio async def test_post_group_success(default_client: AsyncClient): - resp = await default_client.post( - "/iam/group", json={"name": "New Group", "organisation_id": 1} - ) + resp = await default_client.post("/iam/group", json={"name": "New Group", "organisation_id": 1}) data = resp.json() assert resp.status_code == 200 @@ -255,23 +226,10 @@ async def test_post_group_success(default_client: AsyncClient): @pytest.mark.parametrize( "body, expected_status", [ - ({"organisation_id": 1, "name": "Test Group"}, 409), - ( - {"organisation_id": 2, "name": "new group"}, - 404, - ), # Non-existent organisation, valid name - ( - {"organisation_id": "banana", "name": "new group"}, - 422, - ), # Invalid organisation ID, valid name - ( - {"organisation_id": "", "name": "new group"}, - 422, - ), # Blank organisation ID, valid name - ( - {"organisation_id": -1, "name": "new group"}, - 422, - ), # Negative organisation ID, valid name + ({"organisation_id": 2, "name": "new group"}, 404), # Non-existent organisation, valid name + ({"organisation_id": "banana", "name": "new group"}, 422), # Invalid organisation ID, valid name + ({"organisation_id": "", "name": "new group"}, 422), # Blank organisation ID, valid name + ({"organisation_id": -1, "name": "new group"}, 422), # Negative organisation ID, valid name ({"name": 1}, 422), # Only name ({}, 422), # Blank body ({"organisation_id": "", "name": ""}, 422), # Both blank @@ -282,9 +240,7 @@ async def test_post_group_success(default_client: AsyncClient): ], ) @pytest.mark.anyio -async def test_post_group_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int -): +async def test_post_group_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): resp = await default_client.post("/iam/group", json=body) assert resp.status_code == expected_status @@ -294,10 +250,7 @@ async def test_post_group_status_checks( async def test_put_group_perm_success(default_client: AsyncClient, db_session): db_session.add(Group(name="Test Group Two", org_id=1)) db_session.flush() - resp = await default_client.put( - "/iam/group/permission", - json={"permission_id": 1, "group_id": 2, "organisation_id": 1}, - ) + resp = await default_client.put("/iam/group/permission", json={"permission_id": 1, "group_id": 2, "organisation_id": 1}) data = resp.json() assert resp.status_code == 200 @@ -316,71 +269,36 @@ async def test_put_group_perm_success(default_client: AsyncClient, db_session): @pytest.mark.parametrize( "body, expected_status", [ - ( - {"organisation_id": 42, "group_id": 1, "permission_id": 1}, - 404, - ), # Non-existent organisation - ( - {"organisation_id": "banana", "group_id": 1, "permission_id": 1}, - 422, - ), # Invalid organisation ID - ( - {"organisation_id": "", "group_id": 1, "permission_id": 1}, - 422, - ), # Blank organisation ID - ( - {"organisation_id": -1, "group_id": 1, "permission_id": 1}, - 422, - ), # Negative organisation ID - ( - {"organisation_id": 1, "group_id": 42, "permission_id": 1}, - 404, - ), # Non-existent group - ( - {"organisation_id": 1, "group_id": "banana", "permission_id": 1}, - 422, - ), # Invalid group ID - ( - {"organisation_id": 1, "group_id": "", "permission_id": 1}, - 422, - ), # Blank group ID - ( - {"organisation_id": 1, "group_id": -1, "permission_id": 1}, - 422, - ), # Negative group ID - ( - {"organisation_id": 1, "group_id": 1, "permission_id": 42}, - 404, - ), # Non-existent permission - ( - {"organisation_id": 1, "group_id": 1, "permission_id": "banana"}, - 422, - ), # Invalid permission ID - ( - {"organisation_id": 1, "group_id": 1, "permission_id": ""}, - 422, - ), # Blank permission ID - ( - {"organisation_id": 1, "group_id": 1, "permission_id": -1}, - 422, - ), # Negative permission ID + ({"organisation_id": 42, "group_id": 1, "permission_id": 1}, 404), # Non-existent organisation + ({"organisation_id": "banana", "group_id": 1, "permission_id": 1}, 422), # Invalid organisation ID + ({"organisation_id": "", "group_id": 1, "permission_id": 1}, 422), # Blank organisation ID + ({"organisation_id": -1, "group_id": 1, "permission_id": 1}, 422), # Negative organisation ID + + ({"organisation_id": 1, "group_id": 42, "permission_id": 1}, 404), # Non-existent group + ({"organisation_id": 1, "group_id": "banana", "permission_id": 1}, 422), # Invalid group ID + ({"organisation_id": 1, "group_id": "", "permission_id": 1}, 422), # Blank group ID + ({"organisation_id": 1, "group_id": -1, "permission_id": 1}, 422), # Negative group ID + + ({"organisation_id": 1, "group_id": 1, "permission_id": 42}, 404), # Non-existent permission + ({"organisation_id": 1, "group_id": 1, "permission_id": "banana"}, 422), # Invalid permission ID + ({"organisation_id": 1, "group_id": 1, "permission_id": ""}, 422), # Blank permission ID + ({"organisation_id": 1, "group_id": 1, "permission_id": -1}, 422), # Negative permission ID + ({}, 422), # Blank body ({"permission_id": 1}, 422), # Only permission ({"organisation_id": 1}, 422), # Only organisation ({"group_id": 1}, 422), # Only group + ({"organisation_id": 1, "permission_id": 1}, 422), # Missing group ({"group_id": 1, "permission_id": 1}, 422), # Missing organisation ({"organisation_id": 1, "group_id": 1}, 422), # Missing permission - ( - {"organisation_id": 1, "group_id": 1, "permission_id": 1}, - 409, - ), # Permission already in group + + ({"organisation_id": 1, "group_id": 1, "permission_id": 1}, 409), # Permission already in group + ], ) @pytest.mark.anyio -async def test_put_group_perm_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int -): +async def test_put_group_perm_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): resp = await default_client.put("/iam/group/permission", json=body) assert resp.status_code == expected_status @@ -394,19 +312,8 @@ async def test_put_group_perm_status_checks( ], ) @pytest.mark.anyio -async def test_put_group_perm_mismatch( - default_client: AsyncClient, db_session, body: dict -): - db_session.add( - Org( - name="Another Test Org", - root_user_id=1, - billing_contact_id=1, - owner_contact_id=2, - security_contact_id=3, - status="approved", - ) - ) +async def test_put_group_perm_mismatch(default_client: AsyncClient, db_session, body: dict): + db_session.add(Org(name="Another Test Org", root_user_id=1, billing_contact_id=1, owner_contact_id=2, security_contact_id=3, status="approved")) db_session.add(Group(name="Another Test Group", org_id=2)) db_session.flush() resp = await default_client.put("/iam/group/permission", json=body) @@ -417,19 +324,10 @@ async def test_put_group_perm_mismatch( @pytest.mark.anyio async def test_put_group_user_success(default_client: AsyncClient, db_session): - db_session.add( - User( - email="user@test.org", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-1234", - ) - ) + db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.flush() - resp = await default_client.put( - "/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1} - ) + resp = await default_client.put("/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1}) data = resp.json() assert resp.status_code == 200 @@ -449,58 +347,34 @@ async def test_put_group_user_success(default_client: AsyncClient, db_session): @pytest.mark.parametrize( "body, expected_status", [ - ( - {"organisation_id": 42, "group_id": 1, "user_id": 1}, - 404, - ), # Non-existent organisation - ( - {"organisation_id": "banana", "group_id": 1, "user_id": 1}, - 422, - ), # Invalid organisation ID - ( - {"organisation_id": "", "group_id": 1, "user_id": 1}, - 422, - ), # Blank organisation ID - ( - {"organisation_id": -1, "group_id": 1, "user_id": 1}, - 422, - ), # Negative organisation ID - ( - {"organisation_id": 1, "group_id": 42, "user_id": 1}, - 404, - ), # Non-existent group - ( - {"organisation_id": 1, "group_id": "banana", "user_id": 1}, - 422, - ), # Invalid group ID + ({"organisation_id": 42, "group_id": 1, "user_id": 1}, 404), # Non-existent organisation + ({"organisation_id": "banana", "group_id": 1, "user_id": 1}, 422), # Invalid organisation ID + ({"organisation_id": "", "group_id": 1, "user_id": 1}, 422), # Blank organisation ID + ({"organisation_id": -1, "group_id": 1, "user_id": 1}, 422), # Negative organisation ID + + ({"organisation_id": 1, "group_id": 42, "user_id": 1}, 404), # Non-existent group + ({"organisation_id": 1, "group_id": "banana", "user_id": 1}, 422), # Invalid group ID ({"organisation_id": 1, "group_id": "", "user_id": 1}, 422), # Blank group ID - ( - {"organisation_id": 1, "group_id": -1, "user_id": 1}, - 422, - ), # Negative group ID - ( - {"organisation_id": 1, "group_id": 1, "user_id": 42}, - 404, - ), # Non-existent user - ( - {"organisation_id": 1, "group_id": 1, "user_id": "banana"}, - 422, - ), # Invalid user ID + ({"organisation_id": 1, "group_id": -1, "user_id": 1}, 422), # Negative group ID + + ({"organisation_id": 1, "group_id": 1, "user_id": 42}, 404), # Non-existent user + ({"organisation_id": 1, "group_id": 1, "user_id": "banana"}, 422), # Invalid user ID ({"organisation_id": 1, "group_id": 1, "user_id": ""}, 422), # Blank user ID ({"organisation_id": 1, "group_id": 1, "user_id": -1}, 422), # Negative user ID + ({}, 422), # Blank body ({"user_id": 1}, 422), # Only user ({"organisation_id": 1}, 422), # Only organisation ({"group_id": 1}, 422), # Only group + ({"organisation_id": 1, "user_id": 1}, 422), # Missing group ({"group_id": 1, "user_id": 1}, 422), # Missing organisation ({"organisation_id": 1, "group_id": 1}, 422), # Missing user + ], ) @pytest.mark.anyio -async def test_put_group_user_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int -): +async def test_put_group_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): resp = await default_client.put("/iam/group/user", json=body) assert resp.status_code == expected_status @@ -520,12 +394,11 @@ async def test_get_permissions_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["org_id"]) + "query, expected_status", + generate_query_and_status(["org_id"]) ) @pytest.mark.anyio -async def test_get_permissions_status_checks( - default_client: AsyncClient, query: str, expected_status: int -): +async def test_get_permissions_status_checks(default_client: AsyncClient, query: str, expected_status: int): resp = await default_client.get(f"/iam/permissions?{query}") assert resp.status_code == expected_status @@ -533,10 +406,7 @@ async def test_get_permissions_status_checks( @pytest.mark.anyio async def test_post_perm_success(default_client: AsyncClient, db_session): - resp = await default_client.post( - "/iam/permission", - json={"service_id": 1, "resource": "test_resource", "action": "create"}, - ) + resp = await default_client.post("/iam/permission", json={"service_id": 1, "resource": "test_resource", "action": "create"}) data = resp.json() assert resp.status_code == 200 @@ -547,74 +417,51 @@ async def test_post_perm_success(default_client: AsyncClient, db_session): @pytest.mark.parametrize( - "body, expected_status", - [ - ( - {"service_id": 1, "resource": "test_resource", "action": "read"}, - 409, - ), - # service_id tests - ( - {"service_id": 42, "resource": "test_resource", "action": "read"}, - 404, - ), # Non-existent service - ( - {"service_id": "banana", "resource": "test_resource", "action": "read"}, - 422, - ), # Invalid service ID - ( - {"service_id": "", "resource": "test_resource", "action": "read"}, - 422, - ), # Blank service ID - ( - {"service_id": -1, "resource": "test_resource", "action": "read"}, - 422, - ), # Negative service ID - # resource tests - ( - {"service_id": 1, "resource": 42, "action": "read"}, - 422, - ), # Invalid resource type - # action tests - ( - {"service_id": 1, "resource": "test_resource", "action": 42}, - 422, - ), # Invalid action type - # missing/partial body tests - ({}, 422), # Blank body - ({"resource": "test_resource"}, 422), # Only resource - ({"action": "read"}, 422), # Only action - ({"service_id": 1}, 422), # Only service - ({"service_id": 1, "action": "read"}, 422), # Missing resource - ({"service_id": 1, "resource": "test_resource"}, 422), # Missing action - ({"resource": "test_resource", "action": "read"}, 422), # Missing service - ], + "body, expected_status", + [ + # service_id tests + ({"service_id": 42, "resource": "test_resource", "action": "read"}, 404), # Non-existent service + ({"service_id": "banana", "resource": "test_resource", "action": "read"}, 422), # Invalid service ID + ({"service_id": "", "resource": "test_resource", "action": "read"}, 422), # Blank service ID + ({"service_id": -1, "resource": "test_resource", "action": "read"}, 422), # Negative service ID + + # resource tests + ({"service_id": 1, "resource": 42, "action": "read"}, 422), # Invalid resource type + + # action tests + ({"service_id": 1, "resource": "test_resource", "action": 42}, 422), # Invalid action type + + # missing/partial body tests + ({}, 422), # Blank body + ({"resource": "test_resource"}, 422), # Only resource + ({"action": "read"}, 422), # Only action + ({"service_id": 1}, 422), # Only service + + ({"service_id": 1, "action": "read"}, 422), # Missing resource + ({"service_id": 1, "resource": "test_resource"}, 422), # Missing action + ({"resource": "test_resource", "action": "read"}, 422), # Missing service + + ], ) @pytest.mark.anyio -async def test_post_perm_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int -): +async def test_post_perm_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): resp = await default_client.post("/iam/permission", json=body) assert resp.status_code == expected_status @pytest.mark.parametrize( - "body", - [ - { - "organisation_id": 1, - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - {"organisation_id": 1, "service_id": 1}, - {"organisation_id": 1, "resource": "test_resource"}, - {"organisation_id": 1, "action": "read"}, - {"organisation_id": 1, "service_id": 1, "action": "read"}, - {"organisation_id": 1, "service_id": 1, "resource": "test_resource"}, - {"organisation_id": 1, "resource": "test_resource", "action": "read"}, - ], + "body", + [ + {"organisation_id": 1, "service_id": 1, "resource": "test_resource", "action": "read"}, + + {"organisation_id": 1, "service_id": 1}, + {"organisation_id": 1, "resource": "test_resource"}, + {"organisation_id": 1, "action": "read"}, + {"organisation_id": 1, "service_id": 1, "action": "read"}, + {"organisation_id": 1, "service_id": 1, "resource": "test_resource"}, + {"organisation_id": 1, "resource": "test_resource", "action": "read"}, + ], ) @pytest.mark.anyio async def test_post_perm_search_success(default_client: AsyncClient, db_session, body): @@ -630,133 +477,31 @@ async def test_post_perm_search_success(default_client: AsyncClient, db_session, @pytest.mark.parametrize( - "body, expected_status", - [ - # organisation_id tests - ( - { - "organisation_id": 42, - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - 404, - ), # Non-existent organisation - ( - { - "organisation_id": "banana", - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Invalid organisation ID - ( - { - "organisation_id": "", - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Blank organisation ID - ( - { - "organisation_id": -1, - "service_id": 1, - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Negative organisation ID - # service_id tests - ( - { - "organisation_id": 1, - "service_id": "banana", - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Invalid service ID - ( - { - "organisation_id": 1, - "service_id": "", - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Blank service ID - ( - { - "organisation_id": 1, - "service_id": -1, - "resource": "test_resource", - "action": "read", - }, - 422, - ), # Negative service ID - # resource tests - ( - {"organisation_id": 1, "service_id": 1, "resource": 42, "action": "read"}, - 422, - ), # Invalid resource type - # action tests - ( - { - "organisation_id": 1, - "service_id": 1, - "resource": "test_resource", - "action": 42, - }, - 422, - ), # Invalid action type - # missing/partial body tests - ({}, 422), # Blank body - ], + "body, expected_status", + [ + # organisation_id tests + ({"organisation_id": 42, "service_id": 1, "resource": "test_resource", "action": "read"}, 404), # Non-existent organisation + ({"organisation_id": "banana", "service_id": 1, "resource": "test_resource", "action": "read"}, 422), # Invalid organisation ID + ({"organisation_id": "", "service_id": 1, "resource": "test_resource", "action": "read"}, 422), # Blank organisation ID + ({"organisation_id": -1, "service_id": 1, "resource": "test_resource", "action": "read"}, 422), # Negative organisation ID + + # service_id tests + ({"organisation_id": 1, "service_id": "banana", "resource": "test_resource", "action": "read"}, 422), # Invalid service ID + ({"organisation_id": 1, "service_id": "", "resource": "test_resource", "action": "read"}, 422), # Blank service ID + ({"organisation_id": 1, "service_id": -1, "resource": "test_resource", "action": "read"}, 422), # Negative service ID + + # resource tests + ({"organisation_id": 1, "service_id": 1, "resource": 42, "action": "read"}, 422), # Invalid resource type + + # action tests + ({"organisation_id": 1, "service_id": 1, "resource": "test_resource", "action": 42}, 422), # Invalid action type + + # missing/partial body tests + ({}, 422), # Blank body + ], ) @pytest.mark.anyio -async def test_post_perm_search_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int -): +async def test_post_perm_search_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): resp = await default_client.post("/iam/permissions/search", json=body) assert resp.status_code == expected_status - - -@pytest.mark.anyio -async def test_delete_group_permissions_success(default_client: AsyncClient): - resp = await default_client.delete( - "/iam/group/permissions?org_id=1&group_id=1&perm_id=1" - ) - data = resp.json() - - assert resp.status_code == 200 - assert "permissions" in data - assert isinstance(data["permissions"], list) - assert len(data["permissions"]) == 0 - assert "group" in data - assert data["group"]["id"] == 1 - assert data["group"]["name"] == "Test Group" - - -@pytest.mark.anyio -async def test_delete_permissions_success(default_client: AsyncClient): - resp = await default_client.delete("/iam/permission?perm_id=1") - - assert resp.status_code == 204 - - -@pytest.mark.anyio -async def test_delete_group_users_success(default_client: AsyncClient): - resp = await default_client.delete("/iam/group/user?org_id=1&group_id=1&user_id=1") - data = resp.json() - - assert resp.status_code == 200 - assert "users" in data - assert isinstance(data["users"], list) - assert len(data["users"]) == 0 - assert "group" in data - assert data["group"]["id"] == 1 - assert data["group"]["name"] == "Test Group" diff --git a/test/test_organisation.py b/test/test_organisation.py index a8708ea..4b60aa7 100644 --- a/test/test_organisation.py +++ b/test/test_organisation.py @@ -1,7 +1,6 @@ """ [DELETE] /org/ is not tested because the testing client cannot attach a body to a delete request. """ - import pytest from httpx import AsyncClient @@ -25,12 +24,11 @@ async def test_get_org_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["org_id"]) + "query, expected_status", + generate_query_and_status(["org_id"]) ) @pytest.mark.anyio -async def test_get_org_status_checks( - default_client: AsyncClient, query: str, expected_status: int -): +async def test_get_org_status_checks(default_client: AsyncClient, query: str, expected_status: int): resp = await default_client.get(f"/org?{query}") assert resp.status_code == expected_status @@ -49,40 +47,24 @@ async def test_post_org_success(default_client: AsyncClient): @pytest.mark.parametrize( "body, expected_status", [ - ({"name": "Test Org"}, 409), ({"name": 42}, 422), ({}, 422), ({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422), ], ) @pytest.mark.anyio -async def test_post_org_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int -): +async def test_post_org_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): resp = await default_client.post("/org", json=body) assert resp.status_code == expected_status @pytest.mark.anyio -async def test_patch_org_questionnaire_partial_success( - default_client: AsyncClient, db_session -): +async def test_patch_org_questionnaire_partial_success(default_client: AsyncClient, db_session): org_model = db_session.get(Organisation, 1) org_model.status = "partial" db_session.flush() - resp = await default_client.patch( - "/org/questionnaire", - json={ - "organisation_id": 1, - "intake_questionnaire": { - "question_one": "new answer one", - "question_two": None, - "question_three": None, - }, - "partial": True, - }, - ) + resp = await default_client.patch("/org/questionnaire", json={"organisation_id": 1, "intake_questionnaire": {"question_one": "new answer one", "question_two": None, "question_three": None}, "partial": True}) data = resp.json() assert resp.status_code == 200 @@ -101,56 +83,24 @@ async def test_patch_org_questionnaire_partial_success( ({"organisation_id": "Test Org"}, 422), ({"organisation_id": ""}, 422), ({}, 422), - ( - { - "organisation_id": "1", - "intake_questionnaire": {"question_one": 42}, - "partial": True, - }, - 422, - ), - ( - {"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}}, - 422, - ), - ( - { - "organisation_id": "1", - "intake_questionnaire": {"question_one": "valid"}, - "partial": 42, - }, - 422, - ), + ({"organisation_id": "1", "intake_questionnaire": {"question_one": 42}, "partial": True}, 422), + ({"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}}, 422), + ({"organisation_id": "1", "intake_questionnaire": {"question_one": "valid"}, "partial": 42}, 422), ], ) @pytest.mark.anyio -async def test_patch_questionnaire_partial_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int -): +async def test_patch_questionnaire_partial_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): resp = await default_client.patch("/org/questionnaire", json=body) assert resp.status_code == expected_status @pytest.mark.anyio -async def test_patch_org_questionnaire_submit_success( - default_client: AsyncClient, db_session -): +async def test_patch_org_questionnaire_submit_success(default_client: AsyncClient, db_session): org_model = db_session.get(Organisation, 1) org_model.status = "partial" db_session.flush() - resp = await default_client.patch( - "/org/questionnaire", - json={ - "organisation_id": 1, - "intake_questionnaire": { - "question_one": "new answer one", - "question_two": None, - "question_three": None, - }, - "partial": False, - }, - ) + resp = await default_client.patch("/org/questionnaire", json={"organisation_id": 1, "intake_questionnaire": {"question_one": "new answer one", "question_two": None, "question_three": None}, "partial": False}) data = resp.json() assert resp.status_code == 200 @@ -163,13 +113,12 @@ async def test_patch_org_questionnaire_submit_success( @pytest.mark.parametrize( - "status", ["partial", "submitted", "remediation", "approved", "rejected", "removed"] + "status", + ["partial", "submitted", "remediation", "approved", "rejected", "removed"] ) @pytest.mark.anyio async def test_patch_org_status_success(default_client: AsyncClient, status: str): - resp = await default_client.patch( - "/org/status", json={"organisation_id": 1, "status": status} - ) + resp = await default_client.patch("/org/status", json={"organisation_id": 1, "status": status}) data = resp.json() assert resp.status_code == 200 @@ -189,9 +138,7 @@ async def test_patch_org_status_success(default_client: AsyncClient, status: str ], ) @pytest.mark.anyio -async def test_patch_org_status_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int -): +async def test_patch_org_status_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): resp = await default_client.patch("/org/status", json=body) assert resp.status_code == expected_status @@ -214,12 +161,11 @@ async def test_get_org_users_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["org_id"]) + "query, expected_status", + generate_query_and_status(["org_id"]) ) @pytest.mark.anyio -async def test_get_org_users_status_checks( - default_client: AsyncClient, query: str, expected_status: int -): +async def test_get_org_users_status_checks(default_client: AsyncClient, query: str, expected_status: int): resp = await default_client.get(f"/org/users?{query}") assert resp.status_code == expected_status @@ -227,19 +173,10 @@ async def test_get_org_users_status_checks( @pytest.mark.anyio async def test_post_org_user_success(default_client: AsyncClient, db_session): - db_session.add( - User( - email="user@test.org", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-1234", - ) - ) + db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.flush() - resp = await default_client.post( - "/org/user", json={"organisation_id": 1, "user_id": 2} - ) + resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 2}) data = resp.json() assert resp.status_code == 200 @@ -260,17 +197,8 @@ async def test_post_org_user_success(default_client: AsyncClient, db_session): ], ) @pytest.mark.anyio -async def test_post_org_user_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session -): - db_session.add( - User( - email="user@test.org", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-1234", - ) - ) +async def test_post_org_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session): + db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.flush() resp = await default_client.post("/org/user", json=body) @@ -280,21 +208,12 @@ async def test_post_org_user_status_checks( @pytest.mark.anyio async def test_patch_org_root_user_success(default_client: AsyncClient, db_session): - db_session.add( - User( - email="user@test.org", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-1234", - ) - ) + db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.flush() db_session.add(OrgUsers(org_id=1, user_id=2)) db_session.flush() - resp = await default_client.patch( - "/org/root_user", json={"organisation_id": 1, "user_id": 2} - ) + resp = await default_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2}) data = resp.json() assert resp.status_code == 200 @@ -315,17 +234,8 @@ async def test_patch_org_root_user_success(default_client: AsyncClient, db_sessi ], ) @pytest.mark.anyio -async def test_patch_root_user_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session -): - db_session.add( - User( - email="user@test.org", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-1234", - ) - ) +async def test_patch_root_user_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int, db_session): + db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.flush() db_session.add(OrgUsers(org_id=1, user_id=2)) db_session.flush() @@ -337,19 +247,10 @@ async def test_patch_root_user_status_checks( @pytest.mark.anyio async def test_patch_org_root_user_non_member(default_client: AsyncClient, db_session): - db_session.add( - User( - email="user@test.org", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-1234", - ) - ) + db_session.add(User(email="user@test.org", first_name="User", last_name="Test", oidc_id="abcd-efgh-ijkl-1234")) db_session.flush() - resp = await default_client.patch( - "/org/root_user", json={"organisation_id": 1, "user_id": 2} - ) + resp = await default_client.patch("/org/root_user", json={"organisation_id": 1, "user_id": 2}) data = resp.json() assert resp.status_code == 422 @@ -368,23 +269,23 @@ async def test_get_org_groups_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["org_id"]) + "query, expected_status", + generate_query_and_status(["org_id"]) ) @pytest.mark.anyio -async def test_get_org_groups_status_checks( - default_client: AsyncClient, query: str, expected_status: int -): +async def test_get_org_groups_status_checks(default_client: AsyncClient, query: str, expected_status: int): resp = await default_client.get(f"/org/groups?{query}") assert resp.status_code == expected_status -@pytest.mark.parametrize("contact_type", ["billing", "security", "owner"]) +@pytest.mark.parametrize( + "contact_type", + ["billing", "security", "owner"] +) @pytest.mark.anyio async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str): - resp = await default_client.get( - f"/org/contact?org_id=1&contact_type={contact_type}" - ) + resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}") data = resp.json() assert resp.status_code == 200 @@ -426,9 +327,7 @@ async def test_get_org_contact_success(default_client: AsyncClient, contact_type ], ) @pytest.mark.anyio -async def test_get_org_contact_status_checks( - default_client: AsyncClient, query: str, expected_status: int -): +async def test_get_org_contact_status_checks(default_client: AsyncClient, query: str, expected_status: int): resp = await default_client.get(f"/org/contact?{query}") assert resp.status_code == expected_status @@ -449,16 +348,11 @@ async def test_get_org_contact_status_checks( ("address_region", "Glasgow City"), ("country_code", "GB"), ("postal_code", "G1 1AA"), - ], + ] ) @pytest.mark.anyio -async def test_patch_org_contact_success( - default_client: AsyncClient, key: str, value: str -): - resp = await default_client.patch( - "/org/contact", - json={"organisation_id": 1, "contact_type": "billing", key: value}, - ) +async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str): + resp = await default_client.patch("/org/contact", json={"organisation_id": 1, "contact_type": "billing", key: value}) data = resp.json() assert resp.status_code == 200 @@ -485,32 +379,7 @@ async def test_patch_org_contact_success( ], ) @pytest.mark.anyio -async def test_patch_org_contact_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int -): +async def test_patch_org_contact_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): resp = await default_client.patch("/org/contact", json=body) assert resp.status_code == expected_status - - -@pytest.mark.anyio -async def test_delete_org_success(default_client: AsyncClient): - resp = await default_client.delete("/org?org_id=1") - - assert resp.status_code == 204 - - -@pytest.mark.anyio -async def test_delete_org_users_success(db_session, default_client: AsyncClient): - db_session.add( - User( - email="user@test.org", - first_name="User", - last_name="Test", - oidc_id="abcd-efgh-ijkl-1234", - ) - ) - db_session.flush() - resp = await default_client.delete("/org/user?org_id=1&user_id=2") - - assert resp.status_code == 204 diff --git a/test/test_service.py b/test/test_service.py index bd8fe88..c13eb80 100644 --- a/test/test_service.py +++ b/test/test_service.py @@ -1,7 +1,6 @@ """ 409 on [POST]/service/ not tested because SQLite throws a different error than Postgres """ - import pytest from httpx import AsyncClient @@ -20,12 +19,11 @@ async def test_get_services_success(default_client: AsyncClient): @pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["org_id"]) + "query, expected_status", + generate_query_and_status(["org_id"]) ) @pytest.mark.anyio -async def test_get_services_status_checks( - default_client: AsyncClient, query: str, expected_status: int -): +async def test_get_services_status_checks(default_client: AsyncClient, query: str, expected_status: int): resp = await default_client.get(f"/service/?{query}") assert resp.status_code == expected_status @@ -46,15 +44,12 @@ async def test_post_service_success(default_client: AsyncClient): @pytest.mark.parametrize( "body, expected_status", [ - ({"name": "Test Service"}, 409), ({"name": 42}, 422), ({}, 422), ], ) @pytest.mark.anyio -async def test_post_service_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int -): +async def test_post_services_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): resp = await default_client.post("/service/", json=body) assert resp.status_code == expected_status @@ -82,16 +77,7 @@ async def test_patch_service_success(default_client: AsyncClient): ], ) @pytest.mark.anyio -async def test_patch_services_status_checks( - default_client: AsyncClient, body: dict[str, str], expected_status: int -): +async def test_patch_services_status_checks(default_client: AsyncClient, body: dict[str, str], expected_status: int): resp = await default_client.patch("/service/key", json=body) assert resp.status_code == expected_status - - -@pytest.mark.anyio -async def test_delete_service_success(default_client: AsyncClient): - resp = await default_client.delete("/service/?service_id=1") - - assert resp.status_code == 204 diff --git a/test/test_user.py b/test/test_user.py index 4eadc3c..a50e654 100644 --- a/test/test_user.py +++ b/test/test_user.py @@ -8,47 +8,38 @@ from httpx import AsyncClient from .conftest import generate_query_and_status - @pytest.mark.anyio async def test_get_self_db_success(default_client: AsyncClient): - resp = await default_client.get("/user/self/db") - data = resp.json() + resp = await default_client.get("/user/self/db") + data = resp.json() - assert resp.status_code == 200 - assert data["first_name"] == "Admin" - assert data["last_name"] == "Test" - assert data["email"] == "admin@test.com" - assert "organisations" in data - assert "groups" in data + assert resp.status_code == 200 + assert data["first_name"] == "Admin" + assert data["last_name"] == "Test" + assert data["email"] == "admin@test.com" + assert "organisations" in data + assert "groups" in data @pytest.mark.anyio async def test_get_user_success(default_client: AsyncClient): - resp = await default_client.get("/user/?user_id=1") - data = resp.json() + resp = await default_client.get("/user/?user_id=1") + data = resp.json() - assert resp.status_code == 200 - assert data["first_name"] == "Admin" - assert data["last_name"] == "Test" - assert data["email"] == "admin@test.com" - assert "organisations" in data - assert "groups" in data + assert resp.status_code == 200 + assert data["first_name"] == "Admin" + assert data["last_name"] == "Test" + assert data["email"] == "admin@test.com" + assert "organisations" in data + assert "groups" in data @pytest.mark.anyio @pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["user_id"]) + "query, expected_status", + generate_query_and_status(["user_id"]) ) -async def test_get_user_status_checks( - default_client: AsyncClient, query: str, expected_status: int -): - resp = await default_client.get(f"/user/?{query}") +async def test_get_user_status_checks(default_client: AsyncClient, query: str, expected_status: int): + resp = await default_client.get(f"/user/?{query}") - assert resp.status_code == expected_status - - -@pytest.mark.anyio -async def test_delete_user_success(default_client: AsyncClient): - resp = await default_client.delete("/user/?user_id=1") - - assert resp.status_code == 204 + assert resp.status_code == expected_status