Compare commits

..

No commits in common. "88a64d204766e9afe0406aaa3bf71fb4a2570d9d" and "fc835dc98231c076fc368a1b794bbd0a76cfe22f" have entirely different histories.

9 changed files with 117 additions and 138 deletions

View file

@ -19,13 +19,6 @@ from src.organisation.models import Organisation as Org
from src.auth.exceptions import UnauthorizedException from src.auth.exceptions import UnauthorizedException
def is_super_admin(user_model) -> bool:
super_admin_emails = ["chris@sr2.uk"]
if user_model.email not in super_admin_emails:
raise UnauthorizedException()
return True
async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency): async def org_query_user_claims(org_model: org_model_query_dependency, user_model: user_model_claims_dependency):
if user_model in org_model.user_rel: if user_model in org_model.user_rel:
return True return True
@ -40,9 +33,6 @@ async def org_query_root_claims(user_model: user_model_claims_dependency, org_mo
if org_model.root_user_id == user_model.id: if org_model.root_user_id == user_model.id:
return org_model return org_model
if is_super_admin(user_model):
return org_model
raise UnauthorizedException() raise UnauthorizedException()
@ -53,20 +43,17 @@ async def org_body_root_claims(user_model: user_model_claims_dependency, org_mod
if org_model.root_user_id == user_model.id: if org_model.root_user_id == user_model.id:
return org_model return org_model
if is_super_admin(user_model):
return org_model
raise UnauthorizedException() raise UnauthorizedException()
org_model_root_claim_body_dependency = Annotated[type[Org], Depends(org_body_root_claims)] org_model_root_claim_body_dependency = Annotated[type[Org], Depends(org_body_root_claims)]
async def user_model_super_admin(user_model: user_model_claims_dependency): async def is_super_admin(user_model: user_model_claims_dependency):
if is_super_admin(user_model): super_admin_emails = []
return user_model if user_model.email not in super_admin_emails:
raise UnauthorizedException()
raise UnauthorizedException() return True
super_admin_dependency = Annotated[bool, Depends(user_model_super_admin)] super_admin_dependency = Annotated[bool, Depends(is_super_admin)]

View file

@ -9,6 +9,7 @@ from typing import Optional
from pydantic import EmailStr, ConfigDict from pydantic import EmailStr, ConfigDict
from src.organisation.constants import ContactType
from src.schemas import CustomBaseModel from src.schemas import CustomBaseModel
@ -24,11 +25,50 @@ class ContactAddress(CustomBaseModel):
postal_code: Optional[str] = None postal_code: Optional[str] = None
class ContactModel(CustomBaseModel): class ContactContactGetResponse(CustomBaseModel):
email: str
first_name: str
last_name: str
phonenumber: str
vat_number: Optional[str] = None
class ContactAddressGetResponse(CustomBaseModel):
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None # If using a PO box, there would be no street address
street_address_line_2: Optional[str] = None
locality: str
address_region: Optional[str] = None
country_code: str
postal_code: str
class ContactContactPostRequest(CustomBaseModel):
email: EmailStr
first_name: str
last_name: str
phonenumber: str
vat_number: Optional[str] = None
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None
street_address_line_2: Optional[str] = None
locality: str
address_region: Optional[str] = None
country_code: str
postal_code: str
class ContactUpdateRequest(CustomBaseModel):
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
first_name: Optional[str] = None first_name: Optional[str] = None
last_name: Optional[str] = None last_name: Optional[str] = None
phonenumber: Optional[str] = None phonenumber: Optional[str] = None
vat_number: Optional[str] = None vat_number: Optional[str] = None
post_office_box_number: Optional[str] = None
street_address: Optional[str] = None
street_address_line_2: Optional[str] = None
locality: Optional[str] = None
address_region: Optional[str] = None
country_code: Optional[str] = None
postal_code: Optional[str] = None
address: ContactAddress class ContactOrgGetResponse(CustomBaseModel):
name: str
contact_types: list[ContactType]

View file

@ -13,12 +13,3 @@ class UnprocessableContent(HTTPException):
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT, status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
detail=detail, detail=detail,
) )
class Conflict(HTTPException):
def __init__(self, message: Optional[str] = None) -> None:
detail = "Conflict" if not message else message
super().__init__(
status_code=status.HTTP_409_CONFLICT,
detail=detail,
)

