Add remote autoscaler daemon endpoint support
All checks were successful
buildbot/nix-eval Build done.
buildbot/nix-build Build done.
buildbot/nix-effects Build done.

This commit is contained in:
Abel Luck 2026-03-05 15:47:57 +01:00
parent 95021a4253
commit 679b5c8d07
11 changed files with 291 additions and 22 deletions

View file

@ -218,16 +218,23 @@ def main() -> None:
reconciler_thread.start()
metrics_thread.start()
socket_path = Path(config.server.socket_path)
socket_path.parent.mkdir(parents=True, exist_ok=True)
if socket_path.exists():
socket_path.unlink()
uvicorn_config = uvicorn.Config(
app=app,
uds=config.server.socket_path,
log_level=config.server.log_level.lower(),
)
if config.server.listen_port > 0:
uvicorn_config = uvicorn.Config(
app=app,
host=config.server.listen_host,
port=config.server.listen_port,
log_level=config.server.log_level.lower(),
)
else:
socket_path = Path(config.server.socket_path)
socket_path.parent.mkdir(parents=True, exist_ok=True)
if socket_path.exists():
socket_path.unlink()
uvicorn_config = uvicorn.Config(
app=app,
uds=config.server.socket_path,
log_level=config.server.log_level.lower(),
)
server = uvicorn.Server(uvicorn_config)
def _handle_signal(signum: int, _: FrameType | None) -> None:

View file

