diff --git a/alembic.ini b/alembic.ini index ffe7e34..6a48b68 100644 --- a/alembic.ini +++ b/alembic.ini @@ -5,7 +5,7 @@ # this is typically a path given in POSIX (e.g. forward slashes) # format, relative to the token %(here)s which refers to the location of this # ini file -script_location = %(here)s/.alembic +script_location = %(here)s/alembic # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s # Uncomment the line below if you want the files to be prepended with date and time diff --git a/alembic/versions/2026-06-22_mapped_columns.py b/alembic/versions/2026-06-22_mapped_columns.py new file mode 100644 index 0000000..5410745 --- /dev/null +++ b/alembic/versions/2026-06-22_mapped_columns.py @@ -0,0 +1,100 @@ +"""mapped columns + +Revision ID: 869d48618a1c +Revises: 85edbf9a176c +Create Date: 2026-06-22 11:18:34.592199 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '869d48618a1c' +down_revision: Union[str, Sequence[str], None] = '85edbf9a176c' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('group', 'org_id', + existing_type=sa.INTEGER(), + nullable=False) + op.alter_column('organisation', 'name', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('organisation', 'status', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('organisation', 'root_user_id', + existing_type=sa.INTEGER(), + nullable=False) + op.drop_constraint(op.f('organisation_name_key'), 'organisation', type_='unique') + op.alter_column('permission', 'service_id', + existing_type=sa.INTEGER(), + nullable=False) + op.alter_column('service', 'name', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('service', 'api_key', + existing_type=sa.VARCHAR(), + nullable=False) + op.drop_constraint(op.f('service_api_key_key'), 'service', type_='unique') + op.alter_column('user', 'email', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('user', 'first_name', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('user', 'last_name', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('user', 'oidc_id', + existing_type=sa.VARCHAR(), + nullable=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('user', 'oidc_id', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('user', 'last_name', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('user', 'first_name', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('user', 'email', + existing_type=sa.VARCHAR(), + nullable=True) + op.create_unique_constraint(op.f('service_api_key_key'), 'service', ['api_key'], postgresql_nulls_not_distinct=False) + op.alter_column('service', 'api_key', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('service', 'name', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('permission', 'service_id', + existing_type=sa.INTEGER(), + nullable=True) + op.create_unique_constraint(op.f('organisation_name_key'), 'organisation', ['name'], postgresql_nulls_not_distinct=False) + op.alter_column('organisation', 'root_user_id', + existing_type=sa.INTEGER(), + nullable=True) + op.alter_column('organisation', 'status', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('organisation', 'name', + existing_type=sa.VARCHAR(), + nullable=True) + op.alter_column('group', 'org_id', + existing_type=sa.INTEGER(), + nullable=True) + # ### end Alembic commands ### diff --git a/pyproject.toml b/pyproject.toml index 0ae2891..08a9729 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,3 +43,6 @@ dev = [ "pytest>=9.0.3", "ty>=0.0.44,<0.0.45", ] + +[tool.ty.src] +exclude = ["alembic"] diff --git a/src/auth/dependencies.py b/src/auth/dependencies.py index aaabdbb..7cf4e9f 100644 --- a/src/auth/dependencies.py +++ b/src/auth/dependencies.py @@ -34,6 +34,33 @@ async def org_query_user_claims( org_query_user_claims_dependency = Annotated[bool, Depends(org_query_user_claims)] +def get_super_admin_list(): + return [] + + +def empty_su_list(): + return [] + + +def testing_su_list(): + return ["admin@test.com"] + + +su_list_dependency = Annotated[list[str | None], Depends(get_super_admin_list)] + + +async def user_model_super_admin( + user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency +): + if user_model.email in super_admin_emails: + return user_model + + raise ForbiddenException(message="Must be super admin") + + +super_admin_dependency = Annotated[User, Depends(user_model_super_admin)] + + async def org_query_root_claims( user_model: user_model_claims_dependency, org_model: org_model_query_dependency, @@ -54,9 +81,7 @@ async def org_query_root_claims( raise ForbiddenException(message="Must be the org's root user") -org_model_root_claim_query_dependency = Annotated[ - type[Org], Depends(org_query_root_claims) -] +org_model_root_claim_query_dependency = Annotated[Org, Depends(org_query_root_claims)] async def org_body_root_claims( @@ -79,33 +104,4 @@ async def org_body_root_claims( raise ForbiddenException(message="Must be the org's root user") -org_model_root_claim_body_dependency = Annotated[ - type[Org], Depends(org_body_root_claims) -] - - -def get_super_admin_list(): - return [] - - -def empty_su_list(): - return [] - - -def testing_su_list(): - return ["admin@test.com"] - - -su_list_dependency = Annotated[list[User], Depends(get_super_admin_list)] - - -async def user_model_super_admin( - user_model: user_model_claims_dependency, super_admin_emails: su_list_dependency -): - if user_model.email in super_admin_emails: - return user_model - - raise ForbiddenException(message="Must be super admin") - - -super_admin_dependency = Annotated[type[User], Depends(user_model_super_admin)] +org_model_root_claim_body_dependency = Annotated[Org, Depends(org_body_root_claims)] diff --git a/src/auth/service.py b/src/auth/service.py index aa1b060..1b90b8c 100644 --- a/src/auth/service.py +++ b/src/auth/service.py @@ -43,14 +43,11 @@ async def get_current_user( key_response = requests.get(jwks_uri) jwk_keys = KeySet.import_key_set(key_response.json()) - claims_options = { - "exp": {"essential": True}, - "iss": {"essential": True, "value": auth_settings.OIDC_ISSUER}, - } - token = jwt.decode(oidc_auth_string.replace("Bearer ", ""), jwk_keys) - claims_requests = jwt.JWTClaimsRegistry(**claims_options) + claims_requests = jwt.JWTClaimsRegistry( + exp={"essential": True}, iss={"essential": True, "value": auth_settings.OIDC_ISSUER} + ) try: claims_requests.validate(token.claims) diff --git a/src/config.py b/src/config.py index ddce0c8..e8b9ec2 100644 --- a/src/config.py +++ b/src/config.py @@ -24,7 +24,7 @@ class CustomBaseSettings(BaseSettings): class Config(CustomBaseSettings): APP_VERSION: str = "0.1" ENVIRONMENT: Environment = Environment.PRODUCTION - SECRET_KEY: SecretStr = "" + SECRET_KEY: SecretStr = SecretStr("") DISABLE_AUTH: bool = False CORS_ORIGINS: list[str] = ["*"] @@ -34,7 +34,7 @@ class Config(CustomBaseSettings): DATABASE_NAME: str = "fastapi-exp" DATABASE_PORT: str = "5432" DATABASE_HOSTNAME: str = "localhost" - DATABASE_CREDENTIALS: SecretStr = ":" + DATABASE_CREDENTIALS: SecretStr = SecretStr(":") settings = Config() @@ -44,9 +44,9 @@ DATABASE_PORT = settings.DATABASE_PORT DATABASE_HOSTNAME = settings.DATABASE_HOSTNAME DATABASE_CREDENTIALS = settings.DATABASE_CREDENTIALS.get_secret_value() # this will support special chars for credentials -_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str( - DATABASE_CREDENTIALS -).split(":") +_DATABASE_CREDENTIAL_USER, _DATABASE_CREDENTIAL_PASSWORD = str(DATABASE_CREDENTIALS).split( + ":" +) _QUOTED_DATABASE_PASSWORD = parse.quote_plus(str(_DATABASE_CREDENTIAL_PASSWORD)) SQLALCHEMY_DATABASE_URI = SecretStr( diff --git a/src/constants.py b/src/constants.py index b0725bf..f9237e4 100644 --- a/src/constants.py +++ b/src/constants.py @@ -13,10 +13,10 @@ class Environment(StrEnum): Enumeration of environments. Attributes: - LOCAL (str): Application is running locally - TESTING (str): Application is running in testing mode - STAGING (str): Application is running in staging mode (ie not testing) - PRODUCTION (str): Application is running in production mode + LOCAL (str): Application is running locally + TESTING (str): Application is running in testing mode + STAGING (str): Application is running in staging mode (ie not testing) + PRODUCTION (str): Application is running in production mode """ LOCAL = auto() diff --git a/src/contact/models.py b/src/contact/models.py index 9d0fdba..ca359cd 100644 --- a/src/contact/models.py +++ b/src/contact/models.py @@ -5,6 +5,7 @@ Models: - Contact: id[pk], email, first_name, last_name, phonenumber, vat_number street_address, street_address_line_2, post_office_box_number, address_locality, country_code, address_region, postal_code """ + from sqlalchemy import ForeignKey from sqlalchemy.orm import mapped_column, Mapped @@ -15,19 +16,19 @@ class Contact(CustomBase): __tablename__ = "contact" id: Mapped[int] = mapped_column(primary_key=True) - email: Mapped[str] - first_name: Mapped[str] - last_name: Mapped[str] - phonenumber: Mapped[str] - vat_number: Mapped[str | None] = mapped_column(default=None) + email: Mapped[str] = mapped_column(default=None, nullable=True) + first_name: Mapped[str] = mapped_column(default=None, nullable=True) + last_name: Mapped[str] = mapped_column(default=None, nullable=True) + phonenumber: Mapped[str] = mapped_column(default=None, nullable=True) + vat_number: Mapped[str | None] = mapped_column(default=None, nullable=True) - street_address : Mapped[str] - street_address_line_2 : Mapped[str] - post_office_box_number: Mapped[str | None] = mapped_column(default=None) - locality : Mapped[str] # Ie City - country_code : Mapped[str] # Eg GB - address_region: Mapped[str | None] = mapped_column(default=None) - postal_code : Mapped[str] + street_address: Mapped[str] = mapped_column(default=None, nullable=True) + street_address_line_2: Mapped[str] = mapped_column(default=None, nullable=True) + post_office_box_number: Mapped[str | None] = mapped_column(default=None, nullable=True) + locality: Mapped[str] = mapped_column(default=None, nullable=True) # Ie City + country_code: Mapped[str] = mapped_column(default=None, nullable=True) # Eg GB + address_region: Mapped[str | None] = mapped_column(default=None, nullable=True) + postal_code: Mapped[str] = mapped_column(default=None, nullable=True) org_id: Mapped[int] = mapped_column( ForeignKey("organisation.id", ondelete="CASCADE"), nullable=False diff --git a/src/database.py b/src/database.py index 8038905..fb29a41 100644 --- a/src/database.py +++ b/src/database.py @@ -5,6 +5,7 @@ Exports: - db_dependency - Base (sqlalchemy base model) """ + from typing import Annotated from sqlalchemy import create_engine, StaticPool from sqlalchemy.orm import sessionmaker, Session diff --git a/src/iam/dependencies.py b/src/iam/dependencies.py index 6ebc4e5..14fb4ad 100644 --- a/src/iam/dependencies.py +++ b/src/iam/dependencies.py @@ -20,7 +20,7 @@ from src.iam.schemas import GroupIDMixin, PermIDMixin def get_group_model_query( db: db_dependency, group_id: Annotated[int, Query(gt=0)] -) -> type[Group]: +) -> Group: group_model = db.get(Group, group_id) if group_model is None: raise GroupNotFoundException(group_id) @@ -28,12 +28,12 @@ def get_group_model_query( return group_model -group_model_query_dependency = Annotated[type[Group], Depends(get_group_model_query)] +group_model_query_dependency = Annotated[Group, Depends(get_group_model_query)] def get_group_model_body( db: db_dependency, request_model: Optional[GroupIDMixin] = None -) -> type[Group]: +) -> Group: group_id = getattr(request_model, "group_id", None) if group_id is None: raise GroupNotFoundException() @@ -44,12 +44,12 @@ def get_group_model_body( return group_model -group_model_body_dependency = Annotated[type[Group], Depends(get_group_model_body)] +group_model_body_dependency = Annotated[Group, Depends(get_group_model_body)] def get_perm_model_body( db: db_dependency, request_model: Optional[PermIDMixin] = None -) -> type[Permission]: +) -> Permission: perm_id = getattr(request_model, "permission_id", None) if perm_id is None: raise PermNotFoundException @@ -60,12 +60,12 @@ def get_perm_model_body( return perm_model -perm_model_body_dependency = Annotated[type[Permission], Depends(get_perm_model_body)] +perm_model_body_dependency = Annotated[Permission, Depends(get_perm_model_body)] def get_perm_model_query( db: db_dependency, perm_id: Annotated[int, Query(gt=0)] -) -> type[Permission]: +) -> Permission: perm_model = db.get(Permission, perm_id) if perm_model is None: raise PermNotFoundException(perm_id) @@ -73,4 +73,4 @@ def get_perm_model_query( return perm_model -perm_model_query_dependency = Annotated[type[Permission], Depends(get_perm_model_query)] +perm_model_query_dependency = Annotated[Permission, Depends(get_perm_model_query)] diff --git a/src/iam/models.py b/src/iam/models.py index 1860264..1f6d9ba 100644 --- a/src/iam/models.py +++ b/src/iam/models.py @@ -43,7 +43,9 @@ class Permission(CustomBase): ) service_rel = relationship( - "Service", back_populates="permission_rel", foreign_keys="Permission.service_id" + "Service", + back_populates="permission_rel", + foreign_keys="Permission.service_id", ) group_rel = relationship( diff --git a/src/iam/router.py b/src/iam/router.py index fbe1c59..af783a1 100644 --- a/src/iam/router.py +++ b/src/iam/router.py @@ -207,9 +207,7 @@ async def can_act_on_resource( "content": { "application/json": { "examples": { - "db_id": { - "summary": "User not found in db when checking claims." - }, + "db_id": {"summary": "User not found in db when checking claims."}, "user_model": {"summary": "User model not found in db."}, "org_model": {"summary": "Org model not found in db."}, "group_model": {"summary": "Group model not found in db."}, @@ -268,9 +266,7 @@ async def get_group_users( status_code=status.HTTP_201_CREATED, response_model=IAMPostGroupResponse, responses={ - status.HTTP_409_CONFLICT: { - "description": "Group with this name already exists" - }, + status.HTTP_409_CONFLICT: {"description": "Group with this name already exists"}, }, ) async def create_group( @@ -568,9 +564,7 @@ async def permissions_search( ) if not (request_model.resource is None or request_model.resource == ""): - permission_query = permission_query.filter( - Perm.resource == request_model.resource - ) + permission_query = permission_query.filter(Perm.resource == request_model.resource) if not (request_model.action is None or request_model.action == ""): permission_query = permission_query.filter(Perm.action == request_model.action) @@ -633,9 +627,7 @@ async def invitation( response_model=IAMPutGroupInvitationAcceptResponse, responses={ status.HTTP_404_NOT_FOUND: {"description": "User|Org|Group not found"}, - status.HTTP_403_FORBIDDEN: { - "description": "Group and organisation do not match" - }, + status.HTTP_403_FORBIDDEN: {"description": "Group and organisation do not match"}, status.HTTP_409_CONFLICT: {"description": "User is already in the group"}, }, ) @@ -647,9 +639,7 @@ async def accept_invitation( """ Accepts an invitation to join an org's group """ - email_claims = await verify_email_token( - token=request_model.jwt, user_model=user_model - ) + email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model) org_model = db.get(Org, email_claims["org_id"]) if org_model is None: diff --git a/src/models.py b/src/models.py index 4f2e7b8..f2467de 100644 --- a/src/models.py +++ b/src/models.py @@ -1,6 +1,7 @@ """ Global database models """ + from datetime import datetime from typing import Any @@ -13,4 +14,3 @@ class CustomBase(DeclarativeBase): datetime: DateTime(timezone=True), dict[str, Any]: JSON, } - diff --git a/src/organisation/constants.py b/src/organisation/constants.py index 8d956ca..94bcc2d 100644 --- a/src/organisation/constants.py +++ b/src/organisation/constants.py @@ -14,12 +14,12 @@ class Status(StrEnum): Enumeration of organisation statuses. Attributes: - PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted. - SUBMITTED (str): Questionnaire submitted but not approved. - REMEDIATION (str): Questionnaire submitted but requires revisions. - APPROVED (str): Questionnaire has been approved by an admin. - REJECTED (str): Questionnaire has been rejected by an admin. - REMOVED (str): Organisation has been removed. + PARTIAL(str): Organisation has been created but questionnaire hasn't been submitted. + SUBMITTED (str): Questionnaire submitted but not approved. + REMEDIATION (str): Questionnaire submitted but requires revisions. + APPROVED (str): Questionnaire has been approved by an admin. + REJECTED (str): Questionnaire has been rejected by an admin. + REMOVED (str): Organisation has been removed. """ PARTIAL = auto() @@ -47,9 +47,9 @@ class ContactType(StrEnum): Enumeration of organisation contact types. Attributes: - BILLING(str): Billing contact. - SECURITY (str): Security contact. - OWNER (str): Owner contact. + BILLING(str): Billing contact. + SECURITY (str): Security contact. + OWNER (str): Owner contact. """ BILLING = auto() diff --git a/src/organisation/dependencies.py b/src/organisation/dependencies.py index 4c22685..1ecdca8 100644 --- a/src/organisation/dependencies.py +++ b/src/organisation/dependencies.py @@ -17,19 +17,17 @@ from src.organisation.models import Organisation as Org from src.organisation.exceptions import OrgNotFoundException -def get_org_model_query( - db: db_dependency, org_id: Annotated[int, Query(gt=0)] -) -> type[Org]: +def get_org_model_query(db: db_dependency, org_id: Annotated[int, Query(gt=0)]) -> Org: org_model = db.get(Org, org_id) if org_model is None: raise OrgNotFoundException(org_id) return org_model -org_model_query_dependency = Annotated[type[Org], Depends(get_org_model_query)] +org_model_query_dependency = Annotated[Org, Depends(get_org_model_query)] -def get_org_model_body(db: db_dependency, request_model: OrgIDMixin) -> type[Org]: +def get_org_model_body(db: db_dependency, request_model: OrgIDMixin) -> Org: org_id: Optional[int] = getattr(request_model, "organisation_id", None) if org_id is None: raise OrgNotFoundException() @@ -41,4 +39,4 @@ def get_org_model_body(db: db_dependency, request_model: OrgIDMixin) -> type[Org return org_model -org_model_body_dependency = Annotated[type[Org], Depends(get_org_model_body)] +org_model_body_dependency = Annotated[Org, Depends(get_org_model_body)] diff --git a/src/organisation/models.py b/src/organisation/models.py index 3813d69..e6f5acf 100644 --- a/src/organisation/models.py +++ b/src/organisation/models.py @@ -13,6 +13,7 @@ Models: - owner_contact_rel: ORM relationship to Contact with owner_contact FK - OrgUsers: org_id[FK][PK], user_id[FK][PK] """ + from typing import Any from sqlalchemy import ForeignKey @@ -30,15 +31,17 @@ class Organisation(CustomBase): intake_questionnaire: Mapped[dict[str, Any] | None] root_user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) - billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id")) - security_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id")) - owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id")) - - user_rel = relationship( - "User", secondary="orgusers", back_populates="organisation_rel" + billing_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) + security_contact_id: Mapped[int] = mapped_column( + ForeignKey("contact.id"), nullable=True ) + owner_contact_id: Mapped[int] = mapped_column(ForeignKey("contact.id"), nullable=True) - group_rel = relationship("Group", back_populates="org_rel") + user_rel = relationship("User", secondary="orgusers", back_populates="organisation_rel") + + group_rel = relationship( + "Group", back_populates="org_rel", cascade="all, delete-orphan" + ) root_user_rel = relationship("User", foreign_keys="Organisation.root_user_id") billing_contact_rel = relationship( @@ -56,8 +59,9 @@ class Organisation(CustomBase): ) @property - def root_user_email(self): - return self.root_user_rel.email if self.root_user_rel else None + def root_user_email(self) -> str: + return self.root_user_rel.email if self.root_user_rel else "" + class OrgUsers(CustomBase): __tablename__ = "orgusers" @@ -65,4 +69,6 @@ class OrgUsers(CustomBase): org_id: Mapped[int] = mapped_column( ForeignKey("organisation.id", ondelete="CASCADE"), primary_key=True ) - user_id: Mapped[int] = mapped_column(ForeignKey("user.id", ondelete="CASCADE"), primary_key=True) + user_id: Mapped[int] = mapped_column( + ForeignKey("user.id", ondelete="CASCADE"), primary_key=True + ) diff --git a/src/organisation/router.py b/src/organisation/router.py index 129e48c..d968e8e 100644 --- a/src/organisation/router.py +++ b/src/organisation/router.py @@ -133,9 +133,7 @@ async def get_org_by_id( response_model=OrgPostOrgResponse, responses={ status.HTTP_201_CREATED: {"description": "Successfully created organisation."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, status.HTTP_401_UNAUTHORIZED: { "description": "User must be logged in with OIDC to create organisation." }, @@ -169,6 +167,7 @@ async def create_org( org_model = Org( name=request_model.name, intake_questionnaire=intake_questionnaire.model_dump(mode="json"), + root_user_id=user_model.id, ) org_model.status = "partial" @@ -181,13 +180,10 @@ async def create_org( isinstance(e.orig, UniqueViolation) # Postgres unique violation or "UNIQUE constraint failed" in str(e.orig) # SQLite unique violation ): - raise ConflictException( - message="Organisation with this name already exists" - ) + raise ConflictException(message="Organisation with this name already exists") raise # Adds currently logged-in user to org users list and sets them as root_user org_model.user_rel.append(user_model) - org_model.root_user_rel = user_model background_tasks.add_task( assign_defaults, db, org_id=org_model.id, user_id=user_model.id @@ -214,9 +210,7 @@ async def create_org( response_model=OrgPatchQuestionnaireResponse, responses={ status.HTTP_200_OK: {"description": "Successfully updated questionnaire."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, status.HTTP_403_FORBIDDEN: { "description": "Not authorised. Must be org root user." }, @@ -234,12 +228,22 @@ async def update_questionnaire( """ org_status = StatusEnum(org_model.status) if not org_status.is_pre_submission: - raise ForbiddenException( - "Questionnaire may only be modified prior to submission." - ) - update_data = request_model.intake_questionnaire.model_dump(exclude_none=True) + raise ForbiddenException("Questionnaire may only be modified prior to submission.") + update_data: dict = request_model.intake_questionnaire.model_dump(exclude_none=True) questionnaire = org_model.intake_questionnaire - questions_model = QuestionnaireQuestionsVersion0(**questionnaire["questions"]) + if questionnaire is None: + questionnaire_questions = QuestionnaireQuestionsVersion0().model_dump() + + questionnaire_metadata = QuestionnaireMetadata(version=0, submission_date=None) + + questionnaire = Questionnaire( + metadata=questionnaire_metadata, + questions=questionnaire_questions, + ).model_dump() + + questions_model = QuestionnaireQuestionsVersion0(**questionnaire["questions"]) + else: + questions_model = QuestionnaireQuestionsVersion0(**questionnaire["questions"]) for key, value in update_data.items(): if hasattr(questions_model, key): setattr(questions_model, key, value) @@ -271,15 +275,9 @@ async def update_questionnaire( status_code=status.HTTP_200_OK, response_model=OrgPatchStatusResponse, responses={ - status.HTTP_200_OK: { - "description": "Successfully updated organisation status." - }, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be super admin." - }, + status.HTTP_200_OK: {"description": "Successfully updated organisation status."}, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, + status.HTTP_403_FORBIDDEN: {"description": "Not authorised. Must be super admin."}, }, ) async def update_status( @@ -329,15 +327,11 @@ async def get_users(org_model: org_model_root_claim_query_dependency): status_code=status.HTTP_200_OK, response_model=OrgPostUserResponse, responses={ - status.HTTP_200_OK: { - "description": "Successfully added user to the organisation." - }, + status.HTTP_200_OK: {"description": "Successfully added user to the organisation."}, status.HTTP_403_FORBIDDEN: { "description": "Not authorised. Must be org root user." }, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, status.HTTP_409_CONFLICT: { "description": "User is already a member of the organisation." }, @@ -378,12 +372,8 @@ async def add_user_to_org( summary="Delete organisation from the hub.", status_code=status.HTTP_204_NO_CONTENT, responses={ - status.HTTP_204_NO_CONTENT: { - "description": "Successfully deleted organisation." - }, - status.HTTP_403_FORBIDDEN: { - "description": "Not authorised. Must be super admin." - }, + status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."}, + status.HTTP_403_FORBIDDEN: {"description": "Not authorised. Must be super admin."}, status.HTTP_422_UNPROCESSABLE_CONTENT: { "description": "Org ID missing or invalid." }, @@ -406,9 +396,7 @@ async def delete_organisation_by_id( summary="Delete organisation from the hub as root user before it has been approved.", status_code=status.HTTP_204_NO_CONTENT, responses={ - status.HTTP_204_NO_CONTENT: { - "description": "Successfully deleted organisation." - }, + status.HTTP_204_NO_CONTENT: {"description": "Successfully deleted organisation."}, status.HTTP_422_UNPROCESSABLE_CONTENT: { "description": "Unprocessable content.", "content": { @@ -452,9 +440,7 @@ async def delete_organisation_by_id( "content": { "application/json": { "examples": { - "db_id": { - "summary": "User not found in db when checking claims." - }, + "db_id": {"summary": "User not found in db when checking claims."}, "user_model": {"summary": "User model not found in db."}, "org_model": {"summary": "Org model not found in db."}, } @@ -472,9 +458,7 @@ async def delete_preapproved_organisation_by_id( """ org_status = StatusEnum(org_model.status) if not org_status.is_pre_approval: - raise ForbiddenException( - message="Organisation is no longer in pre-approval state." - ) + raise ForbiddenException(message="Organisation is no longer in pre-approval state.") db.delete(org_model) db.commit() @@ -487,9 +471,7 @@ async def delete_preapproved_organisation_by_id( response_model=OrgPatchRootResponse, responses={ status.HTTP_200_OK: {"description": "Successfully updated root user."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, status.HTTP_401_UNAUTHORIZED: { "description": "Not authorised. Must be super admin." }, @@ -539,9 +521,7 @@ async def get_org_groups(org_model: org_model_root_claim_query_dependency): """ return { "organisation": org_model, - "groups": [ - {"id": group.id, "name": group.name} for group in org_model.group_rel - ], + "groups": [{"id": group.id, "name": group.name} for group in org_model.group_rel], } @@ -554,9 +534,7 @@ async def get_org_groups(org_model: org_model_root_claim_query_dependency): status.HTTP_403_FORBIDDEN: { "description": "Not authorised. Must be org root user." }, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, }, ) async def remove_user_from_org( @@ -581,9 +559,7 @@ async def remove_user_from_org( response_model=OrgGetContactResponse, responses={ status.HTTP_200_OK: {"description": "Successful retrieval of contact."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, status.HTTP_403_FORBIDDEN: { "description": "Not authorised. Must be org root user." }, @@ -626,9 +602,7 @@ async def get_contact( response_model=OrgPatchContactResponse, responses={ status.HTTP_200_OK: {"description": "Successfully updated contact."}, - status.HTTP_422_UNPROCESSABLE_CONTENT: { - "description": "Invalid data in request." - }, + status.HTTP_422_UNPROCESSABLE_CONTENT: {"description": "Invalid data in request."}, status.HTTP_403_FORBIDDEN: { "description": "Not authorised. Must be org root user." }, diff --git a/src/organisation/service.py b/src/organisation/service.py index 4fc2f03..4fc04f6 100644 --- a/src/organisation/service.py +++ b/src/organisation/service.py @@ -3,7 +3,6 @@ Reusable business logic functions for the organisation module """ from sqlalchemy.orm import Session -from typing import cast from src.iam.service import assign_default_group from src.organisation.models import Organisation as Org @@ -50,9 +49,6 @@ async def assign_defaults( print("User not found while adding defaults") return - org_model = cast(Org, org_model) - user_model = cast(User, user_model) - await add_default_org_permissions(db, org_model, default_org_permissions) await assign_default_group( db=db, diff --git a/src/service/dependencies.py b/src/service/dependencies.py index 9792f26..bf6b314 100644 --- a/src/service/dependencies.py +++ b/src/service/dependencies.py @@ -26,9 +26,7 @@ async def get_service_model_query( return service_model -service_model_query_dependency = Annotated[ - type[Service], Depends(get_service_model_query) -] +service_model_query_dependency = Annotated[Service, Depends(get_service_model_query)] async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixin): @@ -39,6 +37,4 @@ async def get_service_model_body(db: db_dependency, request_model: ServiceIDMixi return service_model -service_model_body_dependency = Annotated[ - type[Service], Depends(get_service_model_body) -] +service_model_body_dependency = Annotated[Service, Depends(get_service_model_body)] diff --git a/src/service/models.py b/src/service/models.py index 206b615..63719a6 100644 --- a/src/service/models.py +++ b/src/service/models.py @@ -18,4 +18,6 @@ class Service(CustomBase): name: Mapped[str] = mapped_column(unique=True) api_key: Mapped[str] - permission_rel = relationship("Permission", back_populates="service_rel") + permission_rel = relationship( + "Permission", back_populates="service_rel", cascade="all, delete-orphan" + ) diff --git a/src/service/router.py b/src/service/router.py index 22e8594..143fd7a 100644 --- a/src/service/router.py +++ b/src/service/router.py @@ -95,9 +95,7 @@ async def get_all_services( responses={ status.HTTP_200_OK: {"description": "Successfully registered a new service"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, - status.HTTP_409_CONFLICT: { - "description": "Service with this name already exists" - }, + status.HTTP_409_CONFLICT: {"description": "Service with this name already exists"}, }, ) async def register_service( @@ -159,9 +157,7 @@ async def regenerate_api_key( summary="Remove a service.", status_code=status.HTTP_204_NO_CONTENT, responses={ - status.HTTP_204_NO_CONTENT: { - "description": "Successfully removed service from db" - }, + status.HTTP_204_NO_CONTENT: {"description": "Successfully removed service from db"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, }, ) diff --git a/src/user/dependencies.py b/src/user/dependencies.py index b2f2152..0d50daa 100644 --- a/src/user/dependencies.py +++ b/src/user/dependencies.py @@ -30,7 +30,7 @@ async def get_user_model_claims(claims: claims_dependency, db: db_dependency): return user_model -user_model_claims_dependency = Annotated[type[User], Depends(get_user_model_claims)] +user_model_claims_dependency = Annotated[User, Depends(get_user_model_claims)] async def get_user_model_query(db: db_dependency, user_id: Annotated[int, Query(gt=0)]): @@ -41,7 +41,7 @@ async def get_user_model_query(db: db_dependency, user_id: Annotated[int, Query( return user_model -user_model_query_dependency = Annotated[type[User], Depends(get_user_model_query)] +user_model_query_dependency = Annotated[User, Depends(get_user_model_query)] async def get_user_model_body(db: db_dependency, request_model: UserIDMixin): @@ -52,4 +52,4 @@ async def get_user_model_body(db: db_dependency, request_model: UserIDMixin): return user_model -user_model_body_dependency = Annotated[type[User], Depends(get_user_model_body)] +user_model_body_dependency = Annotated[User, Depends(get_user_model_body)] diff --git a/src/user/models.py b/src/user/models.py index 4946d57..5f603ee 100644 --- a/src/user/models.py +++ b/src/user/models.py @@ -30,9 +30,7 @@ class User(CustomBase): "Organisation", secondary="orgusers", back_populates="user_rel" ) - group_rel = relationship( - "Group", secondary="user_groups", back_populates="user_rel" - ) + group_rel = relationship("Group", secondary="user_groups", back_populates="user_rel") @property def organisations(self): diff --git a/src/user/router.py b/src/user/router.py index 7aecc12..a5ae4f5 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -190,9 +190,7 @@ async def accept_invitation( user_model: user_model_claims_dependency, request_model: UserPostInvitationAcceptRequest, ): - email_claims = await verify_email_token( - token=request_model.jwt, user_model=user_model - ) + email_claims = await verify_email_token(token=request_model.jwt, user_model=user_model) org_model = db.get(Org, email_claims["org_id"]) if org_model is None: diff --git a/test/conftest.py b/test/conftest.py index 7ebf64c..6411b96 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -15,7 +15,7 @@ from src.auth.service import get_current_user, get_dev_user from src.auth.dependencies import empty_su_list, get_super_admin_list, testing_su_list from src.main import app # inited FastAPI app from src.database import engine, get_db -from models import CustomBase +from src.models import CustomBase SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) diff --git a/test/test_auth_approval.py b/test/test_auth_approval.py index f038e51..42d0cd9 100644 --- a/test/test_auth_approval.py +++ b/test/test_auth_approval.py @@ -165,9 +165,7 @@ async def test_post_user_invitation_auth_approval(no_su_client: AsyncClient): @pytest.mark.anyio async def test_delete_group_permissions_auth_approval(no_su_client: AsyncClient): - resp = await no_su_client.delete( - "/iam/group/permission?org_id=3&group_id=1&perm_id=1" - ) + resp = await no_su_client.delete("/iam/group/permission?org_id=3&group_id=1&perm_id=1") assert resp.status_code != 422 assert "has not been approved." in resp.json()["detail"] diff --git a/test/test_auth_su.py b/test/test_auth_su.py index 7407cb9..09ce558 100644 --- a/test/test_auth_su.py +++ b/test/test_auth_su.py @@ -69,9 +69,7 @@ async def test_post_perm_auth_su(no_su_client: AsyncClient): @pytest.mark.anyio async def test_post_org_user_auth_su(no_su_client: AsyncClient): - resp = await no_su_client.post( - "/org/user", json={"organisation_id": 1, "user_id": 2} - ) + resp = await no_su_client.post("/org/user", json={"organisation_id": 1, "user_id": 2}) assert resp.status_code != 422 assert resp.status_code == 403 assert "Must be super admin" in resp.json()["detail"] diff --git a/test/test_iam.py b/test/test_iam.py index e234417..ad45543 100644 --- a/test/test_iam.py +++ b/test/test_iam.py @@ -25,9 +25,7 @@ async def test_post_act_on_resource_endpoint_success(default_client: AsyncClient "Authorization": "Bearer not_checked_when_auth_is_disabled", "X-API-Key": "123456789", } - resp = await default_client.post( - "/iam/can_act_on_resource", json=body, headers=headers - ) + resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) data = resp.json() assert resp.status_code == 200 @@ -56,9 +54,7 @@ async def test_act_on_resource_wrong_key( "Authorization": "Bearer not_checked_when_auth_is_disabled", "X-API-Key": api_key, } - resp = await default_client.post( - "/iam/can_act_on_resource", json=body, headers=headers - ) + resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) data = resp.json() assert resp.status_code == 401 @@ -110,9 +106,7 @@ async def test_act_on_resource_endpoint_status_checks( "Authorization": "Bearer not_checked_when_auth_is_disabled", "X-API-Key": "123456789", } - resp = await default_client.post( - "/iam/can_act_on_resource", json=body, headers=headers - ) + resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) assert resp.status_code == expected_status @@ -143,9 +137,7 @@ async def test_act_on_resource_logic( "Authorization": "Bearer not_checked_when_auth_is_disabled", "X-API-Key": "123456789", } - resp = await default_client.post( - "/iam/can_act_on_resource", json=body, headers=headers - ) + resp = await default_client.post("/iam/can_act_on_resource", json=body, headers=headers) data = resp.json() assert resp.status_code == 200 @@ -414,9 +406,7 @@ async def test_get_permissions_success(default_client: AsyncClient): assert permission["action"] == "read" -@pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["org_id"]) -) +@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_permissions_status_checks( default_client: AsyncClient, query: str, expected_status: int diff --git a/test/test_organisation.py b/test/test_organisation.py index 106d593..8c9cff6 100644 --- a/test/test_organisation.py +++ b/test/test_organisation.py @@ -35,9 +35,7 @@ async def test_get_org_success(default_client: AsyncClient): assert org["security_contact"]["email"] == "security@orgone.com" -@pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["org_id"]) -) +@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_org_status_checks( default_client: AsyncClient, query: str, expected_status: int @@ -60,7 +58,6 @@ async def test_post_org_success(default_client: AsyncClient): @pytest.mark.parametrize( "body, expected_status", [ - ({"name": "Org One"}, 409), ({"name": 42}, 422), ({}, 422), ({"name": "New Test Org", "intake_questionnaire": {"question_one": 42}}, 422), @@ -229,9 +226,7 @@ async def test_get_org_users_success(default_client: AsyncClient): assert data["organisation"]["id"] == 1 -@pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["org_id"]) -) +@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_org_users_status_checks( default_client: AsyncClient, query: str, expected_status: int @@ -243,9 +238,7 @@ async def test_get_org_users_status_checks( @pytest.mark.anyio async def test_post_org_user_success(default_client: AsyncClient): - resp = await default_client.post( - "/org/user", json={"organisation_id": 1, "user_id": 3} - ) + resp = await default_client.post("/org/user", json={"organisation_id": 1, "user_id": 3}) assert resp.status_code == 200 @@ -258,9 +251,7 @@ async def test_post_org_user_success(default_client: AsyncClient): assert "users" in data assert isinstance(data["users"], list) - assert ( - len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1 - ) + assert len([user for user in data["users"] if user["email"] == "root@orgtwo.com"]) == 1 @pytest.mark.parametrize( @@ -348,9 +339,7 @@ async def test_get_org_groups_success(default_client: AsyncClient): assert group["name"] == "Org One Group" -@pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["org_id"]) -) +@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_org_groups_status_checks( default_client: AsyncClient, query: str, expected_status: int @@ -363,9 +352,7 @@ async def test_get_org_groups_status_checks( @pytest.mark.parametrize("contact_type", ["billing", "security", "owner"]) @pytest.mark.anyio async def test_get_org_contact_success(default_client: AsyncClient, contact_type: str): - resp = await default_client.get( - f"/org/contact?org_id=1&contact_type={contact_type}" - ) + resp = await default_client.get(f"/org/contact?org_id=1&contact_type={contact_type}") data = resp.json() assert resp.status_code == 200 @@ -437,9 +424,7 @@ async def test_get_org_contact_status_checks( ], ) @pytest.mark.anyio -async def test_patch_org_contact_success( - default_client: AsyncClient, key: str, value: str -): +async def test_patch_org_contact_success(default_client: AsyncClient, key: str, value: str): resp = await default_client.patch( "/org/contact", json={"organisation_id": 1, "contact_type": "billing", key: value}, diff --git a/test/test_service.py b/test/test_service.py index b874a89..43e8bc4 100644 --- a/test/test_service.py +++ b/test/test_service.py @@ -24,9 +24,7 @@ async def test_get_services_success(default_client: AsyncClient): assert data["services"][0]["name"] == "Test Service" -@pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["org_id"]) -) +@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["org_id"])) @pytest.mark.anyio async def test_get_services_status_checks( default_client: AsyncClient, query: str, expected_status: int @@ -49,9 +47,7 @@ async def test_post_service_success(default_client: AsyncClient): assert isinstance(data["service"]["api_key"], str) -@pytest.mark.parametrize( - "body, expected_status", generate_body_and_status({"name": "str"}) -) +@pytest.mark.parametrize("body, expected_status", generate_body_and_status({"name": "str"})) @pytest.mark.anyio async def test_post_service_status_checks( default_client: AsyncClient, body: dict[str, str], expected_status: int diff --git a/test/test_user.py b/test/test_user.py index 5a5a43b..a497fa9 100644 --- a/test/test_user.py +++ b/test/test_user.py @@ -46,9 +46,7 @@ async def test_get_user_success(default_client: AsyncClient): @pytest.mark.anyio -@pytest.mark.parametrize( - "query, expected_status", generate_query_and_status(["user_id"]) -) +@pytest.mark.parametrize("query, expected_status", generate_query_and_status(["user_id"])) async def test_get_user_status_checks( default_client: AsyncClient, query: str, expected_status: int ): @@ -184,10 +182,8 @@ async def test_get_self_orgs_dynamic(default_client: AsyncClient): route = next( route - for route in default_client._transport.app.routes - if isinstance(route, APIRoute) - and path in route.path - and method in route.methods + for route in default_client._transport.app.routes # ty:ignore[unresolved-attribute] + if isinstance(route, APIRoute) and path in route.path and method in route.methods ) assert resp.status_code == route.status_code