Add remote autoscaler daemon endpoint support
All checks were successful
buildbot/nix-eval Build done.
buildbot/nix-build Build done.
buildbot/nix-effects Build done.

This commit is contained in:
Abel Luck 2026-03-05 15:47:57 +01:00
parent 95021a4253
commit 679b5c8d07
11 changed files with 291 additions and 22 deletions

View file

@ -9,7 +9,7 @@ from collections.abc import Callable
from contextlib import suppress
from dataclasses import dataclass
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Any
@ -95,6 +95,87 @@ class FakeDaemon:
os.unlink(self._socket_path)
class _TCPHTTPServer(ThreadingHTTPServer):
def __init__(
self,
*,
on_get: Callable[[str, int], tuple[int, dict[str, Any]]],
on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]],
expected_auth: str | None = None,
) -> None:
self.on_get = on_get
self.on_post = on_post
self.expected_auth = expected_auth
self.state = ServerState()
super().__init__(("127.0.0.1", 0), _TCPHandler)
class _TCPHandler(BaseHTTPRequestHandler):
server: _TCPHTTPServer
def _authorize(self) -> bool:
expected = self.server.expected_auth
if expected is None:
return True
return self.headers.get("Authorization") == f"Bearer {expected}"
def do_GET(self) -> None: # noqa: N802
if not self._authorize():
self._send(HTTPStatus.UNAUTHORIZED, {"error": "unauthorized"})
return
self.server.state.get_count += 1
status, body = self.server.on_get(self.path, self.server.state.get_count)
self._send(status, body)
def do_POST(self) -> None: # noqa: N802
if not self._authorize():
self._send(HTTPStatus.UNAUTHORIZED, {"error": "unauthorized"})
return
self.server.state.post_count += 1
size = int(self.headers.get("Content-Length", "0"))
raw = self.rfile.read(size) if size else b"{}"
payload = json.loads(raw.decode("utf-8"))
status, body = self.server.on_post(self.path, payload, self.server.state.post_count)
self._send(status, body)
def log_message(self, format: str, *args: object) -> None:
del format, args
def _send(self, status: int, body: dict[str, Any]) -> None:
encoded = json.dumps(body).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(encoded)))
self.end_headers()
self.wfile.write(encoded)
class FakeTCPDaemon:
def __init__(
self,
*,
on_get: Callable[[str, int], tuple[int, dict[str, Any]]],
on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]],
expected_auth: str | None = None,
) -> None:
self._server = _TCPHTTPServer(on_get=on_get, on_post=on_post, expected_auth=expected_auth)
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
@property
def base_url(self) -> str:
host, port = self._server.server_address
return f"http://{host}:{port}"
def __enter__(self) -> FakeTCPDaemon:
self._thread.start()
return self
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
del exc_type, exc, tb
self._server.shutdown()
self._server.server_close()
@pytest.fixture
def socket_path() -> str:
with tempfile.TemporaryDirectory() as tmp:
@ -210,3 +291,23 @@ def test_backoff_attempts_at_least_two(socket_path: str) -> None:
)
assert daemon._server.state.get_count >= 2
def test_get_json_success_over_http_with_auth() -> None:
with FakeTCPDaemon(
on_get=lambda _p, _a: (HTTPStatus.OK, {"phase": "ready"}),
on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}),
expected_auth="test-token",
) as daemon:
client = DaemonClient(
base_url=daemon.base_url,
auth_token="test-token",
retry_policy=RetryPolicy(max_attempts=2, base_seconds=0.001, max_seconds=0.01),
)
response = client.get_json(
"/v1/reservations/r1",
timeout_seconds=1.0,
retryable_statuses={429, 500, 503},
)
assert response == {"phase": "ready"}