From 2f0fffa905d94fbc15d95ed39e0d36c5f4dd23e8 Mon Sep 17 00:00:00 2001 From: Abel Luck Date: Fri, 27 Feb 2026 13:48:52 +0100 Subject: [PATCH] agent: complete plan05 closeout --- agent/nix_builder_autoscaler/__main__.py | 98 ++- agent/nix_builder_autoscaler/api.py | 93 +++ .../bootstrap/userdata.py | 11 +- agent/nix_builder_autoscaler/cli.py | 45 +- agent/nix_builder_autoscaler/reconciler.py | 94 ++- agent/nix_builder_autoscaler/runtime/ec2.py | 120 +++- agent/nix_builder_autoscaler/scheduler.py | 11 +- agent/nix_builder_autoscaler/state_db.py | 561 +++++++++--------- .../tests/integration/test_end_to_end_fake.py | 408 ++++++++++++- .../tests/test_reservations_api.py | 87 ++- .../tests/test_runtime_ec2.py | 129 ++++ flake.nix | 3 +- 12 files changed, 1347 insertions(+), 313 deletions(-) diff --git a/agent/nix_builder_autoscaler/__main__.py b/agent/nix_builder_autoscaler/__main__.py index 0bf3f32..8a32dbc 100644 --- a/agent/nix_builder_autoscaler/__main__.py +++ b/agent/nix_builder_autoscaler/__main__.py @@ -6,6 +6,7 @@ import argparse import logging import signal import threading +import time from pathlib import Path from types import FrameType @@ -25,6 +26,29 @@ from .state_db import StateDB log = logging.getLogger(__name__) +class LoopHealth: + """Thread-safe last-success timestamps for daemon loops.""" + + def __init__(self) -> None: + self._lock = threading.Lock() + self._last_success: dict[str, float] = {} + + def mark_success(self, loop_name: str) -> None: + with self._lock: + self._last_success[loop_name] = time.monotonic() + + def is_fresh(self, loop_name: str, max_age_seconds: float) -> bool: + with self._lock: + last = self._last_success.get(loop_name) + if last is None: + return False + return (time.monotonic() - last) <= max_age_seconds + + +def _max_staleness(interval_seconds: float) -> float: + return max(interval_seconds * 3.0, 15.0) + + def _scheduler_loop( db: StateDB, runtime: EC2Runtime, @@ -32,10 +56,12 @@ def _scheduler_loop( clock: SystemClock, metrics: MetricsRegistry, stop_event: threading.Event, + loop_health: LoopHealth, ) -> None: while not stop_event.is_set(): try: scheduling_tick(db, runtime, config, clock, metrics) + loop_health.mark_success("scheduler") except Exception: log.exception("scheduler_tick_failed") stop_event.wait(config.scheduler.tick_seconds) @@ -45,15 +71,36 @@ def _reconciler_loop( reconciler: Reconciler, config: AppConfig, stop_event: threading.Event, + loop_health: LoopHealth, + reconcile_lock: threading.Lock, ) -> None: while not stop_event.is_set(): try: - reconciler.tick() + with reconcile_lock: + reconciler.tick() + loop_health.mark_success("reconciler") except Exception: log.exception("reconciler_tick_failed") stop_event.wait(config.scheduler.reconcile_seconds) +def _metrics_health_loop( + metrics: MetricsRegistry, + stop_event: threading.Event, + loop_health: LoopHealth, + interval_seconds: float, +) -> None: + while not stop_event.is_set(): + try: + metrics.gauge("autoscaler_loop_up", {"loop": "scheduler"}, 1.0) + metrics.gauge("autoscaler_loop_up", {"loop": "reconciler"}, 1.0) + metrics.gauge("autoscaler_loop_up", {"loop": "metrics"}, 1.0) + loop_health.mark_success("metrics") + except Exception: + log.exception("metrics_health_tick_failed") + stop_event.wait(interval_seconds) + + def _parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( prog="nix-builder-autoscaler", @@ -92,7 +139,11 @@ def main() -> None: stop_event = threading.Event() scheduler_thread: threading.Thread | None = None reconciler_thread: threading.Thread | None = None + metrics_thread: threading.Thread | None = None server: uvicorn.Server | None = None + loop_health = LoopHealth() + reconcile_lock = threading.Lock() + metrics_interval = 5.0 def scheduler_running() -> bool: return scheduler_thread is not None and scheduler_thread.is_alive() @@ -100,6 +151,32 @@ def main() -> None: def reconciler_running() -> bool: return reconciler_thread is not None and reconciler_thread.is_alive() + def metrics_running() -> bool: + return metrics_thread is not None and metrics_thread.is_alive() + + def ready_check() -> bool: + checks = [ + ("scheduler", scheduler_running(), _max_staleness(config.scheduler.tick_seconds)), + ( + "reconciler", + reconciler_running(), + _max_staleness(config.scheduler.reconcile_seconds), + ), + ("metrics", metrics_running(), _max_staleness(metrics_interval)), + ] + for loop_name, alive, max_age in checks: + if not alive: + return False + if not loop_health.is_fresh(loop_name, max_age): + return False + return True + + def reconcile_now() -> dict[str, object]: + with reconcile_lock: + reconciler.tick() + loop_health.mark_success("reconciler") + return {"triggered": True} + app = create_app( db, config, @@ -109,23 +186,36 @@ def main() -> None: haproxy=haproxy, scheduler_running=scheduler_running, reconciler_running=reconciler_running, + ready_check=ready_check, + reconcile_now=reconcile_now, ) + loop_health.mark_success("scheduler") + loop_health.mark_success("reconciler") + loop_health.mark_success("metrics") + scheduler_thread = threading.Thread( target=_scheduler_loop, name="autoscaler-scheduler", - args=(db, runtime, config, clock, metrics, stop_event), + args=(db, runtime, config, clock, metrics, stop_event, loop_health), daemon=True, ) reconciler_thread = threading.Thread( target=_reconciler_loop, name="autoscaler-reconciler", - args=(reconciler, config, stop_event), + args=(reconciler, config, stop_event, loop_health, reconcile_lock), + daemon=True, + ) + metrics_thread = threading.Thread( + target=_metrics_health_loop, + name="autoscaler-metrics-health", + args=(metrics, stop_event, loop_health, metrics_interval), daemon=True, ) scheduler_thread.start() reconciler_thread.start() + metrics_thread.start() socket_path = Path(config.server.socket_path) socket_path.parent.mkdir(parents=True, exist_ok=True) @@ -156,6 +246,8 @@ def main() -> None: scheduler_thread.join(timeout=10) if reconciler_thread is not None: reconciler_thread.join(timeout=10) + if metrics_thread is not None: + metrics_thread.join(timeout=10) db.close() diff --git a/agent/nix_builder_autoscaler/api.py b/agent/nix_builder_autoscaler/api.py index dae8074..3df95f5 100644 --- a/agent/nix_builder_autoscaler/api.py +++ b/agent/nix_builder_autoscaler/api.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, NoReturn from fastapi import FastAPI, HTTPException, Request, Response from fastapi.responses import JSONResponse +from pydantic import BaseModel from .models import ( CapacityHint, @@ -35,6 +36,12 @@ if TYPE_CHECKING: log = logging.getLogger(__name__) +class SlotAdminRequest(BaseModel): + """Admin action request that targets a slot.""" + + slot_id: str + + def _parse_required_dt(value: str) -> datetime: return datetime.fromisoformat(value) @@ -95,6 +102,8 @@ def create_app( haproxy: HAProxyRuntime | None = None, scheduler_running: Callable[[], bool] | None = None, reconciler_running: Callable[[], bool] | None = None, + ready_check: Callable[[], bool] | None = None, + reconcile_now: Callable[[], dict[str, object] | None] | None = None, ) -> FastAPI: """Create the FastAPI application.""" app = FastAPI(title="nix-builder-autoscaler", version="0.1.0") @@ -191,6 +200,11 @@ def create_app( @app.get("/health/ready", response_model=HealthResponse) def health_ready() -> HealthResponse: + if ready_check is not None and not ready_check(): + return JSONResponse( # type: ignore[return-value] + status_code=503, + content=HealthResponse(status="degraded").model_dump(mode="json"), + ) if scheduler_running is not None and not scheduler_running(): return JSONResponse( # type: ignore[return-value] status_code=503, @@ -207,4 +221,83 @@ def create_app( def metrics_endpoint() -> Response: return Response(content=metrics.render(), media_type="text/plain") + @app.post("/v1/admin/drain") + def admin_drain(body: SlotAdminRequest, request: Request) -> dict[str, str]: + slot = db.get_slot(body.slot_id) + if slot is None: + _error_response(request, 404, "not_found", "Slot not found") + state = str(slot["state"]) + if state == SlotState.DRAINING.value or state == SlotState.TERMINATING.value: + return {"status": "accepted", "slot_id": body.slot_id, "state": state} + + allowed_states = { + SlotState.READY.value, + SlotState.BINDING.value, + SlotState.BOOTING.value, + SlotState.LAUNCHING.value, + } + if state not in allowed_states: + _error_response( + request, + 409, + "invalid_state", + f"Cannot drain slot from state {state}", + ) + db.update_slot_state(body.slot_id, SlotState.DRAINING, interruption_pending=0) + return {"status": "accepted", "slot_id": body.slot_id, "state": SlotState.DRAINING.value} + + @app.post("/v1/admin/unquarantine") + def admin_unquarantine(body: SlotAdminRequest, request: Request) -> dict[str, str]: + slot = db.get_slot(body.slot_id) + if slot is None: + _error_response(request, 404, "not_found", "Slot not found") + + state = str(slot["state"]) + if state != SlotState.ERROR.value: + _error_response( + request, + 409, + "invalid_state", + f"Cannot unquarantine slot from state {state}", + ) + + db.update_slot_state( + body.slot_id, + SlotState.EMPTY, + instance_id=None, + instance_ip=None, + instance_launch_time=None, + lease_count=0, + cooldown_until=None, + interruption_pending=0, + ) + return {"status": "accepted", "slot_id": body.slot_id, "state": SlotState.EMPTY.value} + + @app.post("/v1/admin/reconcile-now") + def admin_reconcile_now(request: Request) -> dict[str, object]: + if reconcile_now is None: + _error_response( + request, + 503, + "not_configured", + "Reconcile trigger not configured", + retryable=True, + ) + try: + result = reconcile_now() + except Exception: + log.exception("admin_reconcile_now_failed") + _error_response( + request, + 500, + "reconcile_failed", + "Reconcile tick failed", + retryable=True, + ) + + payload: dict[str, object] = {"status": "accepted"} + if isinstance(result, dict): + payload.update(result) + return payload + return app diff --git a/agent/nix_builder_autoscaler/bootstrap/userdata.py b/agent/nix_builder_autoscaler/bootstrap/userdata.py index 11b72a7..91cefe9 100644 --- a/agent/nix_builder_autoscaler/bootstrap/userdata.py +++ b/agent/nix_builder_autoscaler/bootstrap/userdata.py @@ -41,13 +41,22 @@ def render_userdata(slot_id: str, region: str, ssm_param: str = "/nix-builder/ts printf '%s' "$TS_AUTHKEY" > /run/credentials/tailscale-auth-key chmod 600 /run/credentials/tailscale-auth-key + # --- Resolve instance identity from IMDSv2 for unique hostname --- + IMDS_TOKEN=$(curl -fsS -X PUT "http://169.254.169.254/latest/api/token" \\ + -H "X-aws-ec2-metadata-token-ttl-seconds: 21600" || true) + INSTANCE_ID=$(curl -fsS -H "X-aws-ec2-metadata-token: $IMDS_TOKEN" \\ + "http://169.254.169.254/latest/meta-data/instance-id" || true) + if [ -z "$INSTANCE_ID" ]; then + INSTANCE_ID="unknown" + fi + # --- Write tailscale-autoconnect config --- mkdir -p /etc/tailscale cat > /etc/tailscale/autoconnect.conf < argparse.Namespace: subparsers.add_parser("slots", help="List slots") subparsers.add_parser("reservations", help="List reservations") - parser_drain = subparsers.add_parser("drain", help="Drain a slot (not implemented)") + parser_drain = subparsers.add_parser("drain", help="Drain a slot") parser_drain.add_argument("slot_id") - parser_unq = subparsers.add_parser( - "unquarantine", - help="Unquarantine a slot (not implemented)", - ) + parser_unq = subparsers.add_parser("unquarantine", help="Unquarantine a slot") parser_unq.add_argument("slot_id") - subparsers.add_parser("reconcile-now", help="Run reconciler now (not implemented)") + subparsers.add_parser("reconcile-now", help="Trigger immediate reconcile tick") return parser.parse_args() @@ -130,19 +127,31 @@ def main() -> None: if not args.command: raise SystemExit(1) - if args.command in {"drain", "unquarantine", "reconcile-now"}: - print(f"{args.command}: not yet implemented in API v1") - raise SystemExit(0) - - endpoint_map = { - "status": "/v1/state/summary", - "slots": "/v1/slots", - "reservations": "/v1/reservations", - } - path = endpoint_map[args.command] + method = "GET" + path = "" + body: dict[str, Any] | None = None + if args.command == "status": + path = "/v1/state/summary" + elif args.command == "slots": + path = "/v1/slots" + elif args.command == "reservations": + path = "/v1/reservations" + elif args.command == "drain": + method = "POST" + path = "/v1/admin/drain" + body = {"slot_id": args.slot_id} + elif args.command == "unquarantine": + method = "POST" + path = "/v1/admin/unquarantine" + body = {"slot_id": args.slot_id} + elif args.command == "reconcile-now": + method = "POST" + path = "/v1/admin/reconcile-now" + else: + raise SystemExit(1) try: - status, data = _uds_request(args.socket, "GET", path) + status, data = _uds_request(args.socket, method, path, body=body) except OSError as err: print(f"Error: cannot connect to daemon at {args.socket}") raise SystemExit(1) from err @@ -151,7 +160,7 @@ def main() -> None: _print_error(data) raise SystemExit(1) - if args.command == "status": + if args.command in {"status", "drain", "unquarantine", "reconcile-now"}: print(json.dumps(data, indent=2)) elif args.command == "slots": if isinstance(data, list): diff --git a/agent/nix_builder_autoscaler/reconciler.py b/agent/nix_builder_autoscaler/reconciler.py index 9607462..448f92d 100644 --- a/agent/nix_builder_autoscaler/reconciler.py +++ b/agent/nix_builder_autoscaler/reconciler.py @@ -68,7 +68,7 @@ class Reconciler: # 2. Query HAProxy try: - haproxy_health = self._haproxy.read_slot_health() + haproxy_health = self._haproxy_read_slot_health() except HAProxyError: log.warning("haproxy_stat_failed", exc_info=True) haproxy_health = {} @@ -142,8 +142,8 @@ class Reconciler: if tailscale_ip is not None: self._db.update_slot_state(slot["slot_id"], SlotState.BINDING, instance_ip=tailscale_ip) try: - self._haproxy.set_slot_addr(slot["slot_id"], tailscale_ip) - self._haproxy.enable_slot(slot["slot_id"]) + self._haproxy_set_slot_addr(slot["slot_id"], tailscale_ip) + self._haproxy_enable_slot(slot["slot_id"]) except HAProxyError: log.warning( "haproxy_binding_setup_failed", @@ -169,8 +169,8 @@ class Reconciler: ip = slot.get("instance_ip") if ip: try: - self._haproxy.set_slot_addr(slot_id, ip) - self._haproxy.enable_slot(slot_id) + self._haproxy_set_slot_addr(slot_id, ip) + self._haproxy_enable_slot(slot_id) except HAProxyError: pass @@ -204,7 +204,7 @@ class Reconciler: # Disable HAProxy (idempotent) with contextlib.suppress(HAProxyError): - self._haproxy.disable_slot(slot_id) + self._haproxy_disable_slot(slot_id) now = self._clock.now() last_change = datetime.fromisoformat(slot["last_state_change"]) @@ -216,8 +216,17 @@ class Reconciler: if instance_id: try: self._runtime.terminate_instance(instance_id) - self._metrics.counter("autoscaler_ec2_terminate_total", {}, 1.0) + self._metrics.counter( + "autoscaler_ec2_terminate_total", + {"result": "success"}, + 1.0, + ) except Exception: + self._metrics.counter( + "autoscaler_ec2_terminate_total", + {"result": "error"}, + 1.0, + ) log.warning( "terminate_failed", extra={"slot_id": slot_id, "instance_id": instance_id}, @@ -252,7 +261,70 @@ class Reconciler: """Emit reconciler metrics.""" summary = self._db.get_state_summary() for state, count in summary["slots"].items(): - if state == "total": - continue - self._metrics.gauge("autoscaler_slots", {"state": state}, float(count)) - self._metrics.histogram_observe("autoscaler_reconciler_tick_seconds", {}, tick_duration) + self._metrics.gauge("autoscaler_slots_total", {"state": state}, float(count)) + self._metrics.histogram_observe("autoscaler_reconcile_duration_seconds", {}, tick_duration) + + def _haproxy_set_slot_addr(self, slot_id: str, ip: str) -> None: + try: + self._haproxy.set_slot_addr(slot_id, ip) + self._metrics.counter( + "autoscaler_haproxy_command_total", + {"cmd": "set_slot_addr", "result": "success"}, + 1.0, + ) + except HAProxyError: + self._metrics.counter( + "autoscaler_haproxy_command_total", + {"cmd": "set_slot_addr", "result": "error"}, + 1.0, + ) + raise + + def _haproxy_enable_slot(self, slot_id: str) -> None: + try: + self._haproxy.enable_slot(slot_id) + self._metrics.counter( + "autoscaler_haproxy_command_total", + {"cmd": "enable_slot", "result": "success"}, + 1.0, + ) + except HAProxyError: + self._metrics.counter( + "autoscaler_haproxy_command_total", + {"cmd": "enable_slot", "result": "error"}, + 1.0, + ) + raise + + def _haproxy_disable_slot(self, slot_id: str) -> None: + try: + self._haproxy.disable_slot(slot_id) + self._metrics.counter( + "autoscaler_haproxy_command_total", + {"cmd": "disable_slot", "result": "success"}, + 1.0, + ) + except HAProxyError: + self._metrics.counter( + "autoscaler_haproxy_command_total", + {"cmd": "disable_slot", "result": "error"}, + 1.0, + ) + raise + + def _haproxy_read_slot_health(self) -> dict: + try: + health = self._haproxy.read_slot_health() + self._metrics.counter( + "autoscaler_haproxy_command_total", + {"cmd": "show_stat", "result": "success"}, + 1.0, + ) + return health + except HAProxyError: + self._metrics.counter( + "autoscaler_haproxy_command_total", + {"cmd": "show_stat", "result": "error"}, + 1.0, + ) + raise diff --git a/agent/nix_builder_autoscaler/runtime/ec2.py b/agent/nix_builder_autoscaler/runtime/ec2.py index d134c40..f20dcd2 100644 --- a/agent/nix_builder_autoscaler/runtime/ec2.py +++ b/agent/nix_builder_autoscaler/runtime/ec2.py @@ -2,8 +2,11 @@ from __future__ import annotations +import http.client +import json import logging import random +import socket import time from typing import Any @@ -26,6 +29,18 @@ _ERROR_CATEGORIES: dict[str, str] = { _RETRYABLE_CODES: frozenset[str] = frozenset({"RequestLimitExceeded"}) +class _UnixSocketHTTPConnection(http.client.HTTPConnection): + """HTTP connection over a Unix domain socket.""" + + def __init__(self, socket_path: str, timeout: float = 1.0) -> None: + super().__init__("local-tailscaled.sock", timeout=timeout) + self._socket_path = socket_path + + def connect(self) -> None: + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.sock.connect(self._socket_path) + + class EC2Runtime(RuntimeAdapter): """EC2 Spot instance runtime adapter. @@ -41,6 +56,7 @@ class EC2Runtime(RuntimeAdapter): environment: str = "dev", *, _client: Any = None, + _tailscale_socket_path: str = "/run/tailscale/tailscaled.sock", ) -> None: self._client: Any = _client or boto3.client("ec2", region_name=config.region) self._launch_template_id = config.launch_template_id @@ -49,6 +65,7 @@ class EC2Runtime(RuntimeAdapter): self._instance_profile_arn = config.instance_profile_arn self._environment = environment self._subnet_index = 0 + self._tailscale_socket_path = _tailscale_socket_path def launch_spot(self, slot_id: str, user_data: str) -> str: """Launch a spot instance for *slot_id*. Return instance ID.""" @@ -103,10 +120,17 @@ class EC2Runtime(RuntimeAdapter): return {"state": "terminated", "tailscale_ip": None, "launch_time": None} inst = reservations[0]["Instances"][0] + tags = inst.get("Tags", []) + slot_id = self._get_tag(tags, "AutoscalerSlot") + state = inst["State"]["Name"] + tailscale_ip: str | None = None + if state == "running" and slot_id: + tailscale_ip = self._discover_tailscale_ip(slot_id, instance_id) + launch_time = inst.get("LaunchTime") return { - "state": inst["State"]["Name"], - "tailscale_ip": None, + "state": state, + "tailscale_ip": tailscale_ip, "launch_time": launch_time.isoformat() if launch_time else None, } @@ -166,6 +190,98 @@ class EC2Runtime(RuntimeAdapter): msg = "Retries exhausted" raise RuntimeAdapterError(msg, category="unknown") + def _discover_tailscale_ip(self, slot_id: str, instance_id: str) -> str | None: + """Resolve Tailscale IP for instance identity via local tailscaled LocalAPI.""" + status = self._read_tailscale_status() + if status is None: + return None + + peers_obj = status.get("Peer") + if not isinstance(peers_obj, dict): + return None + + online_candidates: list[tuple[str, str]] = [] + for peer in peers_obj.values(): + if not isinstance(peer, dict): + continue + if not self._peer_is_online(peer): + continue + hostname = self._peer_hostname(peer) + if hostname is None: + continue + ip = self._peer_tailscale_ip(peer) + if ip is None: + continue + online_candidates.append((hostname, ip)) + + identity = f"nix-builder-{slot_id}-{instance_id}".lower() + identity_matches = [ip for host, ip in online_candidates if identity in host] + if len(identity_matches) == 1: + return identity_matches[0] + if len(identity_matches) > 1: + log.warning( + "tailscale_identity_ambiguous", + extra={"slot_id": slot_id, "instance_id": instance_id}, + ) + return None + + slot_identity = f"nix-builder-{slot_id}".lower() + slot_matches = [ip for host, ip in online_candidates if slot_identity in host] + if len(slot_matches) == 1: + return slot_matches[0] + if len(slot_matches) > 1: + log.warning("tailscale_slot_ambiguous", extra={"slot_id": slot_id}) + return None + return None + + def _read_tailscale_status(self) -> dict[str, Any] | None: + """Query local tailscaled LocalAPI status endpoint over Unix socket.""" + conn = _UnixSocketHTTPConnection(self._tailscale_socket_path, timeout=1.0) + try: + conn.request( + "GET", + "/localapi/v0/status", + headers={"Host": "local-tailscaled.sock", "Accept": "application/json"}, + ) + response = conn.getresponse() + if response.status != 200: + return None + payload = response.read() + parsed = json.loads(payload.decode()) + if isinstance(parsed, dict): + return parsed + return None + except (OSError, PermissionError, TimeoutError, json.JSONDecodeError, UnicodeDecodeError): + return None + except http.client.HTTPException: + return None + finally: + conn.close() + + @staticmethod + def _peer_is_online(peer: dict[str, Any]) -> bool: + return bool(peer.get("Online") or peer.get("Active")) + + @staticmethod + def _peer_hostname(peer: dict[str, Any]) -> str | None: + host = peer.get("HostName") or peer.get("DNSName") + if not isinstance(host, str) or not host: + return None + return host.strip(".").lower() + + @staticmethod + def _peer_tailscale_ip(peer: dict[str, Any]) -> str | None: + ips = peer.get("TailscaleIPs") + if not isinstance(ips, list): + return None + ipv4 = [ip for ip in ips if isinstance(ip, str) and "." in ip] + if ipv4: + return ipv4[0] + for ip in ips: + if isinstance(ip, str) and ip: + return ip + return None + @staticmethod def _get_tag(tags: list[dict[str, str]], key: str) -> str | None: """Extract a tag value from an EC2 tag list.""" diff --git a/agent/nix_builder_autoscaler/scheduler.py b/agent/nix_builder_autoscaler/scheduler.py index 94baf61..7a835a2 100644 --- a/agent/nix_builder_autoscaler/scheduler.py +++ b/agent/nix_builder_autoscaler/scheduler.py @@ -217,12 +217,13 @@ def _launch_slot( """Launch a single slot. Transition to LAUNCHING on success, ERROR on failure.""" slot_id = slot["slot_id"] user_data = render_userdata(slot_id, config.aws.region) - metrics.counter("autoscaler_ec2_launch_total", {}, 1.0) try: instance_id = runtime.launch_spot(slot_id, user_data) + metrics.counter("autoscaler_ec2_launch_total", {"result": "success"}, 1.0) db.update_slot_state(slot_id, SlotState.LAUNCHING, instance_id=instance_id) log.info("slot_launched", extra={"slot_id": slot_id, "instance_id": instance_id}) except RuntimeAdapterError as exc: + metrics.counter("autoscaler_ec2_launch_total", {"result": exc.category}, 1.0) db.update_slot_state(slot_id, SlotState.ERROR) log.warning( "slot_launch_failed", @@ -257,11 +258,9 @@ def _update_metrics(db: StateDB, metrics: MetricsRegistry, tick_duration: float) summary = db.get_state_summary() for state, count in summary["slots"].items(): - if state == "total": - continue - metrics.gauge("autoscaler_slots", {"state": state}, float(count)) + metrics.gauge("autoscaler_slots_total", {"state": state}, float(count)) for phase, count in summary["reservations"].items(): - metrics.gauge("autoscaler_reservations", {"phase": phase}, float(count)) + metrics.gauge("autoscaler_reservations_total", {"phase": phase}, float(count)) - metrics.histogram_observe("autoscaler_scheduler_tick_seconds", {}, tick_duration) + metrics.histogram_observe("autoscaler_scheduler_tick_duration_seconds", {}, tick_duration) diff --git a/agent/nix_builder_autoscaler/state_db.py b/agent/nix_builder_autoscaler/state_db.py index bb11156..cf5b8d9 100644 --- a/agent/nix_builder_autoscaler/state_db.py +++ b/agent/nix_builder_autoscaler/state_db.py @@ -7,6 +7,7 @@ from __future__ import annotations import json import sqlite3 +import threading import uuid from datetime import UTC, datetime, timedelta from pathlib import Path @@ -75,83 +76,89 @@ class StateDB: self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute("PRAGMA busy_timeout=5000") self._clock = clock + self._lock = threading.RLock() def init_schema(self) -> None: """Create tables if they don't exist.""" - self._conn.executescript(_SCHEMA) + with self._lock: + self._conn.executescript(_SCHEMA) def init_slots(self, slot_prefix: str, slot_count: int, system: str, backend: str) -> None: """Ensure all expected slots exist, creating missing ones as empty.""" - now = _now_iso(self._clock) - for i in range(1, slot_count + 1): - slot_id = f"{slot_prefix}{i:03d}" - bound = f"{backend}/{slot_id}" - self._conn.execute( - """INSERT OR IGNORE INTO slots - (slot_id, system, state, bound_backend, lease_count, last_state_change) - VALUES (?, ?, ?, ?, 0, ?)""", - (slot_id, system, SlotState.EMPTY.value, bound, now), - ) - self._conn.commit() + with self._lock: + now = _now_iso(self._clock) + for i in range(1, slot_count + 1): + slot_id = f"{slot_prefix}{i:03d}" + bound = f"{backend}/{slot_id}" + self._conn.execute( + """INSERT OR IGNORE INTO slots + (slot_id, system, state, bound_backend, lease_count, last_state_change) + VALUES (?, ?, ?, ?, 0, ?)""", + (slot_id, system, SlotState.EMPTY.value, bound, now), + ) + self._conn.commit() # -- Slot operations ---------------------------------------------------- def get_slot(self, slot_id: str) -> dict | None: """Return a slot row as dict, or None.""" - cur = self._conn.execute("SELECT * FROM slots WHERE slot_id = ?", (slot_id,)) - row = cur.fetchone() - if row is None: - return None - return _row_to_dict(cur, row) + with self._lock: + cur = self._conn.execute("SELECT * FROM slots WHERE slot_id = ?", (slot_id,)) + row = cur.fetchone() + if row is None: + return None + return _row_to_dict(cur, row) def list_slots(self, state: SlotState | None = None) -> list[dict]: """List slots, optionally filtered by state.""" - if state is not None: - cur = self._conn.execute( - "SELECT * FROM slots WHERE state = ? ORDER BY slot_id", (state.value,) - ) - else: - cur = self._conn.execute("SELECT * FROM slots ORDER BY slot_id") - return [_row_to_dict(cur, row) for row in cur.fetchall()] + with self._lock: + if state is not None: + cur = self._conn.execute( + "SELECT * FROM slots WHERE state = ? ORDER BY slot_id", (state.value,) + ) + else: + cur = self._conn.execute("SELECT * FROM slots ORDER BY slot_id") + return [_row_to_dict(cur, row) for row in cur.fetchall()] def update_slot_state(self, slot_id: str, new_state: SlotState, **fields: object) -> None: """Atomically transition a slot to a new state and record an event. Additional fields (instance_id, instance_ip, etc.) can be passed as kwargs. """ - now = _now_iso(self._clock) - set_parts = ["state = ?", "last_state_change = ?"] - params: list[object] = [new_state.value, now] + with self._lock: + now = _now_iso(self._clock) + set_parts = ["state = ?", "last_state_change = ?"] + params: list[object] = [new_state.value, now] - allowed = { - "instance_id", - "instance_ip", - "instance_launch_time", - "lease_count", - "cooldown_until", - "interruption_pending", - } - for k, v in fields.items(): - if k not in allowed: - msg = f"Unknown slot field: {k}" - raise ValueError(msg) - set_parts.append(f"{k} = ?") - params.append(v) + allowed = { + "instance_id", + "instance_ip", + "instance_launch_time", + "lease_count", + "cooldown_until", + "interruption_pending", + } + for k, v in fields.items(): + if k not in allowed: + msg = f"Unknown slot field: {k}" + raise ValueError(msg) + set_parts.append(f"{k} = ?") + params.append(v) - params.append(slot_id) - sql = f"UPDATE slots SET {', '.join(set_parts)} WHERE slot_id = ?" + params.append(slot_id) + sql = f"UPDATE slots SET {', '.join(set_parts)} WHERE slot_id = ?" - self._conn.execute("BEGIN IMMEDIATE") - try: - self._conn.execute(sql, params) - self._record_event_inner( - "slot_state_change", - {"slot_id": slot_id, "new_state": new_state.value, **fields}, - ) - self._conn.execute("COMMIT") - except Exception: - self._conn.execute("ROLLBACK") - raise + self._conn.execute("BEGIN IMMEDIATE") + try: + self._conn.execute(sql, params) + self._record_event_inner( + "slot_state_change", + {"slot_id": slot_id, "new_state": new_state.value, **fields}, + ) + self._conn.execute("COMMIT") + except Exception: + self._conn.execute("ROLLBACK") + raise def update_slot_fields(self, slot_id: str, **fields: object) -> None: """Update specific slot columns without changing state or last_state_change. @@ -159,40 +166,41 @@ class StateDB: Uses BEGIN IMMEDIATE. Allowed fields: instance_id, instance_ip, instance_launch_time, lease_count, cooldown_until, interruption_pending. """ - allowed = { - "instance_id", - "instance_ip", - "instance_launch_time", - "lease_count", - "cooldown_until", - "interruption_pending", - } - if not fields: - return + with self._lock: + allowed = { + "instance_id", + "instance_ip", + "instance_launch_time", + "lease_count", + "cooldown_until", + "interruption_pending", + } + if not fields: + return - set_parts: list[str] = [] - params: list[object] = [] - for k, v in fields.items(): - if k not in allowed: - msg = f"Unknown slot field: {k}" - raise ValueError(msg) - set_parts.append(f"{k} = ?") - params.append(v) + set_parts: list[str] = [] + params: list[object] = [] + for k, v in fields.items(): + if k not in allowed: + msg = f"Unknown slot field: {k}" + raise ValueError(msg) + set_parts.append(f"{k} = ?") + params.append(v) - params.append(slot_id) - sql = f"UPDATE slots SET {', '.join(set_parts)} WHERE slot_id = ?" + params.append(slot_id) + sql = f"UPDATE slots SET {', '.join(set_parts)} WHERE slot_id = ?" - self._conn.execute("BEGIN IMMEDIATE") - try: - self._conn.execute(sql, params) - self._record_event_inner( - "slot_fields_updated", - {"slot_id": slot_id, **fields}, - ) - self._conn.execute("COMMIT") - except Exception: - self._conn.execute("ROLLBACK") - raise + self._conn.execute("BEGIN IMMEDIATE") + try: + self._conn.execute(sql, params) + self._record_event_inner( + "slot_fields_updated", + {"slot_id": slot_id, **fields}, + ) + self._conn.execute("COMMIT") + except Exception: + self._conn.execute("ROLLBACK") + raise # -- Reservation operations --------------------------------------------- @@ -204,53 +212,65 @@ class StateDB: ttl_seconds: int, ) -> dict: """Create a new pending reservation. Returns the reservation row as dict.""" - now = _now_iso(self._clock) - if self._clock is not None: - expires = (self._clock.now() + timedelta(seconds=ttl_seconds)).isoformat() - else: - expires = (datetime.now(UTC) + timedelta(seconds=ttl_seconds)).isoformat() - rid = f"resv_{uuid.uuid4().hex}" + with self._lock: + now = _now_iso(self._clock) + if self._clock is not None: + expires = (self._clock.now() + timedelta(seconds=ttl_seconds)).isoformat() + else: + expires = (datetime.now(UTC) + timedelta(seconds=ttl_seconds)).isoformat() + rid = f"resv_{uuid.uuid4().hex}" - self._conn.execute("BEGIN IMMEDIATE") - try: - self._conn.execute( - """INSERT INTO reservations - (reservation_id, system, phase, created_at, updated_at, - expires_at, reason, build_id) - VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - (rid, system, ReservationPhase.PENDING.value, now, now, expires, reason, build_id), - ) - self._record_event_inner( - "reservation_created", - {"reservation_id": rid, "system": system, "reason": reason}, - ) - self._conn.execute("COMMIT") - except Exception: - self._conn.execute("ROLLBACK") - raise + self._conn.execute("BEGIN IMMEDIATE") + try: + self._conn.execute( + """INSERT INTO reservations + (reservation_id, system, phase, created_at, updated_at, + expires_at, reason, build_id) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + rid, + system, + ReservationPhase.PENDING.value, + now, + now, + expires, + reason, + build_id, + ), + ) + self._record_event_inner( + "reservation_created", + {"reservation_id": rid, "system": system, "reason": reason}, + ) + self._conn.execute("COMMIT") + except Exception: + self._conn.execute("ROLLBACK") + raise - return self.get_reservation(rid) # type: ignore[return-value] + return self.get_reservation(rid) # type: ignore[return-value] def get_reservation(self, reservation_id: str) -> dict | None: """Return a reservation row as dict, or None.""" - cur = self._conn.execute( - "SELECT * FROM reservations WHERE reservation_id = ?", (reservation_id,) - ) - row = cur.fetchone() - if row is None: - return None - return _row_to_dict(cur, row) + with self._lock: + cur = self._conn.execute( + "SELECT * FROM reservations WHERE reservation_id = ?", (reservation_id,) + ) + row = cur.fetchone() + if row is None: + return None + return _row_to_dict(cur, row) def list_reservations(self, phase: ReservationPhase | None = None) -> list[dict]: """List reservations, optionally filtered by phase.""" - if phase is not None: - cur = self._conn.execute( - "SELECT * FROM reservations WHERE phase = ? ORDER BY created_at", - (phase.value,), - ) - else: - cur = self._conn.execute("SELECT * FROM reservations ORDER BY created_at") - return [_row_to_dict(cur, row) for row in cur.fetchall()] + with self._lock: + if phase is not None: + cur = self._conn.execute( + "SELECT * FROM reservations WHERE phase = ? ORDER BY created_at", + (phase.value,), + ) + else: + cur = self._conn.execute("SELECT * FROM reservations ORDER BY created_at") + return [_row_to_dict(cur, row) for row in cur.fetchall()] def assign_reservation(self, reservation_id: str, slot_id: str, instance_id: str) -> None: """Assign a pending reservation to a ready slot. @@ -258,184 +278,191 @@ class StateDB: Atomically: update reservation phase to ready, set slot_id/instance_id, and increment slot lease_count. """ - now = _now_iso(self._clock) + with self._lock: + now = _now_iso(self._clock) - self._conn.execute("BEGIN IMMEDIATE") - try: - self._conn.execute( - """UPDATE reservations - SET phase = ?, slot_id = ?, instance_id = ?, updated_at = ? - WHERE reservation_id = ? AND phase = ?""", - ( - ReservationPhase.READY.value, - slot_id, - instance_id, - now, - reservation_id, - ReservationPhase.PENDING.value, - ), - ) - self._conn.execute( - "UPDATE slots SET lease_count = lease_count + 1 WHERE slot_id = ?", - (slot_id,), - ) - self._record_event_inner( - "reservation_assigned", - { - "reservation_id": reservation_id, - "slot_id": slot_id, - "instance_id": instance_id, - }, - ) - self._conn.execute("COMMIT") - except Exception: - self._conn.execute("ROLLBACK") - raise + self._conn.execute("BEGIN IMMEDIATE") + try: + self._conn.execute( + """UPDATE reservations + SET phase = ?, slot_id = ?, instance_id = ?, updated_at = ? + WHERE reservation_id = ? AND phase = ?""", + ( + ReservationPhase.READY.value, + slot_id, + instance_id, + now, + reservation_id, + ReservationPhase.PENDING.value, + ), + ) + self._conn.execute( + "UPDATE slots SET lease_count = lease_count + 1 WHERE slot_id = ?", + (slot_id,), + ) + self._record_event_inner( + "reservation_assigned", + { + "reservation_id": reservation_id, + "slot_id": slot_id, + "instance_id": instance_id, + }, + ) + self._conn.execute("COMMIT") + except Exception: + self._conn.execute("ROLLBACK") + raise def release_reservation(self, reservation_id: str) -> dict | None: """Release a reservation, decrementing the slot lease count.""" - now = _now_iso(self._clock) + with self._lock: + now = _now_iso(self._clock) - self._conn.execute("BEGIN IMMEDIATE") - try: - cur = self._conn.execute( - "SELECT * FROM reservations WHERE reservation_id = ?", - (reservation_id,), - ) - row = cur.fetchone() - if row is None: - self._conn.execute("ROLLBACK") - return None - - resv = _row_to_dict(cur, row) - old_phase = resv["phase"] - - if old_phase in (ReservationPhase.RELEASED.value, ReservationPhase.EXPIRED.value): - self._conn.execute("ROLLBACK") - return resv - - self._conn.execute( - """UPDATE reservations - SET phase = ?, released_at = ?, updated_at = ? - WHERE reservation_id = ?""", - (ReservationPhase.RELEASED.value, now, now, reservation_id), - ) - - if resv["slot_id"] and old_phase == ReservationPhase.READY.value: - self._conn.execute( - """UPDATE slots SET lease_count = MAX(lease_count - 1, 0) - WHERE slot_id = ?""", - (resv["slot_id"],), + self._conn.execute("BEGIN IMMEDIATE") + try: + cur = self._conn.execute( + "SELECT * FROM reservations WHERE reservation_id = ?", + (reservation_id,), ) + row = cur.fetchone() + if row is None: + self._conn.execute("ROLLBACK") + return None - self._record_event_inner("reservation_released", {"reservation_id": reservation_id}) - self._conn.execute("COMMIT") - except Exception: - self._conn.execute("ROLLBACK") - raise + resv = _row_to_dict(cur, row) + old_phase = resv["phase"] - return self.get_reservation(reservation_id) + if old_phase in (ReservationPhase.RELEASED.value, ReservationPhase.EXPIRED.value): + self._conn.execute("ROLLBACK") + return resv - def expire_reservations(self, now: datetime) -> list[str]: - """Expire all reservations past their expires_at. Returns expired IDs.""" - now_iso = now.isoformat() - expired_ids: list[str] = [] - - self._conn.execute("BEGIN IMMEDIATE") - try: - cur = self._conn.execute( - """SELECT reservation_id, slot_id, phase FROM reservations - WHERE phase IN (?, ?) AND expires_at <= ?""", - (ReservationPhase.PENDING.value, ReservationPhase.READY.value, now_iso), - ) - rows = cur.fetchall() - - for row in rows: - rid, slot_id, phase = row - expired_ids.append(rid) self._conn.execute( """UPDATE reservations - SET phase = ?, updated_at = ? + SET phase = ?, released_at = ?, updated_at = ? WHERE reservation_id = ?""", - (ReservationPhase.EXPIRED.value, now_iso, rid), + (ReservationPhase.RELEASED.value, now, now, reservation_id), ) - if slot_id and phase == ReservationPhase.READY.value: + + if resv["slot_id"] and old_phase == ReservationPhase.READY.value: self._conn.execute( """UPDATE slots SET lease_count = MAX(lease_count - 1, 0) WHERE slot_id = ?""", - (slot_id,), + (resv["slot_id"],), ) - self._record_event_inner("reservation_expired", {"reservation_id": rid}) - self._conn.execute("COMMIT") - except Exception: - self._conn.execute("ROLLBACK") - raise + self._record_event_inner("reservation_released", {"reservation_id": reservation_id}) + self._conn.execute("COMMIT") + except Exception: + self._conn.execute("ROLLBACK") + raise - return expired_ids + return self.get_reservation(reservation_id) + + def expire_reservations(self, now: datetime) -> list[str]: + """Expire all reservations past their expires_at. Returns expired IDs.""" + with self._lock: + now_iso = now.isoformat() + expired_ids: list[str] = [] + + self._conn.execute("BEGIN IMMEDIATE") + try: + cur = self._conn.execute( + """SELECT reservation_id, slot_id, phase FROM reservations + WHERE phase IN (?, ?) AND expires_at <= ?""", + (ReservationPhase.PENDING.value, ReservationPhase.READY.value, now_iso), + ) + rows = cur.fetchall() + + for row in rows: + rid, slot_id, phase = row + expired_ids.append(rid) + self._conn.execute( + """UPDATE reservations + SET phase = ?, updated_at = ? + WHERE reservation_id = ?""", + (ReservationPhase.EXPIRED.value, now_iso, rid), + ) + if slot_id and phase == ReservationPhase.READY.value: + self._conn.execute( + """UPDATE slots SET lease_count = MAX(lease_count - 1, 0) + WHERE slot_id = ?""", + (slot_id,), + ) + self._record_event_inner("reservation_expired", {"reservation_id": rid}) + + self._conn.execute("COMMIT") + except Exception: + self._conn.execute("ROLLBACK") + raise + + return expired_ids # -- Events ------------------------------------------------------------- def record_event(self, kind: str, payload: dict) -> None: # type: ignore[type-arg] """Record an audit event.""" - self._conn.execute("BEGIN IMMEDIATE") - try: - self._record_event_inner(kind, payload) - self._conn.execute("COMMIT") - except Exception: - self._conn.execute("ROLLBACK") - raise + with self._lock: + self._conn.execute("BEGIN IMMEDIATE") + try: + self._record_event_inner(kind, payload) + self._conn.execute("COMMIT") + except Exception: + self._conn.execute("ROLLBACK") + raise def _record_event_inner(self, kind: str, payload: dict) -> None: # type: ignore[type-arg] """Insert an event row (must be called inside an active transaction).""" - now = _now_iso(self._clock) - self._conn.execute( - "INSERT INTO events (ts, kind, payload_json) VALUES (?, ?, ?)", - (now, kind, json.dumps(payload, default=str)), - ) + with self._lock: + now = _now_iso(self._clock) + self._conn.execute( + "INSERT INTO events (ts, kind, payload_json) VALUES (?, ?, ?)", + (now, kind, json.dumps(payload, default=str)), + ) # -- Summaries ---------------------------------------------------------- def get_state_summary(self) -> dict: """Return aggregate slot and reservation counts.""" - slot_counts: dict[str, int] = {} - cur = self._conn.execute("SELECT state, COUNT(*) FROM slots GROUP BY state") - for state_val, count in cur.fetchall(): - slot_counts[state_val] = count + with self._lock: + slot_counts: dict[str, int] = {} + cur = self._conn.execute("SELECT state, COUNT(*) FROM slots GROUP BY state") + for state_val, count in cur.fetchall(): + slot_counts[state_val] = count - total_slots = sum(slot_counts.values()) + total_slots = sum(slot_counts.values()) - resv_counts: dict[str, int] = {} - cur = self._conn.execute( - "SELECT phase, COUNT(*) FROM reservations WHERE phase IN (?, ?, ?) GROUP BY phase", - ( - ReservationPhase.PENDING.value, - ReservationPhase.READY.value, - ReservationPhase.FAILED.value, - ), - ) - for phase_val, count in cur.fetchall(): - resv_counts[phase_val] = count + resv_counts: dict[str, int] = {} + cur = self._conn.execute( + "SELECT phase, COUNT(*) FROM reservations WHERE phase IN (?, ?, ?) GROUP BY phase", + ( + ReservationPhase.PENDING.value, + ReservationPhase.READY.value, + ReservationPhase.FAILED.value, + ), + ) + for phase_val, count in cur.fetchall(): + resv_counts[phase_val] = count - return { - "slots": { - "total": total_slots, - "ready": slot_counts.get("ready", 0), - "launching": slot_counts.get("launching", 0), - "booting": slot_counts.get("booting", 0), - "binding": slot_counts.get("binding", 0), - "draining": slot_counts.get("draining", 0), - "terminating": slot_counts.get("terminating", 0), - "empty": slot_counts.get("empty", 0), - "error": slot_counts.get("error", 0), - }, - "reservations": { - "pending": resv_counts.get("pending", 0), - "ready": resv_counts.get("ready", 0), - "failed": resv_counts.get("failed", 0), - }, - } + return { + "slots": { + "total": total_slots, + "ready": slot_counts.get("ready", 0), + "launching": slot_counts.get("launching", 0), + "booting": slot_counts.get("booting", 0), + "binding": slot_counts.get("binding", 0), + "draining": slot_counts.get("draining", 0), + "terminating": slot_counts.get("terminating", 0), + "empty": slot_counts.get("empty", 0), + "error": slot_counts.get("error", 0), + }, + "reservations": { + "pending": resv_counts.get("pending", 0), + "ready": resv_counts.get("ready", 0), + "failed": resv_counts.get("failed", 0), + }, + } def close(self) -> None: """Close the database connection.""" - self._conn.close() + with self._lock: + self._conn.close() diff --git a/agent/nix_builder_autoscaler/tests/integration/test_end_to_end_fake.py b/agent/nix_builder_autoscaler/tests/integration/test_end_to_end_fake.py index 8ae9e78..1bc3beb 100644 --- a/agent/nix_builder_autoscaler/tests/integration/test_end_to_end_fake.py +++ b/agent/nix_builder_autoscaler/tests/integration/test_end_to_end_fake.py @@ -1 +1,407 @@ -"""End-to-end integration tests with FakeRuntime — Plan 05.""" +"""End-to-end integration tests with FakeRuntime and a fake HAProxy socket.""" + +from __future__ import annotations + +import socket +import threading +import time +from pathlib import Path + +from fastapi.testclient import TestClient + +from nix_builder_autoscaler.api import create_app +from nix_builder_autoscaler.config import ( + AppConfig, + AwsConfig, + CapacityConfig, + HaproxyConfig, + SchedulerConfig, +) +from nix_builder_autoscaler.metrics import MetricsRegistry +from nix_builder_autoscaler.models import SlotState +from nix_builder_autoscaler.providers.clock import FakeClock +from nix_builder_autoscaler.providers.haproxy import HAProxyRuntime +from nix_builder_autoscaler.reconciler import Reconciler +from nix_builder_autoscaler.runtime.fake import FakeRuntime +from nix_builder_autoscaler.scheduler import scheduling_tick +from nix_builder_autoscaler.state_db import StateDB + + +class FakeHAProxySocketServer: + """Tiny fake HAProxy runtime socket server for integration tests.""" + + def __init__(self, socket_path: Path, backend: str, slot_ids: list[str]) -> None: + self._socket_path = socket_path + self._backend = backend + self._slot_ids = slot_ids + self._stop_event = threading.Event() + self._thread: threading.Thread | None = None + self._lock = threading.Lock() + self._state: dict[str, dict[str, object]] = { + slot_id: { + "enabled": False, + "addr": "0.0.0.0", + "port": 22, + "status": "MAINT", + "scur": 0, + "qcur": 0, + } + for slot_id in slot_ids + } + + def start(self) -> None: + self._thread = threading.Thread(target=self._serve, name="fake-haproxy", daemon=True) + self._thread.start() + deadline = time.time() + 2.0 + while time.time() < deadline: + if self._socket_path.exists(): + return + time.sleep(0.01) + msg = f"fake haproxy socket not created: {self._socket_path}" + raise RuntimeError(msg) + + def stop(self) -> None: + self._stop_event.set() + try: + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: + sock.connect(str(self._socket_path)) + sock.sendall(b"\n") + except OSError: + pass + if self._thread is not None: + self._thread.join(timeout=2.0) + if self._socket_path.exists(): + self._socket_path.unlink() + + def _serve(self) -> None: + if self._socket_path.exists(): + self._socket_path.unlink() + + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as server: + server.bind(str(self._socket_path)) + server.listen(16) + server.settimeout(0.2) + while not self._stop_event.is_set(): + try: + conn, _ = server.accept() + except TimeoutError: + continue + except OSError: + if self._stop_event.is_set(): + break + continue + with conn: + payload = b"" + while True: + chunk = conn.recv(4096) + if not chunk: + break + payload += chunk + command = payload.decode().strip() + response = self._handle_command(command) + try: + conn.sendall(response.encode()) + except BrokenPipeError: + continue + + def _handle_command(self, command: str) -> str: + if command == "show stat": + return self._render_show_stat() + + parts = command.split() + if not parts: + return "\n" + + if parts[0:2] == ["set", "server"] and len(parts) >= 7: + slot_id = self._parse_slot(parts[2]) + if slot_id is None: + return "No such server.\n" + with self._lock: + slot_state = self._state[slot_id] + slot_state["addr"] = parts[4] + slot_state["port"] = int(parts[6]) + slot_state["status"] = "UP" if slot_state["enabled"] else "DOWN" + return "\n" + + if parts[0:2] == ["enable", "server"] and len(parts) >= 3: + slot_id = self._parse_slot(parts[2]) + if slot_id is None: + return "No such server.\n" + with self._lock: + slot_state = self._state[slot_id] + slot_state["enabled"] = True + slot_state["status"] = "UP" + return "\n" + + if parts[0:2] == ["disable", "server"] and len(parts) >= 3: + slot_id = self._parse_slot(parts[2]) + if slot_id is None: + return "No such server.\n" + with self._lock: + slot_state = self._state[slot_id] + slot_state["enabled"] = False + slot_state["status"] = "MAINT" + return "\n" + + return "Unknown command.\n" + + def _parse_slot(self, backend_slot: str) -> str | None: + backend, _, slot_id = backend_slot.partition("/") + if backend != self._backend or slot_id not in self._state: + return None + return slot_id + + def _render_show_stat(self) -> str: + header = "# pxname,svname,qcur,qmax,scur,smax,slim,stot,status\n" + rows = [f"{self._backend},BACKEND,0,0,0,0,0,0,UP\n"] + with self._lock: + for slot_id in self._slot_ids: + slot_state = self._state[slot_id] + rows.append( + f"{self._backend},{slot_id},{slot_state['qcur']},0," + f"{slot_state['scur']},0,50,0,{slot_state['status']}\n" + ) + return header + "".join(rows) + + +class DaemonHarness: + """In-process threaded harness for scheduler/reconciler/API integration.""" + + def __init__( + self, + root: Path, + *, + db_path: Path | None = None, + runtime: FakeRuntime | None = None, + max_slots: int = 3, + min_slots: int = 0, + idle_scale_down_seconds: int = 1, + drain_timeout_seconds: int = 120, + ) -> None: + root.mkdir(parents=True, exist_ok=True) + self.clock = FakeClock() + self.metrics = MetricsRegistry() + self.runtime = runtime or FakeRuntime(launch_latency_ticks=2, ip_delay_ticks=1) + self._stop_event = threading.Event() + self._threads: list[threading.Thread] = [] + self._reconcile_lock = threading.Lock() + + self._db_path = db_path or (root / "state.db") + self._socket_path = root / "haproxy.sock" + self._slot_ids = [f"slot{i:03d}" for i in range(1, 4)] + + self.config = AppConfig( + aws=AwsConfig(region="us-east-1"), + haproxy=HaproxyConfig( + runtime_socket=str(self._socket_path), + backend="all", + slot_prefix="slot", + slot_count=3, + check_ready_up_count=1, + ), + capacity=CapacityConfig( + default_system="x86_64-linux", + max_slots=max_slots, + min_slots=min_slots, + max_leases_per_slot=1, + target_warm_slots=0, + reservation_ttl_seconds=1200, + idle_scale_down_seconds=idle_scale_down_seconds, + drain_timeout_seconds=drain_timeout_seconds, + ), + scheduler=SchedulerConfig(tick_seconds=0.05, reconcile_seconds=0.05), + ) + + self.db = StateDB(str(self._db_path), clock=self.clock) + self.db.init_schema() + self.db.init_slots("slot", 3, "x86_64-linux", "all") + + self.haproxy_server = FakeHAProxySocketServer(self._socket_path, "all", self._slot_ids) + self.haproxy = HAProxyRuntime(str(self._socket_path), "all", "slot") + self.reconciler = Reconciler( + self.db, + self.runtime, + self.haproxy, + self.config, + self.clock, + self.metrics, + ) + + app = create_app( + self.db, + self.config, + self.clock, + self.metrics, + reconcile_now=self.reconcile_now, + ) + self.client = TestClient(app) + + def start(self) -> None: + self.haproxy_server.start() + with self._reconcile_lock: + self.runtime.tick() + self.reconciler.tick() + self._threads = [ + threading.Thread(target=self._scheduler_loop, name="sched", daemon=True), + threading.Thread(target=self._reconciler_loop, name="recon", daemon=True), + ] + for thread in self._threads: + thread.start() + + def stop(self) -> None: + self._stop_event.set() + for thread in self._threads: + thread.join(timeout=2.0) + self.client.close() + self.haproxy_server.stop() + self.db.close() + + def create_reservation(self, reason: str) -> str: + response = self.client.post( + "/v1/reservations", + json={"system": "x86_64-linux", "reason": reason}, + ) + assert response.status_code == 200 + return str(response.json()["reservation_id"]) + + def release_reservation(self, reservation_id: str) -> None: + response = self.client.post(f"/v1/reservations/{reservation_id}/release") + assert response.status_code == 200 + + def reservation(self, reservation_id: str) -> dict: + response = self.client.get(f"/v1/reservations/{reservation_id}") + assert response.status_code == 200 + return response.json() + + def wait_for(self, predicate, timeout: float = 6.0) -> None: # noqa: ANN001 + deadline = time.time() + timeout + while time.time() < deadline: + if predicate(): + return + time.sleep(0.02) + raise AssertionError("condition not met before timeout") + + def reconcile_now(self) -> dict[str, bool]: + with self._reconcile_lock: + self.runtime.tick() + self.reconciler.tick() + return {"triggered": True} + + def _scheduler_loop(self) -> None: + while not self._stop_event.is_set(): + scheduling_tick(self.db, self.runtime, self.config, self.clock, self.metrics) + self._stop_event.wait(self.config.scheduler.tick_seconds) + + def _reconciler_loop(self) -> None: + while not self._stop_event.is_set(): + with self._reconcile_lock: + self.runtime.tick() + self.reconciler.tick() + self._stop_event.wait(self.config.scheduler.reconcile_seconds) + + +def test_cold_start_reservation_launch_bind_ready(tmp_path: Path) -> None: + harness = DaemonHarness(tmp_path) + harness.start() + try: + reservation_id = harness.create_reservation("cold-start") + harness.wait_for(lambda: harness.reservation(reservation_id)["phase"] == "ready") + reservation = harness.reservation(reservation_id) + assert reservation["slot"] is not None + slot = harness.db.get_slot(reservation["slot"]) + assert slot is not None + assert slot["state"] == SlotState.READY.value + assert slot["instance_ip"] is not None + finally: + harness.stop() + + +def test_burst_three_concurrent_reservations(tmp_path: Path) -> None: + harness = DaemonHarness(tmp_path, max_slots=3) + harness.start() + try: + reservation_ids = [harness.create_reservation(f"burst-{i}") for i in range(3)] + harness.wait_for( + lambda: all(harness.reservation(rid)["phase"] == "ready" for rid in reservation_ids), + timeout=8.0, + ) + slots = [harness.reservation(rid)["slot"] for rid in reservation_ids] + assert len(set(slots)) == 3 + finally: + harness.stop() + + +def test_scale_down_after_release_and_idle_timeout(tmp_path: Path) -> None: + harness = DaemonHarness(tmp_path, idle_scale_down_seconds=1, drain_timeout_seconds=0) + harness.start() + try: + reservation_id = harness.create_reservation("scale-down") + harness.wait_for(lambda: harness.reservation(reservation_id)["phase"] == "ready") + slot_id = str(harness.reservation(reservation_id)["slot"]) + + harness.release_reservation(reservation_id) + harness.clock.advance(2) + harness.wait_for( + lambda: ( + harness.db.get_slot(slot_id) is not None + and harness.db.get_slot(slot_id)["state"] == SlotState.EMPTY.value + ) + ) + finally: + harness.stop() + + +def test_restart_recovery_midflight(tmp_path: Path) -> None: + db_path = tmp_path / "state.db" + runtime = FakeRuntime(launch_latency_ticks=6, ip_delay_ticks=2) + + first = DaemonHarness(tmp_path / "run1", db_path=db_path, runtime=runtime) + first.start() + reservation_id = first.create_reservation("restart-midflight") + first.wait_for( + lambda: len(first.db.list_slots(SlotState.LAUNCHING)) > 0, + timeout=4.0, + ) + first.stop() + + second = DaemonHarness(tmp_path / "run2", db_path=db_path, runtime=runtime) + second.start() + try: + second.wait_for(lambda: second.reservation(reservation_id)["phase"] == "ready", timeout=8.0) + finally: + second.stop() + + +def test_interruption_recovery_pending_reservation_resolves(tmp_path: Path) -> None: + harness = DaemonHarness(tmp_path, max_slots=2, idle_scale_down_seconds=60) + harness.start() + try: + first_reservation = harness.create_reservation("baseline") + harness.wait_for(lambda: harness.reservation(first_reservation)["phase"] == "ready") + slot_id = str(harness.reservation(first_reservation)["slot"]) + instance_id = str(harness.reservation(first_reservation)["instance_id"]) + + second_reservation = harness.create_reservation("post-interruption") + harness.release_reservation(first_reservation) + + harness.runtime.inject_interruption(instance_id) + harness.runtime._instances[instance_id].state = "shutting-down" + + harness.wait_for( + lambda: ( + harness.db.get_slot(slot_id) is not None + and harness.db.get_slot(slot_id)["state"] + in { + SlotState.DRAINING.value, + SlotState.TERMINATING.value, + SlotState.EMPTY.value, + } + ), + timeout=6.0, + ) + harness.wait_for( + lambda: harness.reservation(second_reservation)["phase"] == "ready", + timeout=10.0, + ) + finally: + harness.stop() diff --git a/agent/nix_builder_autoscaler/tests/test_reservations_api.py b/agent/nix_builder_autoscaler/tests/test_reservations_api.py index 8a11db1..2d95282 100644 --- a/agent/nix_builder_autoscaler/tests/test_reservations_api.py +++ b/agent/nix_builder_autoscaler/tests/test_reservations_api.py @@ -3,24 +3,29 @@ from __future__ import annotations from datetime import UTC, datetime +from typing import Any from fastapi.testclient import TestClient from nix_builder_autoscaler.api import create_app from nix_builder_autoscaler.config import AppConfig, CapacityConfig from nix_builder_autoscaler.metrics import MetricsRegistry +from nix_builder_autoscaler.models import SlotState from nix_builder_autoscaler.providers.clock import FakeClock from nix_builder_autoscaler.state_db import StateDB -def _make_client() -> tuple[TestClient, StateDB, FakeClock, MetricsRegistry]: +def _make_client( + *, + reconcile_now: Any = None, # noqa: ANN401 +) -> tuple[TestClient, StateDB, FakeClock, MetricsRegistry]: clock = FakeClock() db = StateDB(":memory:", clock=clock) db.init_schema() db.init_slots("slot", 3, "x86_64-linux", "all") config = AppConfig(capacity=CapacityConfig(reservation_ttl_seconds=1200)) metrics = MetricsRegistry() - app = create_app(db, config, clock, metrics) + app = create_app(db, config, clock, metrics, reconcile_now=reconcile_now) return TestClient(app), db, clock, metrics @@ -120,6 +125,20 @@ def test_health_ready_returns_ok_when_no_checks() -> None: assert response.json()["status"] == "ok" +def test_health_ready_degraded_when_ready_check_fails() -> None: + clock = FakeClock() + db = StateDB(":memory:", clock=clock) + db.init_schema() + db.init_slots("slot", 3, "x86_64-linux", "all") + config = AppConfig(capacity=CapacityConfig(reservation_ttl_seconds=1200)) + metrics = MetricsRegistry() + app = create_app(db, config, clock, metrics, ready_check=lambda: False) + client = TestClient(app) + response = client.get("/health/ready") + assert response.status_code == 503 + assert response.json()["status"] == "degraded" + + def test_metrics_returns_prometheus_text() -> None: client, _, _, metrics = _make_client() metrics.counter("autoscaler_test_counter", {}, 1.0) @@ -150,3 +169,67 @@ def test_release_nonexistent_returns_404() -> None: response = client.post("/v1/reservations/resv_nonexistent/release") assert response.status_code == 404 assert response.json()["error"]["code"] == "not_found" + + +def test_admin_drain_success() -> None: + client, db, _, _ = _make_client() + db.update_slot_state("slot001", SlotState.LAUNCHING, instance_id="i-test") + db.update_slot_state("slot001", SlotState.BOOTING) + db.update_slot_state("slot001", SlotState.BINDING, instance_ip="100.64.0.1") + db.update_slot_state("slot001", SlotState.READY) + + response = client.post("/v1/admin/drain", json={"slot_id": "slot001"}) + assert response.status_code == 200 + assert response.json()["state"] == "draining" + slot = db.get_slot("slot001") + assert slot is not None + assert slot["state"] == SlotState.DRAINING.value + + +def test_admin_drain_invalid_state_returns_409() -> None: + client, _, _, _ = _make_client() + response = client.post("/v1/admin/drain", json={"slot_id": "slot001"}) + assert response.status_code == 409 + assert response.json()["error"]["code"] == "invalid_state" + + +def test_admin_unquarantine_success() -> None: + client, db, _, _ = _make_client() + db.update_slot_state("slot001", SlotState.ERROR, instance_id="i-bad") + + response = client.post("/v1/admin/unquarantine", json={"slot_id": "slot001"}) + assert response.status_code == 200 + assert response.json()["state"] == "empty" + slot = db.get_slot("slot001") + assert slot is not None + assert slot["state"] == SlotState.EMPTY.value + assert slot["instance_id"] is None + + +def test_admin_unquarantine_invalid_state_returns_409() -> None: + client, _, _, _ = _make_client() + response = client.post("/v1/admin/unquarantine", json={"slot_id": "slot001"}) + assert response.status_code == 409 + assert response.json()["error"]["code"] == "invalid_state" + + +def test_admin_reconcile_now_not_configured_returns_503() -> None: + client, _, _, _ = _make_client() + response = client.post("/v1/admin/reconcile-now") + assert response.status_code == 503 + assert response.json()["error"]["code"] == "not_configured" + + +def test_admin_reconcile_now_success() -> None: + called = {"value": False} + + def _reconcile_now() -> dict[str, object]: + called["value"] = True + return {"triggered": True} + + client, _, _, _ = _make_client(reconcile_now=_reconcile_now) + response = client.post("/v1/admin/reconcile-now") + assert response.status_code == 200 + assert response.json()["status"] == "accepted" + assert response.json()["triggered"] is True + assert called["value"] is True diff --git a/agent/nix_builder_autoscaler/tests/test_runtime_ec2.py b/agent/nix_builder_autoscaler/tests/test_runtime_ec2.py index 40c35b3..a8d9ffe 100644 --- a/agent/nix_builder_autoscaler/tests/test_runtime_ec2.py +++ b/agent/nix_builder_autoscaler/tests/test_runtime_ec2.py @@ -130,6 +130,135 @@ class TestDescribeInstance: assert info["tailscale_ip"] is None assert info["launch_time"] == launch_time.isoformat() + @patch.object( + EC2Runtime, + "_read_tailscale_status", + return_value={ + "Peer": { + "peer1": { + "HostName": "nix-builder-slot001-i-running1", + "Online": True, + "TailscaleIPs": ["100.64.0.10"], + } + } + }, + ) + def test_discovers_tailscale_ip_from_localapi(self, _mock_status): + ec2_client = boto3.client("ec2", region_name="us-east-1") + stubber = Stubber(ec2_client) + + launch_time = datetime(2026, 1, 15, 12, 30, 0, tzinfo=UTC) + response = { + "Reservations": [ + { + "Instances": [ + { + "InstanceId": "i-running1", + "State": {"Code": 16, "Name": "running"}, + "LaunchTime": launch_time, + "Tags": [{"Key": "AutoscalerSlot", "Value": "slot001"}], + } + ], + } + ], + } + stubber.add_response( + "describe_instances", + response, + {"InstanceIds": ["i-running1"]}, + ) + runtime = _make_runtime(stubber, ec2_client) + + info = runtime.describe_instance("i-running1") + assert info["tailscale_ip"] == "100.64.0.10" + + @patch.object(EC2Runtime, "_read_tailscale_status", return_value={"Peer": {}}) + def test_discovery_unavailable_returns_none(self, _mock_status): + ec2_client = boto3.client("ec2", region_name="us-east-1") + stubber = Stubber(ec2_client) + + launch_time = datetime(2026, 1, 15, 12, 30, 0, tzinfo=UTC) + response = { + "Reservations": [ + { + "Instances": [ + { + "InstanceId": "i-running1", + "State": {"Code": 16, "Name": "running"}, + "LaunchTime": launch_time, + "Tags": [{"Key": "AutoscalerSlot", "Value": "slot001"}], + } + ], + } + ], + } + stubber.add_response( + "describe_instances", + response, + {"InstanceIds": ["i-running1"]}, + ) + runtime = _make_runtime(stubber, ec2_client) + + info = runtime.describe_instance("i-running1") + assert info["tailscale_ip"] is None + + @patch.object( + EC2Runtime, + "_read_tailscale_status", + return_value={ + "Peer": { + "peer1": { + "HostName": "nix-builder-slot001-old", + "Online": True, + "TailscaleIPs": ["100.64.0.10"], + }, + "peer2": { + "HostName": "nix-builder-slot001-new", + "Online": True, + "TailscaleIPs": ["100.64.0.11"], + }, + } + }, + ) + def test_ambiguous_slot_match_returns_none(self, _mock_status): + ec2_client = boto3.client("ec2", region_name="us-east-1") + stubber = Stubber(ec2_client) + + launch_time = datetime(2026, 1, 15, 12, 30, 0, tzinfo=UTC) + response = { + "Reservations": [ + { + "Instances": [ + { + "InstanceId": "i-running1", + "State": {"Code": 16, "Name": "running"}, + "LaunchTime": launch_time, + "Tags": [{"Key": "AutoscalerSlot", "Value": "slot001"}], + } + ], + } + ], + } + stubber.add_response( + "describe_instances", + response, + {"InstanceIds": ["i-running1"]}, + ) + runtime = _make_runtime(stubber, ec2_client) + + info = runtime.describe_instance("i-running1") + assert info["tailscale_ip"] is None + + def test_localapi_permission_error_returns_none(self): + ec2_client = boto3.client("ec2", region_name="us-east-1") + runtime = EC2Runtime(_make_config(), _client=ec2_client) + + with patch( + "nix_builder_autoscaler.runtime.ec2._UnixSocketHTTPConnection.connect", + side_effect=PermissionError, + ): + assert runtime._read_tailscale_status() is None + def test_missing_instance_returns_terminated(self): ec2_client = boto3.client("ec2", region_name="us-east-1") stubber = Stubber(ec2_client) diff --git a/flake.nix b/flake.nix index 7721148..c066364 100644 --- a/flake.nix +++ b/flake.nix @@ -178,8 +178,7 @@ checkPhase = '' runHook preCheck export HOME=$(mktemp -d) - # Exit code 5 means no tests collected — tolerate until integration tests are written - pytest nix_builder_autoscaler/tests/integration/ -v || test $? -eq 5 + pytest nix_builder_autoscaler/tests/integration/ -v runHook postCheck ''; doCheck = true;