View file

@ -6,10 +6,8 @@ Endpoints:
- Endpoints: Description - Endpoints: Description
""" """
from fastapi import APIRouter, status from fastapi import APIRouter, status
from sqlalchemy.exc import IntegrityError
from psycopg import errors
from src.exceptions import Conflict
from src.database import db_dependency from src.database import db_dependency
from src.schemas import ResourceName from src.schemas import ResourceName
from src.auth.exceptions import UnauthorizedException from src.auth.exceptions import UnauthorizedException
@ -86,11 +84,7 @@ async def create_group(db: db_dependency, org_model: org_model_root_claim_body_d
group_model = Group(name=request_model.name, org_id=org_model.id) group_model = Group(name=request_model.name, org_id=org_model.id)
db.add(group_model) db.add(group_model)
try: db.flush()
db.flush()
except IntegrityError as e:
if isinstance(e.orig, errors.UniqueViolation):
raise Conflict("Group with this name already exists")
response = GroupResponse(**group_model.__dict__) response = GroupResponse(**group_model.__dict__)
db.commit() db.commit()
return {"group": response} return {"group": response}
@ -101,9 +95,6 @@ async def add_group_permission(db: db_dependency, group_model: group_model_body_
if group_model.org_id == org_model.id: if group_model.org_id == org_model.id:
raise UnauthorizedException() raise UnauthorizedException()
if perm_model in group_model.permission_rel:
raise Conflict("Group already has this permission")
group_model.permission_rel.append(perm_model) group_model.permission_rel.append(perm_model)
db.flush() db.flush()
@ -117,9 +108,6 @@ async def add_group_user(db: db_dependency, group_model: group_model_body_depend
if group_model.org_id == org_model.id: if group_model.org_id == org_model.id:
raise UnauthorizedException() raise UnauthorizedException()
if user_model in group_model.user_rel:
raise Conflict("User already in group")
group_model.user_rel.append(user_model) group_model.user_rel.append(user_model)
db.flush() db.flush()
response = IAMPutGroupUserResponse(group=GroupResponse(**group_model.__dict__), users=group_model.user_rel) response = IAMPutGroupUserResponse(group=GroupResponse(**group_model.__dict__), users=group_model.user_rel)
@ -163,11 +151,8 @@ async def get_permissions(db: db_dependency, org_model: org_model_root_claim_bod
@router.post("/permission") @router.post("/permission")
async def create_new_permission(db: db_dependency, su: super_admin_dependency, request_mode: IAMPostPermissionRequest): async def create_new_permission(db: db_dependency, su: super_admin_dependency, request_mode: IAMPostPermissionRequest):
perm_model = Perm(**request_mode.__dict__) perm_model = Perm(**request_mode.__dict__)
try:
db.add(perm_model) db.add(perm_model)
except IntegrityError as e:
if isinstance(e.orig, errors.UniqueViolation):
raise Conflict(message="Permission already exists")
db.flush() db.flush()
response = IAMPostPermissionResponse(permission=PermissionResponse(**perm_model.__dict__)) response = IAMPostPermissionResponse(permission=PermissionResponse(**perm_model.__dict__))
db.commit() db.commit()

View file

@ -9,42 +9,35 @@ Functions:
- List: Description - List: Description
- Functions: Description - Functions: Description
""" """
from typing import Annotated, Optional from typing import Annotated
from fastapi import Depends, Query, Request from fastapi import Depends, Query
from src.database import db_dependency from src.database import db_dependency
from src.organisation.schemas import OrgIDMixin from src.organisation.schemas import OrgIDMixin
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
from src.organisation.exceptions import OrgNotFoundException, AwaitingApprovalException from src.organisation.exceptions import OrgNotFoundException
from src.organisation.constants import Status as OrgStatus
def get_org_model(db, request: Request, org_id: int): def get_org_model_query(db: db_dependency, org_id: Annotated[int, Query(gt=0)]) -> type[Org]:
org_model = db.get(Org, org_id) org_model = db.get(Org, org_id)
if org_model is None: if org_model is None:
raise OrgNotFoundException(org_id) raise OrgNotFoundException(org_id)
pre_approval_endpoints = ["PATCH/org/status", "PATCH/org/questionnaire", "GET/org/id"]
current_request = f"{request.method}{request.url.path}"
if current_request not in pre_approval_endpoints and org_model.status != OrgStatus.APPROVED:
raise AwaitingApprovalException(org_id)
return org_model return org_model
def get_org_model_query(db: db_dependency, request: Request, org_id: Annotated[int, Query(gt=0)]) -> type[Org]:
return get_org_model(db, request, org_id)
org_model_query_dependency = Annotated[type[Org], Depends(get_org_model_query)] org_model_query_dependency = Annotated[type[Org], Depends(get_org_model_query)]
def get_org_model_body(db: db_dependency, request: Request, request_model: OrgIDMixin) -> type[Org]: def get_org_model_body(db: db_dependency, request_model: OrgIDMixin) -> type[Org]:
org_id: Optional[int] = getattr(request_model, "organisation_id", None) org_id = getattr(request_model, "organisation_id", None)
if org_id is None: if org_id is None:
raise OrgNotFoundException raise OrgNotFoundException
org_model = db.get(Org, org_id)
if org_model is None:
raise OrgNotFoundException(org_id)
return get_org_model(db, request, org_id) return org_model
org_model_body_dependency = Annotated[type[Org], Depends(get_org_model_body)] org_model_body_dependency = Annotated[type[Org], Depends(get_org_model_body)]

View file

@ -17,11 +17,3 @@ class OrgNotFoundException(HTTPException):
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail=detail, detail=detail,
) )
class AwaitingApprovalException(HTTPException):
def __init__(self, org_id: Optional[int] = None) -> None:
detail = "Organisation has not been approved." if org_id is None else f"Organisation with ID '{org_id}' has not been approved."
super().__init__(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=detail,
)

View file

@ -16,11 +16,8 @@ from typing import Annotated, Optional
from fastapi import APIRouter, status from fastapi import APIRouter, status
from fastapi.params import Query from fastapi.params import Query
from psycopg.errors import UniqueViolation
from sqlalchemy.exc import IntegrityError
from contact.schemas import ContactModel from src.exceptions import UnprocessableContent
from src.exceptions import UnprocessableContent, Conflict
from src.contact.models import Contact from src.contact.models import Contact
from src.contact.schemas import ContactAddress from src.contact.schemas import ContactAddress
from src.contact.exceptions import ContactNotFoundException from src.contact.exceptions import ContactNotFoundException
@ -32,10 +29,10 @@ from src.auth.dependencies import super_admin_dependency, org_model_root_claim_q
from src.organisation.dependencies import org_model_body_dependency from src.organisation.dependencies import org_model_body_dependency
from src.organisation.constants import ContactType from src.organisation.constants import ContactType
from src.organisation.models import Organisation as Org from src.organisation.models import Organisation as Org
from src.organisation.schemas import OrgPostOrgRequest, OrgPatchQuestionnaireRequest, OrgPatchStatusRequest, \ from src.organisation.schemas import OrgOrgPostRequest, OrgQuestionnairePatchRequest, OrgStatusPatchRequest, \
OrgPatchContactRequest, \ OrgContactPatchRequest, \
OrgPostUserRequest, OrgGetUserResponse, OrgGetContactResponse, OrgGetOrgResponse, OrgPatchRootRequest, \ OrgUserPostRequest, OrgUserGetResponse, OrgContactGetResponse, OrgOrgGetResponse, OrgRootPatchRequest, \
OrgGetGroupResponse, OrgDeleteUserRequest, OrgDeleteOrgRequest OrgGroupGetResponse, OrgUserDeleteRequest, OrgDeleteOrgRequest
router = APIRouter( router = APIRouter(
@ -44,7 +41,7 @@ router = APIRouter(
) )
@router.get("/id", response_model=OrgGetOrgResponse) @router.get("/id", response_model=OrgOrgGetResponse)
async def get_org_by_id(org_model: org_model_root_claim_query_dependency): async def get_org_by_id(org_model: org_model_root_claim_query_dependency):
response = { response = {
"name": org_model.name, "name": org_model.name,
@ -59,7 +56,7 @@ async def get_org_by_id(org_model: org_model_root_claim_query_dependency):
@router.post("/") @router.post("/")
async def create_org(db: db_dependency, user_model: user_model_claims_dependency, request_model: OrgPostOrgRequest): async def create_org(db: db_dependency, user_model: user_model_claims_dependency, request_model: OrgOrgPostRequest):
if request_model.intake_questionnaire: if request_model.intake_questionnaire:
intake_questionnaire = request_model.intake_questionnaire.model_dump() intake_questionnaire = request_model.intake_questionnaire.model_dump()
else: else:
@ -69,12 +66,9 @@ async def create_org(db: db_dependency, user_model: user_model_claims_dependency
org_model.status = "partial" # Status is always set to partial at first, see update_questionnaire() doc org_model.status = "partial" # Status is always set to partial at first, see update_questionnaire() doc
db.add(org_model) db.add(org_model)
try: db.flush()
db.flush()
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise Conflict(message="Organisation with this name already exists")
# Adds currently logged-in user to org users list and sets them as root_user # Adds currently logged-in user to org users list and sets them as root_user
user_model = db.get(User, db_id)
org_model.user_rel.append(user_model) org_model.user_rel.append(user_model)
org_model.root_user_rel = user_model org_model.root_user_rel = user_model
for contact_type in ["billing_contact_id", "security_contact_id", "owner_contact_id"]: for contact_type in ["billing_contact_id", "security_contact_id", "owner_contact_id"]:
@ -86,7 +80,7 @@ async def create_org(db: db_dependency, user_model: user_model_claims_dependency
@router.patch("/questionnaire") @router.patch("/questionnaire")
async def update_questionnaire(db: db_dependency, org_model: org_model_root_claim_query_dependency, request_model: OrgPatchQuestionnaireRequest): async def update_questionnaire(db: db_dependency, org_model: org_model_root_claim_query_dependency, request_model: OrgQuestionnairePatchRequest):
""" """
Route for updating questionnaire. Route for updating questionnaire.
The partial bool allows for submission of partially completed questionnaire and/or The partial bool allows for submission of partially completed questionnaire and/or
@ -102,21 +96,21 @@ async def update_questionnaire(db: db_dependency, org_model: org_model_root_clai
@router.patch("/status") @router.patch("/status")
async def update_status(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchStatusRequest): async def update_status(db: db_dependency, org_model: org_model_body_dependency, su: super_admin_dependency, request_model: OrgStatusPatchRequest):
org_model.status = request_model.status org_model.status = request_model.status
db.commit() db.commit()
@router.get("/users", response_model=OrgGetUserResponse) @router.get("/users", response_model=OrgUserGetResponse)
async def get_users(org_model: org_model_root_claim_query_dependency): async def get_users(org_model: org_model_root_claim_query_dependency):
return {"users": [user.email for user in org_model.user_rel]} return {"users": [user.email for user in org_model.user_rel]}
@router.post("/users") @router.post("/users")
async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgPostUserRequest): async def add_user_to_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgUserPostRequest):
if user_model in org_model.user_rel: if user_model in org_model.user_rel:
raise Conflict(message="User already a part of this organisation") return
org_model.user_rel.append(user_model) org_model.user_rel.append(user_model)
db.commit() db.commit()
@ -128,18 +122,18 @@ async def delete_organisation_by_id(db: db_dependency, org_model: org_model_body
@router.patch("/root_user", status_code=status.HTTP_204_NO_CONTENT) @router.patch("/root_user", status_code=status.HTTP_204_NO_CONTENT)
async def update_root_user(db: db_dependency, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, request_model: OrgPatchRootRequest): async def update_root_user(db: db_dependency, org_model: org_model_body_dependency, user_model: user_model_body_dependency, su: super_admin_dependency, request_model: OrgRootPatchRequest):
org_model.root_user_rel = user_model org_model.root_user_rel = user_model
db.commit() db.commit()
@router.get("/groups", response_model=OrgGetGroupResponse) @router.get("/groups", response_model=OrgGroupGetResponse)
async def get_org_groups(org_model: org_model_root_claim_query_dependency): async def get_org_groups(org_model: org_model_root_claim_query_dependency):
return {"groups": [group.name for group in org_model.group_rel]} return {"groups": [group.name for group in org_model.group_rel]}
@router.delete("/user", status_code=status.HTTP_204_NO_CONTENT) @router.delete("/user", status_code=status.HTTP_204_NO_CONTENT)
async def remove_user_from_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgDeleteUserRequest): async def remove_user_from_org(db: db_dependency, org_model: org_model_root_claim_body_dependency, user_model: user_model_body_dependency, request_model: OrgUserDeleteRequest):
if user_model not in org_model.user_rel: if user_model not in org_model.user_rel:
return return
@ -147,7 +141,7 @@ async def remove_user_from_org(db: db_dependency, org_model: org_model_root_clai
db.commit() db.commit()
@router.get("/contact", response_model=OrgGetContactResponse) @router.get("/contact", response_model=OrgContactGetResponse)
async def get_contact(org_model: org_model_root_claim_query_dependency, contact_type: Annotated[ContactType, Query()]): async def get_contact(org_model: org_model_root_claim_query_dependency, contact_type: Annotated[ContactType, Query()]):
match contact_type: match contact_type:
case "billing": case "billing":
@ -162,14 +156,14 @@ async def get_contact(org_model: org_model_root_claim_query_dependency, contact_
if contact_model is None: if contact_model is None:
raise ContactNotFoundException() raise ContactNotFoundException()
address = ContactAddress.model_validate(contact_model) return OrgContactGetResponse.model_construct(
contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address) **contact_model.__dict__,
address=ContactAddress.model_validate(contact_model)
return {"contact": contact_response} )
@router.patch("/contact", response_model=OrgGetContactResponse) @router.patch("/contact", response_model=OrgContactGetResponse)
async def update_contact(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: OrgPatchContactRequest): async def update_contact(db: db_dependency, org_model: org_model_root_claim_body_dependency, request_model: OrgContactPatchRequest):
match request_model.contact_type: match request_model.contact_type:
case "billing": case "billing":
contact_model = org_model.billing_contact_rel contact_model = org_model.billing_contact_rel
@ -191,9 +185,11 @@ async def update_contact(db: db_dependency, org_model: org_model_root_claim_body
raise UnprocessableContent("Invalid keys in update request") raise UnprocessableContent("Invalid keys in update request")
db.flush() db.flush()
address = ContactAddress.model_validate(contact_model) response = OrgContactGetResponse.model_construct(
contact_response = ContactModel.model_construct(**contact_model.__dict__, address=address) **contact_model.__dict__,
address=ContactAddress.model_validate(contact_model)
)
db.commit() db.commit()
return {"contact": contact_response} return response

View file

@ -10,11 +10,8 @@ from typing import Optional
from pydantic import EmailStr, ConfigDict from pydantic import EmailStr, ConfigDict
from src.schemas import CustomBaseModel from src.schemas import CustomBaseModel
from src.contact.schemas import ContactModel
from src.user.schemas import UserIDMixin
from src.organisation.constants import Status, ContactType from src.organisation.constants import Status, ContactType
from src.contact.schemas import ContactAddress
class OrgQuestionnaire(CustomBaseModel): class OrgQuestionnaire(CustomBaseModel):
question_one: str question_one: str
@ -24,19 +21,18 @@ class OrgQuestionnaire(CustomBaseModel):
class OrgIDMixin(CustomBaseModel): class OrgIDMixin(CustomBaseModel):
organisation_id: int organisation_id: int
class OrgOrgPostRequest(CustomBaseModel):
class OrgPostOrgRequest(CustomBaseModel):
name: str name: str
intake_questionnaire: Optional[OrgQuestionnaire] = None intake_questionnaire: Optional[OrgQuestionnaire] = None
class OrgPatchQuestionnaireRequest(OrgIDMixin): class OrgQuestionnairePatchRequest(OrgIDMixin):
intake_questionnaire: OrgQuestionnaire intake_questionnaire: OrgQuestionnaire
partial: bool partial: bool
class OrgPatchStatusRequest(OrgIDMixin): class OrgStatusPatchRequest(OrgIDMixin):
status: Status status: Status
class OrgPatchContactRequest(OrgIDMixin): class OrgContactPatchRequest(OrgIDMixin):
contact_type: ContactType contact_type: ContactType
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
@ -52,27 +48,33 @@ class OrgPatchContactRequest(OrgIDMixin):
country_code: Optional[str] = None country_code: Optional[str] = None
postal_code: Optional[str] = None postal_code: Optional[str] = None
class OrgPostUserRequest(OrgIDMixin, UserIDMixin): class OrgUserPostRequest(OrgIDMixin):
pass user_id: int
class OrgDeleteUserRequest(OrgIDMixin, UserIDMixin): class OrgUserDeleteRequest(OrgIDMixin):
pass user_id: int
class OrgPatchRootRequest(OrgIDMixin, UserIDMixin): class OrgRootPatchRequest(OrgIDMixin):
pass user_id: int
class OrgGetUserResponse(CustomBaseModel): class OrgUserGetResponse(CustomBaseModel):
users: list[str] users: list[str]
class OrgGetGroupResponse(CustomBaseModel): class OrgGroupGetResponse(CustomBaseModel):
groups: list[str] groups: list[str]
class OrgGetContactResponse(CustomBaseModel): class OrgContactGetResponse(CustomBaseModel):
model_config = ConfigDict(from_attributes=True, extra="ignore") model_config = ConfigDict(from_attributes=True, extra="ignore")
contact: ContactModel email: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
phonenumber: Optional[str] = None
vat_number: Optional[str] = None
class OrgGetOrgResponse(CustomBaseModel): address: ContactAddress
class OrgOrgGetResponse(CustomBaseModel):
name: str name: str
status: Status status: Status
root_user: Optional[str] = None root_user: Optional[str] = None

View file

@ -6,12 +6,10 @@ Endpoints:
- Endpoints: Description - Endpoints: Description
""" """
from fastapi import APIRouter, status from fastapi import APIRouter, status
from psycopg.errors import UniqueViolation
from sqlalchemy.exc import IntegrityError
from src.exceptions import Conflict
from src.database import db_dependency from src.database import db_dependency
from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency from src.auth.service import claims_dependency
from src.auth.dependencies import super_admin_dependency, org_model_root_claim_query_dependency, org_model_root_claim_body_dependency
from src.service.models import Service from src.service.models import Service
from src.service.utils import generate_api_key from src.service.utils import generate_api_key
@ -36,12 +34,7 @@ async def register_service(db: db_dependency, su: super_admin_dependency, servic
service_model = Service(name=service_request.name, api_key=key) service_model = Service(name=service_request.name, api_key=key)
db.add(service_model) db.add(service_model)
try: db.flush()
db.flush()
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise Conflict(message="Service with this name already exists")
db.commit()
response = ServiceWithKeyResponse(**service_model.__dict__) response = ServiceWithKeyResponse(**service_model.__dict__)
db.commit() db.commit()
return {"service": response} return {"service": response}