@ -2,6 +2,7 @@
from __future__ import annotations
import hmac
import logging
import uuid
from collections.abc import Callable
@ -118,6 +119,8 @@ def create_app(
app.state.runtime = runtime
app.state.haproxy = haproxy
auth_token = config.server.auth_token.strip()
@app.middleware("http")
async def request_id_middleware(request: Request, call_next: Callable) -> Response:
request.state.request_id = str(uuid.uuid4())
@ -125,6 +128,25 @@ def create_app(
response.headers["x-request-id"] = request.state.request_id
return response
@app.middleware("http")
async def auth_middleware(request: Request, call_next: Callable) -> Response:
path = request.url.path
if auth_token != "" and (path.startswith("/v1/") or path == "/metrics"):
expected = f"Bearer {auth_token}"
provided = request.headers.get("authorization", "")
if not hmac.compare_digest(provided, expected):
request_id = getattr(request.state, "request_id", str(uuid.uuid4()))
payload = ErrorResponse(
error=ErrorDetail(
code="unauthorized",
message="Missing or invalid bearer token",
retryable=False,
),
request_id=request_id,
)
return JSONResponse(status_code=401, content=payload.model_dump(mode="json"))
return await call_next(request)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
detail = exc.detail

View file

@ -13,6 +13,9 @@ class ServerConfig:
"""[server] section."""
socket_path: str = "/run/nix-builder-autoscaler/daemon.sock"
listen_host: str = "127.0.0.1"
listen_port: int = 0
auth_token: str = ""
log_level: str = "info"
db_path: str = "/var/lib/nix-builder-autoscaler/state.db"

View file

@ -8,7 +8,7 @@ from typing import Any
from fastapi.testclient import TestClient
from nix_builder_autoscaler.api import create_app
from nix_builder_autoscaler.config import AppConfig, CapacityConfig
from nix_builder_autoscaler.config import AppConfig, CapacityConfig, ServerConfig
from nix_builder_autoscaler.metrics import MetricsRegistry
from nix_builder_autoscaler.models import SlotState
from nix_builder_autoscaler.providers.clock import FakeClock
@ -18,12 +18,16 @@ from nix_builder_autoscaler.state_db import StateDB
def _make_client(
*,
reconcile_now: Any = None, # noqa: ANN401
auth_token: str = "",
) -> tuple[TestClient, StateDB, FakeClock, MetricsRegistry]:
clock = FakeClock()
db = StateDB(":memory:", clock=clock)
db.init_schema()
db.init_slots("slot", 3, "x86_64-linux", "all")
config = AppConfig(capacity=CapacityConfig(reservation_ttl_seconds=1200))
config = AppConfig(
server=ServerConfig(auth_token=auth_token),
capacity=CapacityConfig(reservation_ttl_seconds=1200),
)
metrics = MetricsRegistry()
app = create_app(db, config, clock, metrics, reconcile_now=reconcile_now)
return TestClient(app), db, clock, metrics
@ -245,3 +249,20 @@ def test_admin_reconcile_now_success() -> None:
assert response.json()["status"] == "accepted"
assert response.json()["triggered"] is True
assert called["value"] is True
def test_auth_token_required_for_v1_when_configured() -> None:
client, _, _, _ = _make_client(auth_token="test-token")
response = client.post("/v1/reservations", json={"system": "x86_64-linux", "reason": "test"})
assert response.status_code == 401
assert response.json()["error"]["code"] == "unauthorized"
def test_auth_token_allows_v1_when_header_matches() -> None:
client, _, _, _ = _make_client(auth_token="test-token")
response = client.post(
"/v1/reservations",
json={"system": "x86_64-linux", "reason": "test"},
headers={"Authorization": "Bearer test-token"},
)
assert response.status_code == 200

View file

@ -7,6 +7,7 @@ import socket
import time
from dataclasses import dataclass
from typing import Any
from urllib.parse import urlparse
@dataclass(frozen=True)
@ -45,10 +46,39 @@ class UnixSocketHTTPConnection(http.client.HTTPConnection):
class DaemonClient:
def __init__(self, socket_path: str, retry_policy: RetryPolicy) -> None:
def __init__(
self,
*,
retry_policy: RetryPolicy,
socket_path: str | None = None,
base_url: str | None = None,
auth_token: str | None = None,
) -> None:
if (socket_path is None) == (base_url is None):
raise ValueError("exactly one of socket_path or base_url must be set")
self._socket_path = socket_path
self._base_url = base_url
self._auth_token = auth_token.strip() if auth_token is not None else None
self._retry = retry_policy
self._base_path = ""
self._http_scheme = "http"
self._http_host = "localhost"
self._http_port = 80
if base_url is not None:
parsed = urlparse(base_url)
if parsed.scheme not in {"http", "https"}:
raise ValueError("base_url must use http or https scheme")
if parsed.hostname is None:
raise ValueError("base_url must include a hostname")
self._http_scheme = parsed.scheme
self._http_host = parsed.hostname
if parsed.port is not None:
self._http_port = parsed.port
elif parsed.scheme == "https":
self._http_port = 443
self._base_path = parsed.path.rstrip("/")
def post_json(
self,
path: str,
@ -136,12 +166,31 @@ class DaemonClient:
timeout_seconds: float,
payload: bytes | None,
) -> tuple[bytes, int]:
conn = UnixSocketHTTPConnection(self._socket_path, timeout=timeout_seconds)
request_path = path if path.startswith("/") else f"/{path}"
if self._base_path != "":
request_path = f"{self._base_path}{request_path}"
headers = {"Accept": "application/json"}
if payload is not None:
headers["Content-Type"] = "application/json"
if self._auth_token is not None:
headers["Authorization"] = f"Bearer {self._auth_token}"
conn: http.client.HTTPConnection
if self._socket_path is not None:
conn = UnixSocketHTTPConnection(self._socket_path, timeout=timeout_seconds)
else:
conn = (
http.client.HTTPSConnection(
self._http_host, self._http_port, timeout=timeout_seconds
)
if self._http_scheme == "https"
else http.client.HTTPConnection(
self._http_host, self._http_port, timeout=timeout_seconds
)
)
try:
conn.request(method=method, url=path, body=payload, headers=headers)
conn.request(method=method, url=request_path, body=payload, headers=headers)
response = conn.getresponse()
data = response.read()
return data, response.status

View file

@ -34,6 +34,8 @@ class AutoscaleConfigurator(ConfiguratorBase):
gate = CapacityGateStep(
name="Ensure remote builder capacity",
daemon_socket=self.settings.daemon_socket,
daemon_url=self.settings.daemon_url,
daemon_auth_token=self.settings.daemon_auth_token,
system_property=self.settings.system_property,
default_system=self.settings.default_system,
reserve_timeout_seconds=self.settings.reserve_timeout_seconds,
@ -52,6 +54,8 @@ class AutoscaleConfigurator(ConfiguratorBase):
CapacityReleaseStep(
name="Release autoscaler reservation",
daemon_socket=self.settings.daemon_socket,
daemon_url=self.settings.daemon_url,
daemon_auth_token=self.settings.daemon_auth_token,
retry_max_attempts=self.settings.retry_max_attempts,
retry_base_seconds=self.settings.retry_base_seconds,
retry_max_seconds=self.settings.retry_max_seconds,

View file

@ -3,7 +3,9 @@ from dataclasses import dataclass
@dataclass(frozen=True)
class AutoscaleSettings:
daemon_socket: str
daemon_socket: str | None = "/run/nix-builder-autoscaler/daemon.sock"
daemon_url: str | None = None
daemon_auth_token: str | None = None
system_property: str = "system"
default_system: str = "x86_64-linux"
reserve_timeout_seconds: int = 900

View file

@ -20,7 +20,9 @@ class CapacityGateStep(buildstep.BuildStep):
def __init__(
self,
*,
daemon_socket: str,
daemon_socket: str | None = None,
daemon_url: str | None = None,
daemon_auth_token: str | None = None,
system_property: str,
default_system: str,
reserve_timeout_seconds: int,
@ -36,12 +38,14 @@ class CapacityGateStep(buildstep.BuildStep):
self._reserve_timeout_seconds = reserve_timeout_seconds
self._poll_interval_seconds = poll_interval_seconds
self._client = DaemonClient(
socket_path=daemon_socket,
retry_policy=RetryPolicy(
max_attempts=retry_max_attempts,
base_seconds=retry_base_seconds,
max_seconds=retry_max_seconds,
),
socket_path=daemon_socket,
base_url=daemon_url,
auth_token=daemon_auth_token,
)
def _determine_system(self) -> str:
@ -155,7 +159,9 @@ class CapacityReleaseStep(buildstep.BuildStep):
def __init__(
self,
*,
daemon_socket: str,
daemon_socket: str | None = None,
daemon_url: str | None = None,
daemon_auth_token: str | None = None,
retry_max_attempts: int,
retry_base_seconds: float,
retry_max_seconds: float = 5.0,
@ -163,12 +169,14 @@ class CapacityReleaseStep(buildstep.BuildStep):
) -> None:
super().__init__(**kwargs)
self._client = DaemonClient(
socket_path=daemon_socket,
retry_policy=RetryPolicy(
max_attempts=retry_max_attempts,
base_seconds=retry_base_seconds,
max_seconds=retry_max_seconds,
),
socket_path=daemon_socket,
base_url=daemon_url,
auth_token=daemon_auth_token,
)
def run(self) -> defer.Deferred[int]:

View file

@ -9,7 +9,7 @@ from collections.abc import Callable
from contextlib import suppress
from dataclasses import dataclass
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Any
@ -95,6 +95,87 @@ class FakeDaemon:
os.unlink(self._socket_path)
class _TCPHTTPServer(ThreadingHTTPServer):
def __init__(
self,
*,
on_get: Callable[[str, int], tuple[int, dict[str, Any]]],
on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]],
expected_auth: str | None = None,
) -> None:
self.on_get = on_get
self.on_post = on_post
self.expected_auth = expected_auth
self.state = ServerState()
super().__init__(("127.0.0.1", 0), _TCPHandler)
class _TCPHandler(BaseHTTPRequestHandler):
server: _TCPHTTPServer
def _authorize(self) -> bool:
expected = self.server.expected_auth
if expected is None:
return True
return self.headers.get("Authorization") == f"Bearer {expected}"
def do_GET(self) -> None: # noqa: N802
if not self._authorize():
self._send(HTTPStatus.UNAUTHORIZED, {"error": "unauthorized"})
return
self.server.state.get_count += 1
status, body = self.server.on_get(self.path, self.server.state.get_count)
self._send(status, body)
def do_POST(self) -> None: # noqa: N802
if not self._authorize():
self._send(HTTPStatus.UNAUTHORIZED, {"error": "unauthorized"})
return
self.server.state.post_count += 1
size = int(self.headers.get("Content-Length", "0"))
raw = self.rfile.read(size) if size else b"{}"
payload = json.loads(raw.decode("utf-8"))
status, body = self.server.on_post(self.path, payload, self.server.state.post_count)
self._send(status, body)
def log_message(self, format: str, *args: object) -> None:
del format, args
def _send(self, status: int, body: dict[str, Any]) -> None:
encoded = json.dumps(body).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(encoded)))
self.end_headers()
self.wfile.write(encoded)
class FakeTCPDaemon:
def __init__(
self,
*,
on_get: Callable[[str, int], tuple[int, dict[str, Any]]],
on_post: Callable[[str, dict[str, Any], int], tuple[int, dict[str, Any]]],
expected_auth: str | None = None,
) -> None:
self._server = _TCPHTTPServer(on_get=on_get, on_post=on_post, expected_auth=expected_auth)
self._thread = threading.Thread(target=self._server.serve_forever, daemon=True)
@property
def base_url(self) -> str:
host, port = self._server.server_address
return f"http://{host}:{port}"
def __enter__(self) -> FakeTCPDaemon:
self._thread.start()
return self
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
del exc_type, exc, tb
self._server.shutdown()
self._server.server_close()
@pytest.fixture
def socket_path() -> str:
with tempfile.TemporaryDirectory() as tmp:
@ -210,3 +291,23 @@ def test_backoff_attempts_at_least_two(socket_path: str) -> None:
)
assert daemon._server.state.get_count >= 2
def test_get_json_success_over_http_with_auth() -> None:
with FakeTCPDaemon(
on_get=lambda _p, _a: (HTTPStatus.OK, {"phase": "ready"}),
on_post=lambda _p, _payload, _a: (HTTPStatus.OK, {}),
expected_auth="test-token",
) as daemon:
client = DaemonClient(
base_url=daemon.base_url,
auth_token="test-token",
retry_policy=RetryPolicy(max_attempts=2, base_seconds=0.001, max_seconds=0.01),
)
response = client.get_json(
"/v1/reservations/r1",
timeout_seconds=1.0,
retryable_statuses={429, 500, 503},
)
assert response == {"phase": "ready"}

