From 62c43ce88324b3937823ba5e320c5cb0f77a53b7 Mon Sep 17 00:00:00 2001 From: luxferre Date: Tue, 9 Jun 2026 12:22:36 +0100 Subject: [PATCH] feat: sensical user invitation Users can now be invited to an org by email. "Email" for now is "print to stdout" Resolves #12 --- src/user/router.py | 67 ++++++++++++++++++++++++++++++++++++++++++--- src/user/schemas.py | 11 +++++++- src/user/service.py | 23 ++++++++++++++++ src/utils.py | 27 ++++++++++++++++++ test/test_user.py | 50 +++++++++++++++++++++++++++++++++ 5 files changed, 173 insertions(+), 5 deletions(-) create mode 100644 src/utils.py diff --git a/src/user/router.py b/src/user/router.py index b2effc2..64d4de8 100644 --- a/src/user/router.py +++ b/src/user/router.py @@ -8,18 +8,30 @@ Endpoints: - [DELETE](/user/): [super admin]: Removes a User(id) from the hub database. """ -from fastapi import APIRouter -from starlette import status +from fastapi import APIRouter, status, BackgroundTasks -from src.user.schemas import UserResponse, OIDCClaims +from auth.exceptions import UnauthorizedException +from organisation.exceptions import OrgNotFoundException +from src.user.schemas import ( + UserResponse, + OIDCClaims, + UserPostInvitationRequest, + UserPostInvitationAcceptRequest, +) from src.user.dependencies import ( user_model_claims_dependency, user_model_query_dependency, ) +from src.user.service import send_invitation +from src.organisation.models import Organisation as Org -from src.auth.dependencies import super_admin_dependency +from src.auth.dependencies import ( + super_admin_dependency, + org_model_root_claim_body_dependency, +) from src.auth.service import claims_dependency from src.database import db_dependency +from src.utils import decode_jwt router = APIRouter( prefix="/user", @@ -99,3 +111,50 @@ async def delete_user_by_id( """ db.delete(user_model) db.commit() + + +@router.post( + "/invitation", + summary="Send an email invitation for a user to join an org", + status_code=status.HTTP_200_OK, +) +async def invitation( + background_tasks: BackgroundTasks, + org_model: org_model_root_claim_body_dependency, + request_model: UserPostInvitationRequest, +): + org_id = org_model.id + org_name = org_model.name + user_email = request_model.user_email + + background_tasks.add_task( + send_invitation, org_id=org_id, org_name=org_name, user_email=user_email + ) + + return "Invitation sent" + + +@router.post( + "/invitation/accept", + summary="Accept email invitation to join an org", + status_code=status.HTTP_200_OK, +) +async def accept_invitation( + db: db_dependency, + user_model: user_model_claims_dependency, + request_model: UserPostInvitationAcceptRequest, +): + email_claims = await decode_jwt(request_model.jwt) + claimed_email = email_claims["user_email"] + + if user_model.email != claimed_email: + raise UnauthorizedException("The logged in user and email do not match.") + + org_model = db.get(Org, email_claims["org_id"]) + if org_model is None: + raise OrgNotFoundException() + + org_model.user_rel.append(user_model) + db.commit() + + return "Invitation accepted" diff --git a/src/user/schemas.py b/src/user/schemas.py index 8ef46df..74b58dd 100644 --- a/src/user/schemas.py +++ b/src/user/schemas.py @@ -3,7 +3,7 @@ Pydantic models for the user module """ from typing import Optional -from pydantic import Field +from pydantic import Field, EmailStr from src.schemas import CustomBaseModel @@ -55,3 +55,12 @@ class UserResponse(CustomBaseModel): class OrgResponse(CustomBaseModel): org_id: int name: str + + +class UserPostInvitationRequest(CustomBaseModel): + organisation_id: int + user_email: EmailStr + + +class UserPostInvitationAcceptRequest(CustomBaseModel): + jwt: str diff --git a/src/user/service.py b/src/user/service.py index 49ab238..ff1da8b 100644 --- a/src/user/service.py +++ b/src/user/service.py @@ -6,10 +6,12 @@ Exports: """ from typing import Any +from datetime import datetime, timedelta, timezone from sqlalchemy.orm import Session from src.exceptions import UnprocessableContentException +from src.utils import send_email, generate_jwt from src.user.schemas import OIDCUser from src.user.models import User @@ -48,3 +50,24 @@ async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int: db.add(db_user) db.commit() return user_id + + +async def send_invitation(user_email: str, org_name: str, org_id: int): + expiry_delta = timedelta(hours=24) + expiry = datetime.now(timezone.utc) + expiry_delta + claims = { + "email": user_email, + "org_id": org_id, + "exp": expiry, + "type": "org_invite", + } + + token = await generate_jwt(claims) + subject = f"You have been invited to join {org_name}" + body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}" + + await send_email( + recipient=user_email, + subject=subject, + body=body, + ) diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..ebf3ec1 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,27 @@ +from joserfc import jwt, jwk, errors + +from auth.exceptions import UnauthorizedException +from src.config import settings + + +KEY = jwk.import_key(settings.SECRET_KEY.get_secret_value(), "oct") + + +async def generate_jwt(claims): + jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims) + + return jwt_token + + +async def decode_jwt(encoded): + try: + token = jwt.decode(encoded, key=KEY) + return token.claims + except errors.DecodeError: + raise UnauthorizedException("Invalid JWS") + + +async def send_email(recipient: str, subject: str, body: str): + print(recipient) + print(subject) + print(body) diff --git a/test/test_user.py b/test/test_user.py index 4eadc3c..0266c7a 100644 --- a/test/test_user.py +++ b/test/test_user.py @@ -52,3 +52,53 @@ async def test_delete_user_success(default_client: AsyncClient): resp = await default_client.delete("/user/?user_id=1") assert resp.status_code == 204 + + +@pytest.mark.anyio +async def test_post_user_invitation_success(default_client: AsyncClient): + body = {"user_email": "admin@test.com", "organisation_id": 1} + resp = await default_client.post("/user/invitation", json=body) + + assert resp.status_code == 200 + assert resp.json() == "Invitation sent" + + +@pytest.mark.parametrize( + "body, expected_status", + [ + ({"organisation_id": 42, "user_email": "admin@test.com"}, 404), + ({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422), + ({"organisation_id": "", "user_email": "admin@test.com"}, 422), + ({}, 422), + ({"user_email": 42}, 422), + ({"organisation_id": 1, "user_email": "Test User"}, 422), + ], +) +@pytest.mark.anyio +async def test_post_user_invitation_status_checks( + default_client: AsyncClient, body, expected_status +): + resp = await default_client.post("/user/invitation", json=body) + + assert resp.status_code == expected_status + + +@pytest.mark.parametrize( + "body, expected_status", + [ + ({"jwt": "invalid"}, 401), + ({"jwt": ""}, 401), + ({"jwt": None}, 422), + ({"jwt": 42}, 422), + ], +) +@pytest.mark.anyio +async def test_post_user_invitation_accept_status_checks( + default_client: AsyncClient, body, expected_status +): + resp = await default_client.post("/user/invitation/accept", json=body) + + assert resp.status_code == expected_status + + if resp.status_code == 401: + assert resp.json()["detail"] == "Invalid JWS"