from __future__ import annotations import http.client import json import random import socket import time from dataclasses import dataclass from typing import Any from urllib.parse import urlparse @dataclass(frozen=True) class RetryPolicy: max_attempts: int base_seconds: float max_seconds: float class DaemonError(RuntimeError): def __init__( self, message: str, *, path: str, status: int | None = None, response: dict[str, Any] | None = None, cause: Exception | None = None, ) -> None: super().__init__(message) self.path = path self.status = status self.response = response self.cause = cause class UnixSocketHTTPConnection(http.client.HTTPConnection): def __init__(self, socket_path: str, timeout: float) -> None: super().__init__(host="localhost", port=0, timeout=timeout) self._socket_path = socket_path def connect(self) -> None: self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.sock.settimeout(self.timeout) self.sock.connect(self._socket_path) class DaemonClient: def __init__( self, *, retry_policy: RetryPolicy, socket_path: str | None = None, base_url: str | None = None, auth_token: str | None = None, ) -> None: if (socket_path is None) == (base_url is None): raise ValueError("exactly one of socket_path or base_url must be set") self._socket_path = socket_path self._base_url = base_url self._auth_token = auth_token.strip() if auth_token is not None else None self._retry = retry_policy self._base_path = "" self._http_scheme = "http" self._http_host = "localhost" self._http_port = 80 if base_url is not None: parsed = urlparse(base_url) if parsed.scheme not in {"http", "https"}: raise ValueError("base_url must use http or https scheme") if parsed.hostname is None: raise ValueError("base_url must include a hostname") self._http_scheme = parsed.scheme self._http_host = parsed.hostname if parsed.port is not None: self._http_port = parsed.port elif parsed.scheme == "https": self._http_port = 443 self._base_path = parsed.path.rstrip("/") def post_json( self, path: str, body: dict[str, Any], timeout_seconds: float, retryable_statuses: set[int], ) -> dict[str, Any]: return self._request_json( method="POST", path=path, timeout_seconds=timeout_seconds, retryable_statuses=retryable_statuses, body=body, ) def get_json( self, path: str, timeout_seconds: float, retryable_statuses: set[int], ) -> dict[str, Any]: return self._request_json( method="GET", path=path, timeout_seconds=timeout_seconds, retryable_statuses=retryable_statuses, body=None, ) def _request_json( self, *, method: str, path: str, timeout_seconds: float, retryable_statuses: set[int], body: dict[str, Any] | None, ) -> dict[str, Any]: last_error: DaemonError | None = None for attempt in range(1, self._retry.max_attempts + 1): try: payload = json.dumps(body).encode("utf-8") if body is not None else None response_body, status = self._raw_request( method=method, path=path, timeout_seconds=timeout_seconds, payload=payload, ) parsed = self._parse_json(response_body, path) if 200 <= status < 300: return parsed err = DaemonError( f"daemon returned HTTP {status} for {method} {path}", path=path, status=status, response=parsed, ) retryable = status in retryable_statuses if not retryable: raise err last_error = err except (ConnectionRefusedError, FileNotFoundError, TimeoutError, OSError) as exc: last_error = DaemonError( f"daemon transport error during {method} {path}: {exc}", path=path, cause=exc, ) except DaemonError: raise if attempt < self._retry.max_attempts: self._sleep_backoff(attempt) assert last_error is not None raise last_error def _raw_request( self, *, method: str, path: str, timeout_seconds: float, payload: bytes | None, ) -> tuple[bytes, int]: request_path = path if path.startswith("/") else f"/{path}" if self._base_path != "": request_path = f"{self._base_path}{request_path}" headers = {"Accept": "application/json"} if payload is not None: headers["Content-Type"] = "application/json" if self._auth_token is not None: headers["Authorization"] = f"Bearer {self._auth_token}" conn: http.client.HTTPConnection if self._socket_path is not None: conn = UnixSocketHTTPConnection(self._socket_path, timeout=timeout_seconds) else: conn = ( http.client.HTTPSConnection( self._http_host, self._http_port, timeout=timeout_seconds ) if self._http_scheme == "https" else http.client.HTTPConnection( self._http_host, self._http_port, timeout=timeout_seconds ) ) try: conn.request(method=method, url=request_path, body=payload, headers=headers) response = conn.getresponse() data = response.read() return data, response.status finally: conn.close() @staticmethod def _parse_json(raw: bytes, path: str) -> dict[str, Any]: if not raw: return {} try: data = json.loads(raw.decode("utf-8")) except json.JSONDecodeError as exc: raise DaemonError( f"daemon returned invalid JSON for {path}", path=path, cause=exc, ) from exc if not isinstance(data, dict): raise DaemonError(f"daemon returned non-object JSON for {path}", path=path) return data def _sleep_backoff(self, attempt: int) -> None: ceiling = min(self._retry.max_seconds, self._retry.base_seconds * (2 ** (attempt - 1))) time.sleep(random.uniform(0.0, ceiling))