from __future__ import annotations import json import os import socketserver import tempfile import threading from collections.abc import Callable from contextlib import suppress from dataclasses import dataclass from http import HTTPStatus from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from pathlib import Path from typing import Any import pytest from buildbot_autoscale_ext.client import ( DaemonClient, DaemonError, RetryPolicy, UnixSocketHTTPConnection, ) @dataclass class ServerState: post_count: int = 0 get_count: int = 0 class _Handler(BaseHTTPRequestHandler): server: _UnixHTTPServer def do_GET(self) -> None: # noqa: N802 self.server.state.get_count += 1 status, body = self.server.on_get(self.path, self.server.state.get_count) self._send(status, body) def do_POST(self) -> None: # noqa: N802 self.server.state.post_count += 1 size = int(self.headers.get("Content-Length", "0")) raw = self.rfile.read(size) if size else b"{}" payload = json.loads(raw.decode("utf-8")) status, body = self.server.on_post(self.path, payload, self.server.state.post_count) self._send(status, body) def log_message(self, format: str, *args: object) -> None: del format, args def _send(self, status: int, body: dict[str, Any]) -> None: encoded = json.dumps(body).encode("utf-8") self.send_response(status) self.send_header("Content-Type", "application/json") self.send_header("Content-Length", str(len(encoded))) self.end_headers() self.wfile.write(encoded) class _UnixHTTPServer(socketserver.UnixStreamServer): def __init__( self, socket_path: str, *, on_get: Callable[[str, int], tuple[int, dict[str, Any]]], on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]], ) -> None: self.on_get = on_get self.on_post = on_post self.state = ServerState() super().__init__(socket_path, _Handler) class FakeDaemon: def __init__( self, socket_path: str, *, on_get: Callable[[str, int], tuple[int, dict[str, Any]]], on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]], ) -> None: self._socket_path = socket_path self._server = _UnixHTTPServer(socket_path, on_get=on_get, on_post=on_post) self._thread = threading.Thread(target=self._server.serve_forever, daemon=True) def __enter__(self) -> FakeDaemon: self._thread.start() return self def __exit__(self, exc_type: object, exc: object, tb: object) -> None: del exc_type, exc, tb self._server.shutdown() self._server.server_close() with suppress(FileNotFoundError): os.unlink(self._socket_path) class _TCPHTTPServer(ThreadingHTTPServer): def __init__( self, *, on_get: Callable[[str, int], tuple[int, dict[str, Any]]], on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]], expected_auth: str | None = None, ) -> None: self.on_get = on_get self.on_post = on_post self.expected_auth = expected_auth self.state = ServerState() super().__init__(("127.0.0.1", 0), _TCPHandler) class _TCPHandler(BaseHTTPRequestHandler): server: _TCPHTTPServer def _authorize(self) -> bool: expected = self.server.expected_auth if expected is None: return True return self.headers.get("Authorization") == f"Bearer {expected}" def do_GET(self) -> None: # noqa: N802 if not self._authorize(): self._send(HTTPStatus.UNAUTHORIZED, {"error": "unauthorized"}) return self.server.state.get_count += 1 status, body = self.server.on_get(self.path, self.server.state.get_count) self._send(status, body) def do_POST(self) -> None: # noqa: N802 if not self._authorize(): self._send(HTTPStatus.UNAUTHORIZED, {"error": "unauthorized"}) return self.server.state.post_count += 1 size = int(self.headers.get("Content-Length", "0")) raw = self.rfile.read(size) if size else b"{}" payload = json.loads(raw.decode("utf-8")) status, body = self.server.on_post(self.path, payload, self.server.state.post_count) self._send(status, body) def log_message(self, format: str, *args: object) -> None: del format, args def _send(self, status: int, body: dict[str, Any]) -> None: encoded = json.dumps(body).encode("utf-8") self.send_response(status) self.send_header("Content-Type", "application/json") self.send_header("Content-Length", str(len(encoded))) self.end_headers() self.wfile.write(encoded) class FakeTCPDaemon: def __init__( self, *, on_get: Callable[[str, int], tuple[int, dict[str, Any]]], on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]], expected_auth: str | None = None, ) -> None: self._server = _TCPHTTPServer(on_get=on_get, on_post=on_post, expected_auth=expected_auth) self._thread = threading.Thread(target=self._server.serve_forever, daemon=True) @property def base_url(self) -> str: host, port = self._server.server_address return f"http://{host}:{port}" def __enter__(self) -> FakeTCPDaemon: self._thread.start() return self def __exit__(self, exc_type: object, exc: object, tb: object) -> None: del exc_type, exc, tb self._server.shutdown() self._server.server_close() @pytest.fixture def socket_path() -> str: with tempfile.TemporaryDirectory() as tmp: yield str(Path(tmp) / "daemon.sock") def _client(socket_path: str, attempts: int = 3) -> DaemonClient: return DaemonClient( socket_path=socket_path, retry_policy=RetryPolicy(max_attempts=attempts, base_seconds=0.001, max_seconds=0.01), ) def test_post_json_success(socket_path: str) -> None: with FakeDaemon( socket_path, on_get=lambda _p, _a: (HTTPStatus.OK, {}), on_post=lambda _p, payload, _a: (HTTPStatus.OK, {"echo": payload["system"]}), ): response = _client(socket_path).post_json( "/v1/reservations", {"system": "x86_64-linux"}, timeout_seconds=1.0, retryable_statuses={429, 500, 503}, ) assert response == {"echo": "x86_64-linux"} def test_get_json_success(socket_path: str) -> None: with FakeDaemon( socket_path, on_get=lambda _p, _a: (HTTPStatus.OK, {"phase": "ready"}), on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}), ): response = _client(socket_path).get_json( "/v1/reservations/r1", timeout_seconds=1.0, retryable_statuses={429, 500, 503}, ) assert response == {"phase": "ready"} def test_transient_503_retries_then_raises(socket_path: str) -> None: with ( FakeDaemon( socket_path, on_get=lambda _p, _a: (HTTPStatus.SERVICE_UNAVAILABLE, {"error": "busy"}), on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}), ) as daemon, pytest.raises(DaemonError) as exc, ): _client(socket_path, attempts=3).get_json( "/v1/reservations/r1", timeout_seconds=1.0, retryable_statuses={429, 500, 502, 503, 504}, ) assert exc.value.status == HTTPStatus.SERVICE_UNAVAILABLE assert daemon._server.state.get_count == 3 def test_400_not_retried(socket_path: str) -> None: with ( FakeDaemon( socket_path, on_get=lambda _p, _a: (HTTPStatus.BAD_REQUEST, {"error": "bad"}), on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}), ) as daemon, pytest.raises(DaemonError) as exc, ): _client(socket_path, attempts=5).get_json( "/v1/reservations/r1", timeout_seconds=1.0, retryable_statuses={429, 500, 502, 503, 504}, ) assert exc.value.status == HTTPStatus.BAD_REQUEST assert daemon._server.state.get_count == 1 def test_connection_refused_retries_then_raises( socket_path: str, monkeypatch: pytest.MonkeyPatch, ) -> None: def _boom(self: UnixSocketHTTPConnection) -> None: raise ConnectionRefusedError("refused") monkeypatch.setattr(UnixSocketHTTPConnection, "connect", _boom) with pytest.raises(DaemonError): _client(socket_path, attempts=3).get_json( "/v1/reservations/r1", timeout_seconds=1.0, retryable_statuses={429, 500, 502, 503, 504}, ) def test_backoff_attempts_at_least_two(socket_path: str) -> None: with ( FakeDaemon( socket_path, on_get=lambda _p, _a: (HTTPStatus.SERVICE_UNAVAILABLE, {"error": "busy"}), on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}), ) as daemon, pytest.raises(DaemonError), ): _client(socket_path, attempts=2).get_json( "/v1/reservations/r1", timeout_seconds=1.0, retryable_statuses={429, 500, 502, 503, 504}, ) assert daemon._server.state.get_count >= 2 def test_get_json_success_over_http_with_auth() -> None: with FakeTCPDaemon( on_get=lambda _p, _a: (HTTPStatus.OK, {"phase": "ready"}), on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}), expected_auth="test-token", ) as daemon: client = DaemonClient( base_url=daemon.base_url, auth_token="test-token", retry_policy=RetryPolicy(max_attempts=2, base_seconds=0.001, max_seconds=0.01), ) response = client.get_json( "/v1/reservations/r1", timeout_seconds=1.0, retryable_statuses={429, 500, 503}, ) assert response == {"phase": "ready"}