from __future__ import annotations import http.client import json import random import socket import time from dataclasses import dataclass from typing import Any @dataclass(frozen=True) class RetryPolicy: max_attempts: int base_seconds: float max_seconds: float class DaemonError(RuntimeError): def __init__( self, message: str, *, path: str, status: int | None = None, response: dict[str, Any] | None = None, cause: Exception | None = None, ) -> None: super().__init__(message) self.path = path self.status = status self.response = response self.cause = cause class UnixSocketHTTPConnection(http.client.HTTPConnection): def __init__(self, socket_path: str, timeout: float) -> None: super().__init__(host="localhost", port=0, timeout=timeout) self._socket_path = socket_path def connect(self) -> None: self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.sock.settimeout(self.timeout) self.sock.connect(self._socket_path) class DaemonClient: def __init__(self, socket_path: str, retry_policy: RetryPolicy) -> None: self._socket_path = socket_path self._retry = retry_policy def post_json( self, path: str, body: dict[str, Any], timeout_seconds: float, retryable_statuses: set[int], ) -> dict[str, Any]: return self._request_json( method="POST", path=path, timeout_seconds=timeout_seconds, retryable_statuses=retryable_statuses, body=body, ) def get_json( self, path: str, timeout_seconds: float, retryable_statuses: set[int], ) -> dict[str, Any]: return self._request_json( method="GET", path=path, timeout_seconds=timeout_seconds, retryable_statuses=retryable_statuses, body=None, ) def _request_json( self, *, method: str, path: str, timeout_seconds: float, retryable_statuses: set[int], body: dict[str, Any] | None, ) -> dict[str, Any]: last_error: DaemonError | None = None for attempt in range(1, self._retry.max_attempts + 1): try: payload = json.dumps(body).encode("utf-8") if body is not None else None response_body, status = self._raw_request( method=method, path=path, timeout_seconds=timeout_seconds, payload=payload, ) parsed = self._parse_json(response_body, path) if 200 <= status < 300: return parsed err = DaemonError( f"daemon returned HTTP {status} for {method} {path}", path=path, status=status, response=parsed, ) retryable = status in retryable_statuses if not retryable: raise err last_error = err except (ConnectionRefusedError, FileNotFoundError, TimeoutError, OSError) as exc: last_error = DaemonError( f"daemon transport error during {method} {path}: {exc}", path=path, cause=exc, ) except DaemonError: raise if attempt < self._retry.max_attempts: self._sleep_backoff(attempt) assert last_error is not None raise last_error def _raw_request( self, *, method: str, path: str, timeout_seconds: float, payload: bytes | None, ) -> tuple[bytes, int]: conn = UnixSocketHTTPConnection(self._socket_path, timeout=timeout_seconds) headers = {"Accept": "application/json"} if payload is not None: headers["Content-Type"] = "application/json" try: conn.request(method=method, url=path, body=payload, headers=headers) response = conn.getresponse() data = response.read() return data, response.status finally: conn.close() @staticmethod def _parse_json(raw: bytes, path: str) -> dict[str, Any]: if not raw: return {} try: data = json.loads(raw.decode("utf-8")) except json.JSONDecodeError as exc: raise DaemonError( f"daemon returned invalid JSON for {path}", path=path, cause=exc, ) from exc if not isinstance(data, dict): raise DaemonError(f"daemon returned non-object JSON for {path}", path=path) return data def _sleep_backoff(self, attempt: int) -> None: ceiling = min(self._retry.max_seconds, self._retry.base_seconds * (2 ** (attempt - 1))) time.sleep(random.uniform(0.0, ceiling))