feat: sensical user invitation
Some checks failed
ci / lint_and_test (push) Failing after 8s

Users can now be invited to an org by email.

"Email" for now is "print to stdout"

Resolves #12
This commit is contained in:
Chris Milne 2026-06-09 12:22:36 +01:00
parent 1012947b67
commit 62c43ce883
5 changed files with 173 additions and 5 deletions

View file

@ -8,18 +8,30 @@ Endpoints:
- [DELETE](/user/): [super admin]: Removes a User(id) from the hub database.
"""
from fastapi import APIRouter
from starlette import status
from fastapi import APIRouter, status, BackgroundTasks
from src.user.schemas import UserResponse, OIDCClaims
from auth.exceptions import UnauthorizedException
from organisation.exceptions import OrgNotFoundException
from src.user.schemas import (
UserResponse,
OIDCClaims,
UserPostInvitationRequest,
UserPostInvitationAcceptRequest,
)
from src.user.dependencies import (
user_model_claims_dependency,
user_model_query_dependency,
)
from src.user.service import send_invitation
from src.organisation.models import Organisation as Org
from src.auth.dependencies import super_admin_dependency
from src.auth.dependencies import (
super_admin_dependency,
org_model_root_claim_body_dependency,
)
from src.auth.service import claims_dependency
from src.database import db_dependency
from src.utils import decode_jwt
router = APIRouter(
prefix="/user",
@ -99,3 +111,50 @@ async def delete_user_by_id(
"""
db.delete(user_model)
db.commit()
@router.post(
"/invitation",
summary="Send an email invitation for a user to join an org",
status_code=status.HTTP_200_OK,
)
async def invitation(
background_tasks: BackgroundTasks,
org_model: org_model_root_claim_body_dependency,
request_model: UserPostInvitationRequest,
):
org_id = org_model.id
org_name = org_model.name
user_email = request_model.user_email
background_tasks.add_task(
send_invitation, org_id=org_id, org_name=org_name, user_email=user_email
)
return "Invitation sent"
@router.post(
"/invitation/accept",
summary="Accept email invitation to join an org",
status_code=status.HTTP_200_OK,
)
async def accept_invitation(
db: db_dependency,
user_model: user_model_claims_dependency,
request_model: UserPostInvitationAcceptRequest,
):
email_claims = await decode_jwt(request_model.jwt)
claimed_email = email_claims["user_email"]
if user_model.email != claimed_email:
raise UnauthorizedException("The logged in user and email do not match.")
org_model = db.get(Org, email_claims["org_id"])
if org_model is None:
raise OrgNotFoundException()
org_model.user_rel.append(user_model)
db.commit()
return "Invitation accepted"

View file

@ -3,7 +3,7 @@ Pydantic models for the user module
"""
from typing import Optional
from pydantic import Field
from pydantic import Field, EmailStr
from src.schemas import CustomBaseModel
@ -55,3 +55,12 @@ class UserResponse(CustomBaseModel):
class OrgResponse(CustomBaseModel):
org_id: int
name: str
class UserPostInvitationRequest(CustomBaseModel):
organisation_id: int
user_email: EmailStr
class UserPostInvitationAcceptRequest(CustomBaseModel):
jwt: str

View file

@ -6,10 +6,12 @@ Exports:
"""
from typing import Any
from datetime import datetime, timedelta, timezone
from sqlalchemy.orm import Session
from src.exceptions import UnprocessableContentException
from src.utils import send_email, generate_jwt
from src.user.schemas import OIDCUser
from src.user.models import User
@ -48,3 +50,24 @@ async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
db.add(db_user)
db.commit()
return user_id
async def send_invitation(user_email: str, org_name: str, org_id: int):
expiry_delta = timedelta(hours=24)
expiry = datetime.now(timezone.utc) + expiry_delta
claims = {
"email": user_email,
"org_id": org_id,
"exp": expiry,
"type": "org_invite",
}
token = await generate_jwt(claims)
subject = f"You have been invited to join {org_name}"
body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
await send_email(
recipient=user_email,
subject=subject,
body=body,
)

27
src/utils.py Normal file
View file

@ -0,0 +1,27 @@
from joserfc import jwt, jwk, errors
from auth.exceptions import UnauthorizedException
from src.config import settings
KEY = jwk.import_key(settings.SECRET_KEY.get_secret_value(), "oct")
async def generate_jwt(claims):
jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims)
return jwt_token
async def decode_jwt(encoded):
try:
token = jwt.decode(encoded, key=KEY)
return token.claims
except errors.DecodeError:
raise UnauthorizedException("Invalid JWS")
async def send_email(recipient: str, subject: str, body: str):
print(recipient)
print(subject)
print(body)

View file

@ -52,3 +52,53 @@ async def test_delete_user_success(default_client: AsyncClient):
resp = await default_client.delete("/user/?user_id=1")
assert resp.status_code == 204
@pytest.mark.anyio
async def test_post_user_invitation_success(default_client: AsyncClient):
body = {"user_email": "admin@test.com", "organisation_id": 1}
resp = await default_client.post("/user/invitation", json=body)
assert resp.status_code == 200
assert resp.json() == "Invitation sent"
@pytest.mark.parametrize(
"body, expected_status",
[
({"organisation_id": 42, "user_email": "admin@test.com"}, 404),
({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422),
({"organisation_id": "", "user_email": "admin@test.com"}, 422),
({}, 422),
({"user_email": 42}, 422),
({"organisation_id": 1, "user_email": "Test User"}, 422),
],
)
@pytest.mark.anyio
async def test_post_user_invitation_status_checks(
default_client: AsyncClient, body, expected_status
):
resp = await default_client.post("/user/invitation", json=body)
assert resp.status_code == expected_status
@pytest.mark.parametrize(
"body, expected_status",
[
({"jwt": "invalid"}, 401),
({"jwt": ""}, 401),
({"jwt": None}, 422),
({"jwt": 42}, 422),
],
)
@pytest.mark.anyio
async def test_post_user_invitation_accept_status_checks(
default_client: AsyncClient, body, expected_status
):
resp = await default_client.post("/user/invitation/accept", json=body)
assert resp.status_code == expected_status
if resp.status_code == 401:
assert resp.json()["detail"] == "Invalid JWS"