77 lines
2.3 KiB
Python
77 lines
2.3 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from collections.abc import Mapping
|
|
from dataclasses import dataclass
|
|
from typing import Literal, cast
|
|
|
|
AUTH_MODE_DISABLED = "disabled"
|
|
AUTH_MODE_TRUSTED_HEADERS = "trusted-headers"
|
|
AUTH_MODE_ENV = "REPUBLISHER_AUTH_MODE"
|
|
|
|
AuthMode = Literal["disabled", "trusted-headers"]
|
|
AuthRole = Literal["admin", "publisher"]
|
|
|
|
ROLE_HEADER = "X-Republisher-Auth-Role"
|
|
PROVIDER_HEADER = "X-Republisher-Auth-Provider"
|
|
USER_HEADER = "X-Republisher-Auth-User"
|
|
EMAIL_HEADER = "X-Republisher-Auth-Email"
|
|
PREFERRED_USERNAME_HEADER = "X-Republisher-Auth-Preferred-Username"
|
|
GROUPS_HEADER = "X-Republisher-Auth-Groups"
|
|
VALID_ROLES = frozenset({"admin", "publisher"})
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TrustedIdentity:
|
|
role: AuthRole
|
|
provider: str
|
|
user: str
|
|
email: str
|
|
preferred_username: str
|
|
groups: tuple[str, ...]
|
|
|
|
|
|
def load_auth_mode(environ: Mapping[str, str] | None = None) -> AuthMode:
|
|
raw_mode = (environ or os.environ).get(AUTH_MODE_ENV, AUTH_MODE_DISABLED).strip()
|
|
if raw_mode in {AUTH_MODE_DISABLED, AUTH_MODE_TRUSTED_HEADERS}:
|
|
return cast(AuthMode, raw_mode)
|
|
raise ValueError(
|
|
f"Unsupported {AUTH_MODE_ENV}: {raw_mode!r}. "
|
|
f"Expected {AUTH_MODE_DISABLED!r} or {AUTH_MODE_TRUSTED_HEADERS!r}."
|
|
)
|
|
|
|
|
|
def load_trusted_identity(headers: Mapping[str, str]) -> TrustedIdentity | None:
|
|
role = _read_header(headers, ROLE_HEADER)
|
|
if role not in VALID_ROLES:
|
|
return None
|
|
|
|
provider = _read_header(headers, PROVIDER_HEADER)
|
|
user = _read_header(headers, USER_HEADER)
|
|
email = _read_header(headers, EMAIL_HEADER)
|
|
if provider is None or user is None or email is None:
|
|
return None
|
|
|
|
preferred_username = _read_header(headers, PREFERRED_USERNAME_HEADER) or user
|
|
return TrustedIdentity(
|
|
role=cast(AuthRole, role),
|
|
provider=provider,
|
|
user=user,
|
|
email=email,
|
|
preferred_username=preferred_username,
|
|
groups=_read_groups(headers.get(GROUPS_HEADER, "")),
|
|
)
|
|
|
|
|
|
def _read_header(headers: Mapping[str, str], name: str) -> str | None:
|
|
value = headers.get(name)
|
|
if value is None:
|
|
return None
|
|
stripped = value.strip()
|
|
return stripped or None
|
|
|
|
|
|
def _read_groups(value: str) -> tuple[str, ...]:
|
|
return tuple(
|
|
group for group in (part.strip() for part in value.split(",")) if group
|
|
)
|