nix-builder-autoscaler/buildbot-ext/buildbot_autoscale_ext/tests/test_client.py
Abel Luck 679b5c8d07
All checks were successful
buildbot/nix-eval Build done.
buildbot/nix-build Build done.
buildbot/nix-effects Build done.
Add remote autoscaler daemon endpoint support
2026-03-05 15:47:57 +01:00

313 lines
9.9 KiB
Python

from __future__ import annotations
import json
import os
import socketserver
import tempfile
import threading
from collections.abc import Callable
from contextlib import suppress
from dataclasses import dataclass
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Any
import pytest
from buildbot_autoscale_ext.client import (
DaemonClient,
DaemonError,
RetryPolicy,
UnixSocketHTTPConnection,
)
@dataclass
class ServerState:
post_count: int = 0
get_count: int = 0
class _Handler(BaseHTTPRequestHandler):
server: _UnixHTTPServer
def do_GET(self) -> None: # noqa: N802
self.server.state.get_count += 1
status, body = self.server.on_get(self.path, self.server.state.get_count)
self._send(status, body)
def do_POST(self) -> None: # noqa: N802
self.server.state.post_count += 1
size = int(self.headers.get("Content-Length", "0"))
raw = self.rfile.read(size) if size else b"{}"
payload = json.loads(raw.decode("utf-8"))
status, body = self.server.on_post(self.path, payload, self.server.state.post_count)
self._send(status, body)
def log_message(self, format: str, *args: object) -> None:
del format, args
def _send(self, status: int, body: dict[str, Any]) -> None:
encoded = json.dumps(body).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(encoded)))
self.end_headers()
self.wfile.write(encoded)
class _UnixHTTPServer(socketserver.UnixStreamServer):
def __init__(
self,
socket_path: str,
*,
on_get: Callable[[str, int], tuple[int, dict[str, Any]]],
on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]],
) -> None:
self.on_get = on_get
self.on_post = on_post
self.state = ServerState()
super().__init__(socket_path, _Handler)
class FakeDaemon:
def __init__(
self,
socket_path: str,
*,
on_get: Callable[[str, int], tuple[int, dict[str, Any]]],
on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]],
) -> None:
self._socket_path = socket_path
self._server = _UnixHTTPServer(socket_path, on_get=on_get, on_post=on_post)
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
def __enter__(self) -> FakeDaemon:
self._thread.start()
return self
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
del exc_type, exc, tb
self._server.shutdown()
self._server.server_close()
with suppress(FileNotFoundError):
os.unlink(self._socket_path)
class _TCPHTTPServer(ThreadingHTTPServer):
def __init__(
self,
*,
on_get: Callable[[str, int], tuple[int, dict[str, Any]]],
on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]],
expected_auth: str | None = None,
) -> None:
self.on_get = on_get
self.on_post = on_post
self.expected_auth = expected_auth
self.state = ServerState()
super().__init__(("127.0.0.1", 0), _TCPHandler)
class _TCPHandler(BaseHTTPRequestHandler):
server: _TCPHTTPServer
def _authorize(self) -> bool:
expected = self.server.expected_auth
if expected is None:
return True
return self.headers.get("Authorization") == f"Bearer {expected}"
def do_GET(self) -> None: # noqa: N802
if not self._authorize():
self._send(HTTPStatus.UNAUTHORIZED, {"error": "unauthorized"})
return
self.server.state.get_count += 1
status, body = self.server.on_get(self.path, self.server.state.get_count)
self._send(status, body)
def do_POST(self) -> None: # noqa: N802
if not self._authorize():
self._send(HTTPStatus.UNAUTHORIZED, {"error": "unauthorized"})
return
self.server.state.post_count += 1
size = int(self.headers.get("Content-Length", "0"))
raw = self.rfile.read(size) if size else b"{}"
payload = json.loads(raw.decode("utf-8"))
status, body = self.server.on_post(self.path, payload, self.server.state.post_count)
self._send(status, body)
def log_message(self, format: str, *args: object) -> None:
del format, args
def _send(self, status: int, body: dict[str, Any]) -> None:
encoded = json.dumps(body).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(encoded)))
self.end_headers()
self.wfile.write(encoded)
class FakeTCPDaemon:
def __init__(
self,
*,
on_get: Callable[[str, int], tuple[int, dict[str, Any]]],
on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]],
expected_auth: str | None = None,
) -> None:
self._server = _TCPHTTPServer(on_get=on_get, on_post=on_post, expected_auth=expected_auth)
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
@property
def base_url(self) -> str:
host, port = self._server.server_address
return f"http://{host}:{port}"
def __enter__(self) -> FakeTCPDaemon:
self._thread.start()
return self
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
del exc_type, exc, tb
self._server.shutdown()
self._server.server_close()
@pytest.fixture
def socket_path() -> str:
with tempfile.TemporaryDirectory() as tmp:
yield str(Path(tmp) / "daemon.sock")
def _client(socket_path: str, attempts: int = 3) -> DaemonClient:
return DaemonClient(
socket_path=socket_path,
retry_policy=RetryPolicy(max_attempts=attempts, base_seconds=0.001, max_seconds=0.01),
)
def test_post_json_success(socket_path: str) -> None:
with FakeDaemon(
socket_path,
on_get=lambda _p, _a: (HTTPStatus.OK, {}),
on_post=lambda _p, payload, _a: (HTTPStatus.OK, {"echo": payload["system"]}),
):
response = _client(socket_path).post_json(
"/v1/reservations",
{"system": "x86_64-linux"},
timeout_seconds=1.0,
retryable_statuses={429, 500, 503},
)
assert response == {"echo": "x86_64-linux"}
def test_get_json_success(socket_path: str) -> None:
with FakeDaemon(
socket_path,
on_get=lambda _p, _a: (HTTPStatus.OK, {"phase": "ready"}),
on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}),
):
response = _client(socket_path).get_json(
"/v1/reservations/r1",
timeout_seconds=1.0,
retryable_statuses={429, 500, 503},
)
assert response == {"phase": "ready"}
def test_transient_503_retries_then_raises(socket_path: str) -> None:
with (
FakeDaemon(
socket_path,
on_get=lambda _p, _a: (HTTPStatus.SERVICE_UNAVAILABLE, {"error": "busy"}),
on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}),
) as daemon,
pytest.raises(DaemonError) as exc,
):
_client(socket_path, attempts=3).get_json(
"/v1/reservations/r1",
timeout_seconds=1.0,
retryable_statuses={429, 500, 502, 503, 504},
)
assert exc.value.status == HTTPStatus.SERVICE_UNAVAILABLE
assert daemon._server.state.get_count == 3
def test_400_not_retried(socket_path: str) -> None:
with (
FakeDaemon(
socket_path,
on_get=lambda _p, _a: (HTTPStatus.BAD_REQUEST, {"error": "bad"}),
on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}),
) as daemon,
pytest.raises(DaemonError) as exc,
):
_client(socket_path, attempts=5).get_json(
"/v1/reservations/r1",
timeout_seconds=1.0,
retryable_statuses={429, 500, 502, 503, 504},
)
assert exc.value.status == HTTPStatus.BAD_REQUEST
assert daemon._server.state.get_count == 1
def test_connection_refused_retries_then_raises(
socket_path: str,
monkeypatch: pytest.MonkeyPatch,
) -> None:
def _boom(self: UnixSocketHTTPConnection) -> None:
raise ConnectionRefusedError("refused")
monkeypatch.setattr(UnixSocketHTTPConnection, "connect", _boom)
with pytest.raises(DaemonError):
_client(socket_path, attempts=3).get_json(
"/v1/reservations/r1",
timeout_seconds=1.0,
retryable_statuses={429, 500, 502, 503, 504},
)
def test_backoff_attempts_at_least_two(socket_path: str) -> None:
with (
FakeDaemon(
socket_path,
on_get=lambda _p, _a: (HTTPStatus.SERVICE_UNAVAILABLE, {"error": "busy"}),
on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}),
) as daemon,
pytest.raises(DaemonError),
):
_client(socket_path, attempts=2).get_json(
"/v1/reservations/r1",
timeout_seconds=1.0,
retryable_statuses={429, 500, 502, 503, 504},
)
assert daemon._server.state.get_count >= 2
def test_get_json_success_over_http_with_auth() -> None:
with FakeTCPDaemon(
on_get=lambda _p, _a: (HTTPStatus.OK, {"phase": "ready"}),
on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}),
expected_auth="test-token",
) as daemon:
client = DaemonClient(
base_url=daemon.base_url,
auth_token="test-token",
retry_policy=RetryPolicy(max_attempts=2, base_seconds=0.001, max_seconds=0.01),
)
response = client.get_json(
"/v1/reservations/r1",
timeout_seconds=1.0,
retryable_statuses={429, 500, 503},
)
assert response == {"phase": "ready"}