Users can now be invited to an org by email. "Email" for now is "print to stdout" Resolves #12
This commit is contained in:
parent
1012947b67
commit
62c43ce883
5 changed files with 173 additions and 5 deletions
|
|
@ -8,18 +8,30 @@ Endpoints:
|
|||
- [DELETE](/user/): [super admin]: Removes a User(id) from the hub database.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from starlette import status
|
||||
from fastapi import APIRouter, status, BackgroundTasks
|
||||
|
||||
from src.user.schemas import UserResponse, OIDCClaims
|
||||
from auth.exceptions import UnauthorizedException
|
||||
from organisation.exceptions import OrgNotFoundException
|
||||
from src.user.schemas import (
|
||||
UserResponse,
|
||||
OIDCClaims,
|
||||
UserPostInvitationRequest,
|
||||
UserPostInvitationAcceptRequest,
|
||||
)
|
||||
from src.user.dependencies import (
|
||||
user_model_claims_dependency,
|
||||
user_model_query_dependency,
|
||||
)
|
||||
from src.user.service import send_invitation
|
||||
from src.organisation.models import Organisation as Org
|
||||
|
||||
from src.auth.dependencies import super_admin_dependency
|
||||
from src.auth.dependencies import (
|
||||
super_admin_dependency,
|
||||
org_model_root_claim_body_dependency,
|
||||
)
|
||||
from src.auth.service import claims_dependency
|
||||
from src.database import db_dependency
|
||||
from src.utils import decode_jwt
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/user",
|
||||
|
|
@ -99,3 +111,50 @@ async def delete_user_by_id(
|
|||
"""
|
||||
db.delete(user_model)
|
||||
db.commit()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invitation",
|
||||
summary="Send an email invitation for a user to join an org",
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def invitation(
|
||||
background_tasks: BackgroundTasks,
|
||||
org_model: org_model_root_claim_body_dependency,
|
||||
request_model: UserPostInvitationRequest,
|
||||
):
|
||||
org_id = org_model.id
|
||||
org_name = org_model.name
|
||||
user_email = request_model.user_email
|
||||
|
||||
background_tasks.add_task(
|
||||
send_invitation, org_id=org_id, org_name=org_name, user_email=user_email
|
||||
)
|
||||
|
||||
return "Invitation sent"
|
||||
|
||||
|
||||
@router.post(
|
||||
"/invitation/accept",
|
||||
summary="Accept email invitation to join an org",
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def accept_invitation(
|
||||
db: db_dependency,
|
||||
user_model: user_model_claims_dependency,
|
||||
request_model: UserPostInvitationAcceptRequest,
|
||||
):
|
||||
email_claims = await decode_jwt(request_model.jwt)
|
||||
claimed_email = email_claims["user_email"]
|
||||
|
||||
if user_model.email != claimed_email:
|
||||
raise UnauthorizedException("The logged in user and email do not match.")
|
||||
|
||||
org_model = db.get(Org, email_claims["org_id"])
|
||||
if org_model is None:
|
||||
raise OrgNotFoundException()
|
||||
|
||||
org_model.user_rel.append(user_model)
|
||||
db.commit()
|
||||
|
||||
return "Invitation accepted"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ Pydantic models for the user module
|
|||
"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
from pydantic import Field, EmailStr
|
||||
|
||||
from src.schemas import CustomBaseModel
|
||||
|
||||
|
|
@ -55,3 +55,12 @@ class UserResponse(CustomBaseModel):
|
|||
class OrgResponse(CustomBaseModel):
|
||||
org_id: int
|
||||
name: str
|
||||
|
||||
|
||||
class UserPostInvitationRequest(CustomBaseModel):
|
||||
organisation_id: int
|
||||
user_email: EmailStr
|
||||
|
||||
|
||||
class UserPostInvitationAcceptRequest(CustomBaseModel):
|
||||
jwt: str
|
||||
|
|
|
|||
|
|
@ -6,10 +6,12 @@ Exports:
|
|||
"""
|
||||
|
||||
from typing import Any
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from src.exceptions import UnprocessableContentException
|
||||
from src.utils import send_email, generate_jwt
|
||||
|
||||
from src.user.schemas import OIDCUser
|
||||
from src.user.models import User
|
||||
|
|
@ -48,3 +50,24 @@ async def add_user_to_db(db: Session, user_claims: dict[str, Any]) -> int:
|
|||
db.add(db_user)
|
||||
db.commit()
|
||||
return user_id
|
||||
|
||||
|
||||
async def send_invitation(user_email: str, org_name: str, org_id: int):
|
||||
expiry_delta = timedelta(hours=24)
|
||||
expiry = datetime.now(timezone.utc) + expiry_delta
|
||||
claims = {
|
||||
"email": user_email,
|
||||
"org_id": org_id,
|
||||
"exp": expiry,
|
||||
"type": "org_invite",
|
||||
}
|
||||
|
||||
token = await generate_jwt(claims)
|
||||
subject = f"You have been invited to join {org_name}"
|
||||
body = f"You have been invited to join {org_name}.\nClick the link to accept.\nfrontend.capture/send/to/endpoint/{token}"
|
||||
|
||||
await send_email(
|
||||
recipient=user_email,
|
||||
subject=subject,
|
||||
body=body,
|
||||
)
|
||||
|
|
|
|||
27
src/utils.py
Normal file
27
src/utils.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
from joserfc import jwt, jwk, errors
|
||||
|
||||
from auth.exceptions import UnauthorizedException
|
||||
from src.config import settings
|
||||
|
||||
|
||||
KEY = jwk.import_key(settings.SECRET_KEY.get_secret_value(), "oct")
|
||||
|
||||
|
||||
async def generate_jwt(claims):
|
||||
jwt_token = jwt.encode(header={"alg": "HS256"}, key=KEY, claims=claims)
|
||||
|
||||
return jwt_token
|
||||
|
||||
|
||||
async def decode_jwt(encoded):
|
||||
try:
|
||||
token = jwt.decode(encoded, key=KEY)
|
||||
return token.claims
|
||||
except errors.DecodeError:
|
||||
raise UnauthorizedException("Invalid JWS")
|
||||
|
||||
|
||||
async def send_email(recipient: str, subject: str, body: str):
|
||||
print(recipient)
|
||||
print(subject)
|
||||
print(body)
|
||||
|
|
@ -52,3 +52,53 @@ async def test_delete_user_success(default_client: AsyncClient):
|
|||
resp = await default_client.delete("/user/?user_id=1")
|
||||
|
||||
assert resp.status_code == 204
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_post_user_invitation_success(default_client: AsyncClient):
|
||||
body = {"user_email": "admin@test.com", "organisation_id": 1}
|
||||
resp = await default_client.post("/user/invitation", json=body)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == "Invitation sent"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"body, expected_status",
|
||||
[
|
||||
({"organisation_id": 42, "user_email": "admin@test.com"}, 404),
|
||||
({"organisation_id": "Test Org", "user_email": "admin@test.com"}, 422),
|
||||
({"organisation_id": "", "user_email": "admin@test.com"}, 422),
|
||||
({}, 422),
|
||||
({"user_email": 42}, 422),
|
||||
({"organisation_id": 1, "user_email": "Test User"}, 422),
|
||||
],
|
||||
)
|
||||
@pytest.mark.anyio
|
||||
async def test_post_user_invitation_status_checks(
|
||||
default_client: AsyncClient, body, expected_status
|
||||
):
|
||||
resp = await default_client.post("/user/invitation", json=body)
|
||||
|
||||
assert resp.status_code == expected_status
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"body, expected_status",
|
||||
[
|
||||
({"jwt": "invalid"}, 401),
|
||||
({"jwt": ""}, 401),
|
||||
({"jwt": None}, 422),
|
||||
({"jwt": 42}, 422),
|
||||
],
|
||||
)
|
||||
@pytest.mark.anyio
|
||||
async def test_post_user_invitation_accept_status_checks(
|
||||
default_client: AsyncClient, body, expected_status
|
||||
):
|
||||
resp = await default_client.post("/user/invitation/accept", json=body)
|
||||
|
||||
assert resp.status_code == expected_status
|
||||
|
||||
if resp.status_code == 401:
|
||||
assert resp.json()["detail"] == "Invalid JWS"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue