from __future__ import annotations import os from collections.abc import Mapping from dataclasses import dataclass from typing import Literal, cast AUTH_MODE_DISABLED = "disabled" AUTH_MODE_TRUSTED_HEADERS = "trusted-headers" AUTH_MODE_ENV = "REPUBLISHER_AUTH_MODE" AuthMode = Literal["disabled", "trusted-headers"] AuthRole = Literal["admin", "publisher"] ROLE_HEADER = "X-Republisher-Auth-Role" PROVIDER_HEADER = "X-Republisher-Auth-Provider" USER_HEADER = "X-Republisher-Auth-User" EMAIL_HEADER = "X-Republisher-Auth-Email" PREFERRED_USERNAME_HEADER = "X-Republisher-Auth-Preferred-Username" GROUPS_HEADER = "X-Republisher-Auth-Groups" VALID_ROLES = frozenset({"admin", "publisher"}) @dataclass(frozen=True) class TrustedIdentity: role: AuthRole provider: str user: str email: str preferred_username: str groups: tuple[str, ...] def load_auth_mode(environ: Mapping[str, str] | None = None) -> AuthMode: raw_mode = (environ or os.environ).get(AUTH_MODE_ENV, AUTH_MODE_DISABLED).strip() if raw_mode in {AUTH_MODE_DISABLED, AUTH_MODE_TRUSTED_HEADERS}: return cast(AuthMode, raw_mode) raise ValueError( f"Unsupported {AUTH_MODE_ENV}: {raw_mode!r}. " f"Expected {AUTH_MODE_DISABLED!r} or {AUTH_MODE_TRUSTED_HEADERS!r}." ) def load_trusted_identity(headers: Mapping[str, str]) -> TrustedIdentity | None: role = _read_header(headers, ROLE_HEADER) if role not in VALID_ROLES: return None provider = _read_header(headers, PROVIDER_HEADER) user = _read_header(headers, USER_HEADER) email = _read_header(headers, EMAIL_HEADER) if provider is None or user is None or email is None: return None preferred_username = _read_header(headers, PREFERRED_USERNAME_HEADER) or user return TrustedIdentity( role=cast(AuthRole, role), provider=provider, user=user, email=email, preferred_username=preferred_username, groups=_read_groups(headers.get(GROUPS_HEADER, "")), ) def _read_header(headers: Mapping[str, str], name: str) -> str | None: value = headers.get(name) if value is None: return None stripped = value.strip() return stripped or None def _read_groups(value: str) -> tuple[str, ...]: return tuple( group for group in (part.strip() for part in value.split(",")) if group )