Compare commits
No commits in common. "0b521414b3defb45a423c5adcda449d33125ab23" and "294baadcb71ef5c409df297b57df4af1f265abd4" have entirely different histories.
0b521414b3
...
294baadcb7
5 changed files with 15 additions and 43 deletions
|
|
@ -27,12 +27,3 @@ class ConflictException(HTTPException):
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
detail=detail,
|
detail=detail,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ForbiddenException(HTTPException):
|
|
||||||
def __init__(self, message: Optional[str] = None) -> None:
|
|
||||||
detail = "Forbidden" if not message else message
|
|
||||||
super().__init__(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=detail,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from src.iam.exceptions import GroupNotFoundException
|
||||||
from src.organisation.exceptions import OrgNotFoundException
|
from src.organisation.exceptions import OrgNotFoundException
|
||||||
from src.schemas import GroupSummary, OrgSummary
|
from src.schemas import GroupSummary, OrgSummary
|
||||||
from src.service.exceptions import ServiceNotFoundException
|
from src.service.exceptions import ServiceNotFoundException
|
||||||
from src.exceptions import ConflictException, ForbiddenException
|
from src.exceptions import ConflictException
|
||||||
from src.database import db_dependency
|
from src.database import db_dependency
|
||||||
from src.auth.exceptions import UnauthorizedException
|
from src.auth.exceptions import UnauthorizedException
|
||||||
from src.auth.service import claims_dependency
|
from src.auth.service import claims_dependency
|
||||||
|
|
@ -75,7 +75,7 @@ from src.iam.schemas import (
|
||||||
IAMPutGroupInvitationRequest,
|
IAMPutGroupInvitationRequest,
|
||||||
IAMPutGroupInvitationAcceptRequest,
|
IAMPutGroupInvitationAcceptRequest,
|
||||||
)
|
)
|
||||||
from src.utils import verify_email_token
|
from src.utils import decode_jwt
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
tags=["IAM"],
|
tags=["IAM"],
|
||||||
|
|
@ -211,11 +211,6 @@ async def add_group_user(
|
||||||
if user_model in group_model.user_rel:
|
if user_model in group_model.user_rel:
|
||||||
raise ConflictException("User already in group")
|
raise ConflictException("User already in group")
|
||||||
|
|
||||||
if user_model not in org_model.user_rel:
|
|
||||||
raise ForbiddenException(
|
|
||||||
"Adding users directly can only be done with org members. Use email invitation instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
group_model.user_rel.append(user_model)
|
group_model.user_rel.append(user_model)
|
||||||
db.flush()
|
db.flush()
|
||||||
response = IAMPutGroupUserResponse(
|
response = IAMPutGroupUserResponse(
|
||||||
|
|
@ -378,9 +373,11 @@ async def accept_invitation(
|
||||||
user_model: user_model_claims_dependency,
|
user_model: user_model_claims_dependency,
|
||||||
request_model: IAMPutGroupInvitationAcceptRequest,
|
request_model: IAMPutGroupInvitationAcceptRequest,
|
||||||
):
|
):
|
||||||
email_claims = await verify_email_token(
|
email_claims = await decode_jwt(request_model.jwt)
|
||||||
token=request_model.jwt, user_model=user_model
|
claimed_email = email_claims["email"]
|
||||||
)
|
|
||||||
|
if user_model.email != claimed_email:
|
||||||
|
raise UnauthorizedException("The logged in user and email do not match.")
|
||||||
|
|
||||||
org_model = db.get(Org, email_claims["org_id"])
|
org_model = db.get(Org, email_claims["org_id"])
|
||||||
if org_model is None:
|
if org_model is None:
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ Endpoints:
|
||||||
|
|
||||||
from fastapi import APIRouter, status, BackgroundTasks
|
from fastapi import APIRouter, status, BackgroundTasks
|
||||||
|
|
||||||
|
from src.auth.exceptions import UnauthorizedException
|
||||||
from src.organisation.exceptions import OrgNotFoundException
|
from src.organisation.exceptions import OrgNotFoundException
|
||||||
from src.user.schemas import (
|
from src.user.schemas import (
|
||||||
UserResponse,
|
UserResponse,
|
||||||
|
|
@ -31,7 +32,7 @@ from src.auth.dependencies import (
|
||||||
)
|
)
|
||||||
from src.auth.service import claims_dependency
|
from src.auth.service import claims_dependency
|
||||||
from src.database import db_dependency
|
from src.database import db_dependency
|
||||||
from src.utils import verify_email_token
|
from src.utils import decode_jwt
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/user",
|
prefix="/user",
|
||||||
|
|
@ -180,9 +181,11 @@ async def accept_invitation(
|
||||||
user_model: user_model_claims_dependency,
|
user_model: user_model_claims_dependency,
|
||||||
request_model: UserPostInvitationAcceptRequest,
|
request_model: UserPostInvitationAcceptRequest,
|
||||||
):
|
):
|
||||||
email_claims = await verify_email_token(
|
email_claims = await decode_jwt(request_model.jwt)
|
||||||
token=request_model.jwt, user_model=user_model
|
claimed_email = email_claims["email"]
|
||||||
)
|
|
||||||
|
if user_model.email != claimed_email:
|
||||||
|
raise UnauthorizedException("The logged in user and email do not match.")
|
||||||
|
|
||||||
org_model = db.get(Org, email_claims["org_id"])
|
org_model = db.get(Org, email_claims["org_id"])
|
||||||
if org_model is None:
|
if org_model is None:
|
||||||
|
|
|
||||||
17
src/utils.py
17
src/utils.py
|
|
@ -1,4 +1,3 @@
|
||||||
from datetime import datetime, timezone
|
|
||||||
from joserfc import jwt, jwk, errors
|
from joserfc import jwt, jwk, errors
|
||||||
|
|
||||||
from src.auth.exceptions import UnauthorizedException
|
from src.auth.exceptions import UnauthorizedException
|
||||||
|
|
@ -22,22 +21,6 @@ async def decode_jwt(encoded):
|
||||||
raise UnauthorizedException("Invalid JWS")
|
raise UnauthorizedException("Invalid JWS")
|
||||||
|
|
||||||
|
|
||||||
async def verify_email_token(user_model, token):
|
|
||||||
email_claims = await decode_jwt(token)
|
|
||||||
|
|
||||||
claimed_email = email_claims["email"]
|
|
||||||
|
|
||||||
expiry = datetime.fromtimestamp(email_claims["exp"], timezone.utc)
|
|
||||||
|
|
||||||
if expiry < datetime.now(timezone.utc):
|
|
||||||
raise UnauthorizedException("Invitation expired.")
|
|
||||||
|
|
||||||
if user_model.email != claimed_email:
|
|
||||||
raise UnauthorizedException("The logged in user and email do not match.")
|
|
||||||
|
|
||||||
return email_claims
|
|
||||||
|
|
||||||
|
|
||||||
async def send_email(recipient: str, subject: str, body: str):
|
async def send_email(recipient: str, subject: str, body: str):
|
||||||
print(recipient)
|
print(recipient)
|
||||||
print(subject)
|
print(subject)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import pytest
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
from src.user.models import User
|
from src.user.models import User
|
||||||
from src.organisation.models import Organisation as Org, OrgUsers
|
from src.organisation.models import Organisation as Org
|
||||||
from src.iam.models import Group
|
from src.iam.models import Group
|
||||||
|
|
||||||
from .conftest import generate_query_and_status
|
from .conftest import generate_query_and_status
|
||||||
|
|
@ -468,8 +468,6 @@ async def test_put_group_user_success(default_client: AsyncClient, db_session):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
db_session.flush()
|
db_session.flush()
|
||||||
db_session.add(OrgUsers(user_id=2, org_id=1))
|
|
||||||
db_session.flush()
|
|
||||||
|
|
||||||
resp = await default_client.put(
|
resp = await default_client.put(
|
||||||
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1}
|
"/iam/group/user", json={"user_id": 2, "group_id": 1, "organisation_id": 1}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue