from __future__ import annotations import json import os import socketserver import tempfile import threading from collections.abc import Callable from contextlib import suppress from dataclasses import dataclass from http import HTTPStatus from http.server import BaseHTTPRequestHandler from pathlib import Path from typing import Any import pytest from buildbot_autoscale_ext.client import ( DaemonClient, DaemonError, RetryPolicy, UnixSocketHTTPConnection, ) @dataclass class ServerState: post_count: int = 0 get_count: int = 0 class _Handler(BaseHTTPRequestHandler): server: _UnixHTTPServer def do_GET(self) -> None: # noqa: N802 self.server.state.get_count += 1 status, body = self.server.on_get(self.path, self.server.state.get_count) self._send(status, body) def do_POST(self) -> None: # noqa: N802 self.server.state.post_count += 1 size = int(self.headers.get("Content-Length", "0")) raw = self.rfile.read(size) if size else b"{}" payload = json.loads(raw.decode("utf-8")) status, body = self.server.on_post(self.path, payload, self.server.state.post_count) self._send(status, body) def log_message(self, format: str, *args: object) -> None: del format, args def _send(self, status: int, body: dict[str, Any]) -> None: encoded = json.dumps(body).encode("utf-8") self.send_response(status) self.send_header("Content-Type", "application/json") self.send_header("Content-Length", str(len(encoded))) self.end_headers() self.wfile.write(encoded) class _UnixHTTPServer(socketserver.UnixStreamServer): def __init__( self, socket_path: str, *, on_get: Callable[[str, int], tuple[int, dict[str, Any]]], on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]], ) -> None: self.on_get = on_get self.on_post = on_post self.state = ServerState() super().__init__(socket_path, _Handler) class FakeDaemon: def __init__( self, socket_path: str, *, on_get: Callable[[str, int], tuple[int, dict[str, Any]]], on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]], ) -> None: self._socket_path = socket_path self._server = _UnixHTTPServer(socket_path, on_get=on_get, on_post=on_post) self._thread = threading.Thread(target=self._server.serve_forever, daemon=True) def __enter__(self) -> FakeDaemon: self._thread.start() return self def __exit__(self, exc_type: object, exc: object, tb: object) -> None: del exc_type, exc, tb self._server.shutdown() self._server.server_close() with suppress(FileNotFoundError): os.unlink(self._socket_path) @pytest.fixture def socket_path() -> str: with tempfile.TemporaryDirectory() as tmp: yield str(Path(tmp) / "daemon.sock") def _client(socket_path: str, attempts: int = 3) -> DaemonClient: return DaemonClient( socket_path=socket_path, retry_policy=RetryPolicy(max_attempts=attempts, base_seconds=0.001, max_seconds=0.01), ) def test_post_json_success(socket_path: str) -> None: with FakeDaemon( socket_path, on_get=lambda _p, _a: (HTTPStatus.OK, {}), on_post=lambda _p, payload, _a: (HTTPStatus.OK, {"echo": payload["system"]}), ): response = _client(socket_path).post_json( "/v1/reservations", {"system": "x86_64-linux"}, timeout_seconds=1.0, retryable_statuses={429, 500, 503}, ) assert response == {"echo": "x86_64-linux"} def test_get_json_success(socket_path: str) -> None: with FakeDaemon( socket_path, on_get=lambda _p, _a: (HTTPStatus.OK, {"phase": "ready"}), on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}), ): response = _client(socket_path).get_json( "/v1/reservations/r1", timeout_seconds=1.0, retryable_statuses={429, 500, 503}, ) assert response == {"phase": "ready"} def test_transient_503_retries_then_raises(socket_path: str) -> None: with ( FakeDaemon( socket_path, on_get=lambda _p, _a: (HTTPStatus.SERVICE_UNAVAILABLE, {"error": "busy"}), on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}), ) as daemon, pytest.raises(DaemonError) as exc, ): _client(socket_path, attempts=3).get_json( "/v1/reservations/r1", timeout_seconds=1.0, retryable_statuses={429, 500, 502, 503, 504}, ) assert exc.value.status == HTTPStatus.SERVICE_UNAVAILABLE assert daemon._server.state.get_count == 3 def test_400_not_retried(socket_path: str) -> None: with ( FakeDaemon( socket_path, on_get=lambda _p, _a: (HTTPStatus.BAD_REQUEST, {"error": "bad"}), on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}), ) as daemon, pytest.raises(DaemonError) as exc, ): _client(socket_path, attempts=5).get_json( "/v1/reservations/r1", timeout_seconds=1.0, retryable_statuses={429, 500, 502, 503, 504}, ) assert exc.value.status == HTTPStatus.BAD_REQUEST assert daemon._server.state.get_count == 1 def test_connection_refused_retries_then_raises( socket_path: str, monkeypatch: pytest.MonkeyPatch, ) -> None: def _boom(self: UnixSocketHTTPConnection) -> None: raise ConnectionRefusedError("refused") monkeypatch.setattr(UnixSocketHTTPConnection, "connect", _boom) with pytest.raises(DaemonError): _client(socket_path, attempts=3).get_json( "/v1/reservations/r1", timeout_seconds=1.0, retryable_statuses={429, 500, 502, 503, 504}, ) def test_backoff_attempts_at_least_two(socket_path: str) -> None: with ( FakeDaemon( socket_path, on_get=lambda _p, _a: (HTTPStatus.SERVICE_UNAVAILABLE, {"error": "busy"}), on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}), ) as daemon, pytest.raises(DaemonError), ): _client(socket_path, attempts=2).get_json( "/v1/reservations/r1", timeout_seconds=1.0, retryable_statuses={429, 500, 502, 503, 504}, ) assert daemon._server.state.get_count >= 2