View file

@ -25,6 +25,18 @@ in
description = "Autoscaler daemon Unix socket path for Buildbot gate/release steps.";
};
daemonUrl = lib.mkOption {
type = lib.types.nullOr lib.types.str;
default = null;
description = "Optional autoscaler daemon HTTP(S) endpoint URL for remote gate/release calls.";
};
daemonAuthTokenFile = lib.mkOption {
type = lib.types.nullOr lib.types.str;
default = null;
description = "Optional file containing bearer token for authenticated daemon API calls.";
};
defaultSystem = lib.mkOption {
type = lib.types.str;
default = "x86_64-linux";
@ -131,6 +143,10 @@ in
assertion = cfg.builderClusterHost != null;
message = "services.buildbot-nix.nix-build-autoscaler.builderClusterHost must be set.";
}
{
assertion = cfg.daemonUrl != null || cfg.daemonSocket != "";
message = "services.buildbot-nix.nix-build-autoscaler requires either daemonUrl or daemonSocket.";
}
];
services.buildbot-master.pythonPackages = ps: [
@ -149,6 +165,7 @@ in
];
services.buildbot-master.extraImports = ''
import pathlib
from buildbot_autoscale_ext.configurator import AutoscaleConfigurator
from buildbot_autoscale_ext.settings import AutoscaleSettings
'';
@ -157,7 +174,14 @@ in
''
AutoscaleConfigurator(
AutoscaleSettings(
daemon_socket="${cfg.daemonSocket}",
daemon_socket=${if cfg.daemonUrl == null then ''"${cfg.daemonSocket}"'' else "None"},
daemon_url=${if cfg.daemonUrl != null then ''"${cfg.daemonUrl}"'' else "None"},
daemon_auth_token=${
if cfg.daemonAuthTokenFile != null then
''pathlib.Path("${cfg.daemonAuthTokenFile}").read_text(encoding="utf-8").strip()''
else
"None"
},
default_system="${cfg.defaultSystem}",
reserve_timeout_seconds=${toString cfg.reserveTimeoutSeconds},
poll_interval_seconds=${toString cfg.pollIntervalSeconds},

View file

@ -45,6 +45,24 @@ in
description = "Unix socket path exposed by the autoscaler API server.";
};
listenHost = lib.mkOption {
type = lib.types.str;
default = "127.0.0.1";
description = "TCP listen host for the autoscaler API server when listenPort is set.";
};
listenPort = lib.mkOption {
type = lib.types.nullOr lib.types.int;
default = null;
description = "Optional TCP listen port for the autoscaler API server. Null keeps Unix socket mode.";
};
authTokenFile = lib.mkOption {
type = lib.types.nullOr lib.types.str;
default = null;
description = "Optional file containing bearer token required for /v1 and /metrics API requests.";
};
logLevel = lib.mkOption {
type = lib.types.str;
default = "info";
@ -287,6 +305,10 @@ in
assertion = !cfg.capacity.nestedVirtualization || cfg.aws.onDemandLaunchTemplateIdFile != null;
message = "services.nix-builder-autoscaler.aws.onDemandLaunchTemplateIdFile must be set when capacity.nestedVirtualization is true.";
}
{
assertion = cfg.listenPort == null || (cfg.listenPort >= 1 && cfg.listenPort <= 65535);
message = "services.nix-builder-autoscaler.listenPort must be null or a TCP port between 1 and 65535.";
}
];
environment.systemPackages = [ cfg.package ];
@ -338,10 +360,16 @@ in
${lib.optionalString (cfg.aws.assumeRoleArnFile != null) ''
assume_role_arn="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.assumeRoleArnFile})"
''}
${lib.optionalString (cfg.authTokenFile != null) ''
auth_token="$(tr -d '\n' < ${lib.escapeShellArg cfg.authTokenFile})"
''}
cat > ${generatedConfigPath} <<EOF
[server]
socket_path = "${cfg.socketPath}"
listen_host = "${cfg.listenHost}"
listen_port = ${toString (if cfg.listenPort != null then cfg.listenPort else 0)}
${lib.optionalString (cfg.authTokenFile != null) ''auth_token = "$auth_token"''}
log_level = "${cfg.logLevel}"
db_path = "${cfg.dbPath}"