agent: complete plan05 closeout

This commit is contained in:
Abel Luck 2026-02-27 13:48:52 +01:00
parent 33ba248c49
commit 2f0fffa905
12 changed files with 1347 additions and 313 deletions

View file

@ -6,6 +6,7 @@ import argparse
import logging import logging
import signal import signal
import threading import threading
import time
from pathlib import Path from pathlib import Path
from types import FrameType from types import FrameType
@ -25,6 +26,29 @@ from .state_db import StateDB
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class LoopHealth:
"""Thread-safe last-success timestamps for daemon loops."""
def __init__(self) -> None:
self._lock = threading.Lock()
self._last_success: dict[str, float] = {}
def mark_success(self, loop_name: str) -> None:
with self._lock:
self._last_success[loop_name] = time.monotonic()
def is_fresh(self, loop_name: str, max_age_seconds: float) -> bool:
with self._lock:
last = self._last_success.get(loop_name)
if last is None:
return False
return (time.monotonic() - last) <= max_age_seconds
def _max_staleness(interval_seconds: float) -> float:
return max(interval_seconds * 3.0, 15.0)
def _scheduler_loop( def _scheduler_loop(
db: StateDB, db: StateDB,
runtime: EC2Runtime, runtime: EC2Runtime,
@ -32,10 +56,12 @@ def _scheduler_loop(
clock: SystemClock, clock: SystemClock,
metrics: MetricsRegistry, metrics: MetricsRegistry,
stop_event: threading.Event, stop_event: threading.Event,
loop_health: LoopHealth,
) -> None: ) -> None:
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
scheduling_tick(db, runtime, config, clock, metrics) scheduling_tick(db, runtime, config, clock, metrics)
loop_health.mark_success("scheduler")
except Exception: except Exception:
log.exception("scheduler_tick_failed") log.exception("scheduler_tick_failed")
stop_event.wait(config.scheduler.tick_seconds) stop_event.wait(config.scheduler.tick_seconds)
@ -45,15 +71,36 @@ def _reconciler_loop(
reconciler: Reconciler, reconciler: Reconciler,
config: AppConfig, config: AppConfig,
stop_event: threading.Event, stop_event: threading.Event,
loop_health: LoopHealth,
reconcile_lock: threading.Lock,
) -> None: ) -> None:
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
with reconcile_lock:
reconciler.tick() reconciler.tick()
loop_health.mark_success("reconciler")
except Exception: except Exception:
log.exception("reconciler_tick_failed") log.exception("reconciler_tick_failed")
stop_event.wait(config.scheduler.reconcile_seconds) stop_event.wait(config.scheduler.reconcile_seconds)
def _metrics_health_loop(
metrics: MetricsRegistry,
stop_event: threading.Event,
loop_health: LoopHealth,
interval_seconds: float,
) -> None:
while not stop_event.is_set():
try:
metrics.gauge("autoscaler_loop_up", {"loop": "scheduler"}, 1.0)
metrics.gauge("autoscaler_loop_up", {"loop": "reconciler"}, 1.0)
metrics.gauge("autoscaler_loop_up", {"loop": "metrics"}, 1.0)
loop_health.mark_success("metrics")
except Exception:
log.exception("metrics_health_tick_failed")
stop_event.wait(interval_seconds)
def _parse_args() -> argparse.Namespace: def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog="nix-builder-autoscaler", prog="nix-builder-autoscaler",
@ -92,7 +139,11 @@ def main() -> None:
stop_event = threading.Event() stop_event = threading.Event()
scheduler_thread: threading.Thread | None = None scheduler_thread: threading.Thread | None = None
reconciler_thread: threading.Thread | None = None reconciler_thread: threading.Thread | None = None
metrics_thread: threading.Thread | None = None
server: uvicorn.Server | None = None server: uvicorn.Server | None = None
loop_health = LoopHealth()
reconcile_lock = threading.Lock()
metrics_interval = 5.0
def scheduler_running() -> bool: def scheduler_running() -> bool:
return scheduler_thread is not None and scheduler_thread.is_alive() return scheduler_thread is not None and scheduler_thread.is_alive()
@ -100,6 +151,32 @@ def main() -> None:
def reconciler_running() -> bool: def reconciler_running() -> bool:
return reconciler_thread is not None and reconciler_thread.is_alive() return reconciler_thread is not None and reconciler_thread.is_alive()
def metrics_running() -> bool:
return metrics_thread is not None and metrics_thread.is_alive()
def ready_check() -> bool:
checks = [
("scheduler", scheduler_running(), _max_staleness(config.scheduler.tick_seconds)),
(
"reconciler",
reconciler_running(),
_max_staleness(config.scheduler.reconcile_seconds),
),
("metrics", metrics_running(), _max_staleness(metrics_interval)),
]
for loop_name, alive, max_age in checks:
if not alive:
return False
if not loop_health.is_fresh(loop_name, max_age):
return False
return True
def reconcile_now() -> dict[str, object]:
with reconcile_lock:
reconciler.tick()
loop_health.mark_success("reconciler")
return {"triggered": True}
app = create_app( app = create_app(
db, db,
config, config,
@ -109,23 +186,36 @@ def main() -> None:
haproxy=haproxy, haproxy=haproxy,
scheduler_running=scheduler_running, scheduler_running=scheduler_running,
reconciler_running=reconciler_running, reconciler_running=reconciler_running,
ready_check=ready_check,
reconcile_now=reconcile_now,
) )
loop_health.mark_success("scheduler")
loop_health.mark_success("reconciler")
loop_health.mark_success("metrics")
scheduler_thread = threading.Thread( scheduler_thread = threading.Thread(
target=_scheduler_loop, target=_scheduler_loop,
name="autoscaler-scheduler", name="autoscaler-scheduler",
args=(db, runtime, config, clock, metrics, stop_event), args=(db, runtime, config, clock, metrics, stop_event, loop_health),
daemon=True, daemon=True,
) )
reconciler_thread = threading.Thread( reconciler_thread = threading.Thread(
target=_reconciler_loop, target=_reconciler_loop,
name="autoscaler-reconciler", name="autoscaler-reconciler",
args=(reconciler, config, stop_event), args=(reconciler, config, stop_event, loop_health, reconcile_lock),
daemon=True,
)
metrics_thread = threading.Thread(
target=_metrics_health_loop,
name="autoscaler-metrics-health",
args=(metrics, stop_event, loop_health, metrics_interval),
daemon=True, daemon=True,
) )
scheduler_thread.start() scheduler_thread.start()
reconciler_thread.start() reconciler_thread.start()
metrics_thread.start()
socket_path = Path(config.server.socket_path) socket_path = Path(config.server.socket_path)
socket_path.parent.mkdir(parents=True, exist_ok=True) socket_path.parent.mkdir(parents=True, exist_ok=True)
@ -156,6 +246,8 @@ def main() -> None:
scheduler_thread.join(timeout=10) scheduler_thread.join(timeout=10)
if reconciler_thread is not None: if reconciler_thread is not None:
reconciler_thread.join(timeout=10) reconciler_thread.join(timeout=10)
if metrics_thread is not None:
metrics_thread.join(timeout=10)
db.close() db.close()

View file

@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, NoReturn
from fastapi import FastAPI, HTTPException, Request, Response from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import BaseModel
from .models import ( from .models import (
CapacityHint, CapacityHint,
@ -35,6 +36,12 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class SlotAdminRequest(BaseModel):
"""Admin action request that targets a slot."""
slot_id: str
def _parse_required_dt(value: str) -> datetime: def _parse_required_dt(value: str) -> datetime:
return datetime.fromisoformat(value) return datetime.fromisoformat(value)
@ -95,6 +102,8 @@ def create_app(
haproxy: HAProxyRuntime | None = None, haproxy: HAProxyRuntime | None = None,
scheduler_running: Callable[[], bool] | None = None, scheduler_running: Callable[[], bool] | None = None,
reconciler_running: Callable[[], bool] | None = None, reconciler_running: Callable[[], bool] | None = None,
ready_check: Callable[[], bool] | None = None,
reconcile_now: Callable[[], dict[str, object] | None] | None = None,
) -> FastAPI: ) -> FastAPI:
"""Create the FastAPI application.""" """Create the FastAPI application."""
app = FastAPI(title="nix-builder-autoscaler", version="0.1.0") app = FastAPI(title="nix-builder-autoscaler", version="0.1.0")
@ -191,6 +200,11 @@ def create_app(
@app.get("/health/ready", response_model=HealthResponse) @app.get("/health/ready", response_model=HealthResponse)
def health_ready() -> HealthResponse: def health_ready() -> HealthResponse:
if ready_check is not None and not ready_check():
return JSONResponse( # type: ignore[return-value]
status_code=503,
content=HealthResponse(status="degraded").model_dump(mode="json"),
)
if scheduler_running is not None and not scheduler_running(): if scheduler_running is not None and not scheduler_running():
return JSONResponse( # type: ignore[return-value] return JSONResponse( # type: ignore[return-value]
status_code=503, status_code=503,
@ -207,4 +221,83 @@ def create_app(
def metrics_endpoint() -> Response: def metrics_endpoint() -> Response:
return Response(content=metrics.render(), media_type="text/plain") return Response(content=metrics.render(), media_type="text/plain")
@app.post("/v1/admin/drain")
def admin_drain(body: SlotAdminRequest, request: Request) -> dict[str, str]:
slot = db.get_slot(body.slot_id)
if slot is None:
_error_response(request, 404, "not_found", "Slot not found")
state = str(slot["state"])
if state == SlotState.DRAINING.value or state == SlotState.TERMINATING.value:
return {"status": "accepted", "slot_id": body.slot_id, "state": state}
allowed_states = {
SlotState.READY.value,
SlotState.BINDING.value,
SlotState.BOOTING.value,
SlotState.LAUNCHING.value,
}
if state not in allowed_states:
_error_response(
request,
409,
"invalid_state",
f"Cannot drain slot from state {state}",
)
db.update_slot_state(body.slot_id, SlotState.DRAINING, interruption_pending=0)
return {"status": "accepted", "slot_id": body.slot_id, "state": SlotState.DRAINING.value}
@app.post("/v1/admin/unquarantine")
def admin_unquarantine(body: SlotAdminRequest, request: Request) -> dict[str, str]:
slot = db.get_slot(body.slot_id)
if slot is None:
_error_response(request, 404, "not_found", "Slot not found")
state = str(slot["state"])
if state != SlotState.ERROR.value:
_error_response(
request,
409,
"invalid_state",
f"Cannot unquarantine slot from state {state}",
)
db.update_slot_state(
body.slot_id,
SlotState.EMPTY,
instance_id=None,
instance_ip=None,
instance_launch_time=None,
lease_count=0,
cooldown_until=None,
interruption_pending=0,
)
return {"status": "accepted", "slot_id": body.slot_id, "state": SlotState.EMPTY.value}
@app.post("/v1/admin/reconcile-now")
def admin_reconcile_now(request: Request) -> dict[str, object]:
if reconcile_now is None:
_error_response(
request,
503,
"not_configured",
"Reconcile trigger not configured",
retryable=True,
)
try:
result = reconcile_now()
except Exception:
log.exception("admin_reconcile_now_failed")
_error_response(
request,
500,
"reconcile_failed",
"Reconcile tick failed",
retryable=True,
)
payload: dict[str, object] = {"status": "accepted"}
if isinstance(result, dict):
payload.update(result)
return payload
return app return app

View file

@ -41,13 +41,22 @@ def render_userdata(slot_id: str, region: str, ssm_param: str = "/nix-builder/ts
printf '%s' "$TS_AUTHKEY" > /run/credentials/tailscale-auth-key printf '%s' "$TS_AUTHKEY" > /run/credentials/tailscale-auth-key
chmod 600 /run/credentials/tailscale-auth-key chmod 600 /run/credentials/tailscale-auth-key
# --- Resolve instance identity from IMDSv2 for unique hostname ---
IMDS_TOKEN=$(curl -fsS -X PUT "http://169.254.169.254/latest/api/token" \\
-H "X-aws-ec2-metadata-token-ttl-seconds: 21600" || true)
INSTANCE_ID=$(curl -fsS -H "X-aws-ec2-metadata-token: $IMDS_TOKEN" \\
"http://169.254.169.254/latest/meta-data/instance-id" || true)
if [ -z "$INSTANCE_ID" ]; then
INSTANCE_ID="unknown"
fi
# --- Write tailscale-autoconnect config --- # --- Write tailscale-autoconnect config ---
mkdir -p /etc/tailscale mkdir -p /etc/tailscale
cat > /etc/tailscale/autoconnect.conf <<TSCONF cat > /etc/tailscale/autoconnect.conf <<TSCONF
TS_AUTHKEY_FILE=/run/credentials/tailscale-auth-key TS_AUTHKEY_FILE=/run/credentials/tailscale-auth-key
TS_AUTHKEY_EPHEMERAL=true TS_AUTHKEY_EPHEMERAL=true
TS_AUTHKEY_PREAUTHORIZED=true TS_AUTHKEY_PREAUTHORIZED=true
TS_HOSTNAME=nix-builder-$SLOT_ID TS_HOSTNAME=nix-builder-$SLOT_ID-$INSTANCE_ID
TS_EXTRA_ARGS="--ssh --advertise-tags=tag:nix-builder" TS_EXTRA_ARGS="--ssh --advertise-tags=tag:nix-builder"
TSCONF TSCONF

View file

@ -106,14 +106,11 @@ def _parse_args() -> argparse.Namespace:
subparsers.add_parser("slots", help="List slots") subparsers.add_parser("slots", help="List slots")
subparsers.add_parser("reservations", help="List reservations") subparsers.add_parser("reservations", help="List reservations")
parser_drain = subparsers.add_parser("drain", help="Drain a slot (not implemented)") parser_drain = subparsers.add_parser("drain", help="Drain a slot")
parser_drain.add_argument("slot_id") parser_drain.add_argument("slot_id")
parser_unq = subparsers.add_parser( parser_unq = subparsers.add_parser("unquarantine", help="Unquarantine a slot")
"unquarantine",
help="Unquarantine a slot (not implemented)",
)
parser_unq.add_argument("slot_id") parser_unq.add_argument("slot_id")
subparsers.add_parser("reconcile-now", help="Run reconciler now (not implemented)") subparsers.add_parser("reconcile-now", help="Trigger immediate reconcile tick")
return parser.parse_args() return parser.parse_args()
@ -130,19 +127,31 @@ def main() -> None:
if not args.command: if not args.command:
raise SystemExit(1) raise SystemExit(1)
if args.command in {"drain", "unquarantine", "reconcile-now"}: method = "GET"
print(f"{args.command}: not yet implemented in API v1") path = ""
raise SystemExit(0) body: dict[str, Any] | None = None
if args.command == "status":
endpoint_map = { path = "/v1/state/summary"
"status": "/v1/state/summary", elif args.command == "slots":
"slots": "/v1/slots", path = "/v1/slots"
"reservations": "/v1/reservations", elif args.command == "reservations":
} path = "/v1/reservations"
path = endpoint_map[args.command] elif args.command == "drain":
method = "POST"
path = "/v1/admin/drain"
body = {"slot_id": args.slot_id}
elif args.command == "unquarantine":
method = "POST"
path = "/v1/admin/unquarantine"
body = {"slot_id": args.slot_id}
elif args.command == "reconcile-now":
method = "POST"
path = "/v1/admin/reconcile-now"
else:
raise SystemExit(1)
try: try:
status, data = _uds_request(args.socket, "GET", path) status, data = _uds_request(args.socket, method, path, body=body)
except OSError as err: except OSError as err:
print(f"Error: cannot connect to daemon at {args.socket}") print(f"Error: cannot connect to daemon at {args.socket}")
raise SystemExit(1) from err raise SystemExit(1) from err
@ -151,7 +160,7 @@ def main() -> None:
_print_error(data) _print_error(data)
raise SystemExit(1) raise SystemExit(1)
if args.command == "status": if args.command in {"status", "drain", "unquarantine", "reconcile-now"}:
print(json.dumps(data, indent=2)) print(json.dumps(data, indent=2))
elif args.command == "slots": elif args.command == "slots":
if isinstance(data, list): if isinstance(data, list):

View file

@ -68,7 +68,7 @@ class Reconciler:
# 2. Query HAProxy # 2. Query HAProxy
try: try:
haproxy_health = self._haproxy.read_slot_health() haproxy_health = self._haproxy_read_slot_health()
except HAProxyError: except HAProxyError:
log.warning("haproxy_stat_failed", exc_info=True) log.warning("haproxy_stat_failed", exc_info=True)
haproxy_health = {} haproxy_health = {}
@ -142,8 +142,8 @@ class Reconciler:
if tailscale_ip is not None: if tailscale_ip is not None:
self._db.update_slot_state(slot["slot_id"], SlotState.BINDING, instance_ip=tailscale_ip) self._db.update_slot_state(slot["slot_id"], SlotState.BINDING, instance_ip=tailscale_ip)
try: try:
self._haproxy.set_slot_addr(slot["slot_id"], tailscale_ip) self._haproxy_set_slot_addr(slot["slot_id"], tailscale_ip)
self._haproxy.enable_slot(slot["slot_id"]) self._haproxy_enable_slot(slot["slot_id"])
except HAProxyError: except HAProxyError:
log.warning( log.warning(
"haproxy_binding_setup_failed", "haproxy_binding_setup_failed",
@ -169,8 +169,8 @@ class Reconciler:
ip = slot.get("instance_ip") ip = slot.get("instance_ip")
if ip: if ip:
try: try:
self._haproxy.set_slot_addr(slot_id, ip) self._haproxy_set_slot_addr(slot_id, ip)
self._haproxy.enable_slot(slot_id) self._haproxy_enable_slot(slot_id)
except HAProxyError: except HAProxyError:
pass pass
@ -204,7 +204,7 @@ class Reconciler:
# Disable HAProxy (idempotent) # Disable HAProxy (idempotent)
with contextlib.suppress(HAProxyError): with contextlib.suppress(HAProxyError):
self._haproxy.disable_slot(slot_id) self._haproxy_disable_slot(slot_id)
now = self._clock.now() now = self._clock.now()
last_change = datetime.fromisoformat(slot["last_state_change"]) last_change = datetime.fromisoformat(slot["last_state_change"])
@ -216,8 +216,17 @@ class Reconciler:
if instance_id: if instance_id:
try: try:
self._runtime.terminate_instance(instance_id) self._runtime.terminate_instance(instance_id)
self._metrics.counter("autoscaler_ec2_terminate_total", {}, 1.0) self._metrics.counter(
"autoscaler_ec2_terminate_total",
{"result": "success"},
1.0,
)
except Exception: except Exception:
self._metrics.counter(
"autoscaler_ec2_terminate_total",
{"result": "error"},
1.0,
)
log.warning( log.warning(
"terminate_failed", "terminate_failed",
extra={"slot_id": slot_id, "instance_id": instance_id}, extra={"slot_id": slot_id, "instance_id": instance_id},
@ -252,7 +261,70 @@ class Reconciler:
"""Emit reconciler metrics.""" """Emit reconciler metrics."""
summary = self._db.get_state_summary() summary = self._db.get_state_summary()
for state, count in summary["slots"].items(): for state, count in summary["slots"].items():
if state == "total": self._metrics.gauge("autoscaler_slots_total", {"state": state}, float(count))
continue self._metrics.histogram_observe("autoscaler_reconcile_duration_seconds", {}, tick_duration)
self._metrics.gauge("autoscaler_slots", {"state": state}, float(count))
self._metrics.histogram_observe("autoscaler_reconciler_tick_seconds", {}, tick_duration) def _haproxy_set_slot_addr(self, slot_id: str, ip: str) -> None:
try:
self._haproxy.set_slot_addr(slot_id, ip)
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "set_slot_addr", "result": "success"},
1.0,
)
except HAProxyError:
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "set_slot_addr", "result": "error"},
1.0,
)
raise
def _haproxy_enable_slot(self, slot_id: str) -> None:
try:
self._haproxy.enable_slot(slot_id)
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "enable_slot", "result": "success"},
1.0,
)
except HAProxyError:
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "enable_slot", "result": "error"},
1.0,
)
raise
def _haproxy_disable_slot(self, slot_id: str) -> None:
try:
self._haproxy.disable_slot(slot_id)
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "disable_slot", "result": "success"},
1.0,
)
except HAProxyError:
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "disable_slot", "result": "error"},
1.0,
)
raise
def _haproxy_read_slot_health(self) -> dict:
try:
health = self._haproxy.read_slot_health()
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "show_stat", "result": "success"},
1.0,
)
return health
except HAProxyError:
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "show_stat", "result": "error"},
1.0,
)
raise

View file

@ -2,8 +2,11 @@
from __future__ import annotations from __future__ import annotations
import http.client
import json
import logging import logging
import random import random
import socket
import time import time
from typing import Any from typing import Any
@ -26,6 +29,18 @@ _ERROR_CATEGORIES: dict[str, str] = {
_RETRYABLE_CODES: frozenset[str] = frozenset({"RequestLimitExceeded"}) _RETRYABLE_CODES: frozenset[str] = frozenset({"RequestLimitExceeded"})
class _UnixSocketHTTPConnection(http.client.HTTPConnection):
"""HTTP connection over a Unix domain socket."""
def __init__(self, socket_path: str, timeout: float = 1.0) -> None:
super().__init__("local-tailscaled.sock", timeout=timeout)
self._socket_path = socket_path
def connect(self) -> None:
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.sock.connect(self._socket_path)
class EC2Runtime(RuntimeAdapter): class EC2Runtime(RuntimeAdapter):
"""EC2 Spot instance runtime adapter. """EC2 Spot instance runtime adapter.
@ -41,6 +56,7 @@ class EC2Runtime(RuntimeAdapter):
environment: str = "dev", environment: str = "dev",
*, *,
_client: Any = None, _client: Any = None,
_tailscale_socket_path: str = "/run/tailscale/tailscaled.sock",
) -> None: ) -> None:
self._client: Any = _client or boto3.client("ec2", region_name=config.region) self._client: Any = _client or boto3.client("ec2", region_name=config.region)
self._launch_template_id = config.launch_template_id self._launch_template_id = config.launch_template_id
@ -49,6 +65,7 @@ class EC2Runtime(RuntimeAdapter):
self._instance_profile_arn = config.instance_profile_arn self._instance_profile_arn = config.instance_profile_arn
self._environment = environment self._environment = environment
self._subnet_index = 0 self._subnet_index = 0
self._tailscale_socket_path = _tailscale_socket_path
def launch_spot(self, slot_id: str, user_data: str) -> str: def launch_spot(self, slot_id: str, user_data: str) -> str:
"""Launch a spot instance for *slot_id*. Return instance ID.""" """Launch a spot instance for *slot_id*. Return instance ID."""
@ -103,10 +120,17 @@ class EC2Runtime(RuntimeAdapter):
return {"state": "terminated", "tailscale_ip": None, "launch_time": None} return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
inst = reservations[0]["Instances"][0] inst = reservations[0]["Instances"][0]
tags = inst.get("Tags", [])
slot_id = self._get_tag(tags, "AutoscalerSlot")
state = inst["State"]["Name"]
tailscale_ip: str | None = None
if state == "running" and slot_id:
tailscale_ip = self._discover_tailscale_ip(slot_id, instance_id)
launch_time = inst.get("LaunchTime") launch_time = inst.get("LaunchTime")
return { return {
"state": inst["State"]["Name"], "state": state,
"tailscale_ip": None, "tailscale_ip": tailscale_ip,
"launch_time": launch_time.isoformat() if launch_time else None, "launch_time": launch_time.isoformat() if launch_time else None,
} }
@ -166,6 +190,98 @@ class EC2Runtime(RuntimeAdapter):
msg = "Retries exhausted" msg = "Retries exhausted"
raise RuntimeAdapterError(msg, category="unknown") raise RuntimeAdapterError(msg, category="unknown")
def _discover_tailscale_ip(self, slot_id: str, instance_id: str) -> str | None:
"""Resolve Tailscale IP for instance identity via local tailscaled LocalAPI."""
status = self._read_tailscale_status()
if status is None:
return None
peers_obj = status.get("Peer")
if not isinstance(peers_obj, dict):
return None
online_candidates: list[tuple[str, str]] = []
for peer in peers_obj.values():
if not isinstance(peer, dict):
continue
if not self._peer_is_online(peer):
continue
hostname = self._peer_hostname(peer)
if hostname is None:
continue
ip = self._peer_tailscale_ip(peer)
if ip is None:
continue
online_candidates.append((hostname, ip))
identity = f"nix-builder-{slot_id}-{instance_id}".lower()
identity_matches = [ip for host, ip in online_candidates if identity in host]
if len(identity_matches) == 1:
return identity_matches[0]
if len(identity_matches) > 1:
log.warning(
"tailscale_identity_ambiguous",
extra={"slot_id": slot_id, "instance_id": instance_id},
)
return None
slot_identity = f"nix-builder-{slot_id}".lower()
slot_matches = [ip for host, ip in online_candidates if slot_identity in host]
if len(slot_matches) == 1:
return slot_matches[0]
if len(slot_matches) > 1:
log.warning("tailscale_slot_ambiguous", extra={"slot_id": slot_id})
return None
return None
def _read_tailscale_status(self) -> dict[str, Any] | None:
"""Query local tailscaled LocalAPI status endpoint over Unix socket."""
conn = _UnixSocketHTTPConnection(self._tailscale_socket_path, timeout=1.0)
try:
conn.request(
"GET",
"/localapi/v0/status",
headers={"Host": "local-tailscaled.sock", "Accept": "application/json"},
)
response = conn.getresponse()
if response.status != 200:
return None
payload = response.read()
parsed = json.loads(payload.decode())
if isinstance(parsed, dict):
return parsed
return None
except (OSError, PermissionError, TimeoutError, json.JSONDecodeError, UnicodeDecodeError):
return None
except http.client.HTTPException:
return None
finally:
conn.close()
@staticmethod
def _peer_is_online(peer: dict[str, Any]) -> bool:
return bool(peer.get("Online") or peer.get("Active"))
@staticmethod
def _peer_hostname(peer: dict[str, Any]) -> str | None:
host = peer.get("HostName") or peer.get("DNSName")
if not isinstance(host, str) or not host:
return None
return host.strip(".").lower()
@staticmethod
def _peer_tailscale_ip(peer: dict[str, Any]) -> str | None:
ips = peer.get("TailscaleIPs")
if not isinstance(ips, list):
return None
ipv4 = [ip for ip in ips if isinstance(ip, str) and "." in ip]
if ipv4:
return ipv4[0]
for ip in ips:
if isinstance(ip, str) and ip:
return ip
return None
@staticmethod @staticmethod
def _get_tag(tags: list[dict[str, str]], key: str) -> str | None: def _get_tag(tags: list[dict[str, str]], key: str) -> str | None:
"""Extract a tag value from an EC2 tag list.""" """Extract a tag value from an EC2 tag list."""

View file

@ -217,12 +217,13 @@ def _launch_slot(
"""Launch a single slot. Transition to LAUNCHING on success, ERROR on failure.""" """Launch a single slot. Transition to LAUNCHING on success, ERROR on failure."""
slot_id = slot["slot_id"] slot_id = slot["slot_id"]
user_data = render_userdata(slot_id, config.aws.region) user_data = render_userdata(slot_id, config.aws.region)
metrics.counter("autoscaler_ec2_launch_total", {}, 1.0)
try: try:
instance_id = runtime.launch_spot(slot_id, user_data) instance_id = runtime.launch_spot(slot_id, user_data)
metrics.counter("autoscaler_ec2_launch_total", {"result": "success"}, 1.0)
db.update_slot_state(slot_id, SlotState.LAUNCHING, instance_id=instance_id) db.update_slot_state(slot_id, SlotState.LAUNCHING, instance_id=instance_id)
log.info("slot_launched", extra={"slot_id": slot_id, "instance_id": instance_id}) log.info("slot_launched", extra={"slot_id": slot_id, "instance_id": instance_id})
except RuntimeAdapterError as exc: except RuntimeAdapterError as exc:
metrics.counter("autoscaler_ec2_launch_total", {"result": exc.category}, 1.0)
db.update_slot_state(slot_id, SlotState.ERROR) db.update_slot_state(slot_id, SlotState.ERROR)
log.warning( log.warning(
"slot_launch_failed", "slot_launch_failed",
@ -257,11 +258,9 @@ def _update_metrics(db: StateDB, metrics: MetricsRegistry, tick_duration: float)
summary = db.get_state_summary() summary = db.get_state_summary()
for state, count in summary["slots"].items(): for state, count in summary["slots"].items():
if state == "total": metrics.gauge("autoscaler_slots_total", {"state": state}, float(count))
continue
metrics.gauge("autoscaler_slots", {"state": state}, float(count))
for phase, count in summary["reservations"].items(): for phase, count in summary["reservations"].items():
metrics.gauge("autoscaler_reservations", {"phase": phase}, float(count)) metrics.gauge("autoscaler_reservations_total", {"phase": phase}, float(count))
metrics.histogram_observe("autoscaler_scheduler_tick_seconds", {}, tick_duration) metrics.histogram_observe("autoscaler_scheduler_tick_duration_seconds", {}, tick_duration)

View file

@ -7,6 +7,7 @@ from __future__ import annotations
import json import json
import sqlite3 import sqlite3
import threading
import uuid import uuid
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from pathlib import Path from pathlib import Path
@ -75,13 +76,16 @@ class StateDB:
self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA busy_timeout=5000") self._conn.execute("PRAGMA busy_timeout=5000")
self._clock = clock self._clock = clock
self._lock = threading.RLock()
def init_schema(self) -> None: def init_schema(self) -> None:
"""Create tables if they don't exist.""" """Create tables if they don't exist."""
with self._lock:
self._conn.executescript(_SCHEMA) self._conn.executescript(_SCHEMA)
def init_slots(self, slot_prefix: str, slot_count: int, system: str, backend: str) -> None: def init_slots(self, slot_prefix: str, slot_count: int, system: str, backend: str) -> None:
"""Ensure all expected slots exist, creating missing ones as empty.""" """Ensure all expected slots exist, creating missing ones as empty."""
with self._lock:
now = _now_iso(self._clock) now = _now_iso(self._clock)
for i in range(1, slot_count + 1): for i in range(1, slot_count + 1):
slot_id = f"{slot_prefix}{i:03d}" slot_id = f"{slot_prefix}{i:03d}"
@ -98,6 +102,7 @@ class StateDB:
def get_slot(self, slot_id: str) -> dict | None: def get_slot(self, slot_id: str) -> dict | None:
"""Return a slot row as dict, or None.""" """Return a slot row as dict, or None."""
with self._lock:
cur = self._conn.execute("SELECT * FROM slots WHERE slot_id = ?", (slot_id,)) cur = self._conn.execute("SELECT * FROM slots WHERE slot_id = ?", (slot_id,))
row = cur.fetchone() row = cur.fetchone()
if row is None: if row is None:
@ -106,6 +111,7 @@ class StateDB:
def list_slots(self, state: SlotState | None = None) -> list[dict]: def list_slots(self, state: SlotState | None = None) -> list[dict]:
"""List slots, optionally filtered by state.""" """List slots, optionally filtered by state."""
with self._lock:
if state is not None: if state is not None:
cur = self._conn.execute( cur = self._conn.execute(
"SELECT * FROM slots WHERE state = ? ORDER BY slot_id", (state.value,) "SELECT * FROM slots WHERE state = ? ORDER BY slot_id", (state.value,)
@ -119,6 +125,7 @@ class StateDB:
Additional fields (instance_id, instance_ip, etc.) can be passed as kwargs. Additional fields (instance_id, instance_ip, etc.) can be passed as kwargs.
""" """
with self._lock:
now = _now_iso(self._clock) now = _now_iso(self._clock)
set_parts = ["state = ?", "last_state_change = ?"] set_parts = ["state = ?", "last_state_change = ?"]
params: list[object] = [new_state.value, now] params: list[object] = [new_state.value, now]
@ -159,6 +166,7 @@ class StateDB:
Uses BEGIN IMMEDIATE. Allowed fields: instance_id, instance_ip, Uses BEGIN IMMEDIATE. Allowed fields: instance_id, instance_ip,
instance_launch_time, lease_count, cooldown_until, interruption_pending. instance_launch_time, lease_count, cooldown_until, interruption_pending.
""" """
with self._lock:
allowed = { allowed = {
"instance_id", "instance_id",
"instance_ip", "instance_ip",
@ -204,6 +212,7 @@ class StateDB:
ttl_seconds: int, ttl_seconds: int,
) -> dict: ) -> dict:
"""Create a new pending reservation. Returns the reservation row as dict.""" """Create a new pending reservation. Returns the reservation row as dict."""
with self._lock:
now = _now_iso(self._clock) now = _now_iso(self._clock)
if self._clock is not None: if self._clock is not None:
expires = (self._clock.now() + timedelta(seconds=ttl_seconds)).isoformat() expires = (self._clock.now() + timedelta(seconds=ttl_seconds)).isoformat()
@ -218,7 +227,16 @@ class StateDB:
(reservation_id, system, phase, created_at, updated_at, (reservation_id, system, phase, created_at, updated_at,
expires_at, reason, build_id) expires_at, reason, build_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(rid, system, ReservationPhase.PENDING.value, now, now, expires, reason, build_id), (
rid,
system,
ReservationPhase.PENDING.value,
now,
now,
expires,
reason,
build_id,
),
) )
self._record_event_inner( self._record_event_inner(
"reservation_created", "reservation_created",
@ -233,6 +251,7 @@ class StateDB:
def get_reservation(self, reservation_id: str) -> dict | None: def get_reservation(self, reservation_id: str) -> dict | None:
"""Return a reservation row as dict, or None.""" """Return a reservation row as dict, or None."""
with self._lock:
cur = self._conn.execute( cur = self._conn.execute(
"SELECT * FROM reservations WHERE reservation_id = ?", (reservation_id,) "SELECT * FROM reservations WHERE reservation_id = ?", (reservation_id,)
) )
@ -243,6 +262,7 @@ class StateDB:
def list_reservations(self, phase: ReservationPhase | None = None) -> list[dict]: def list_reservations(self, phase: ReservationPhase | None = None) -> list[dict]:
"""List reservations, optionally filtered by phase.""" """List reservations, optionally filtered by phase."""
with self._lock:
if phase is not None: if phase is not None:
cur = self._conn.execute( cur = self._conn.execute(
"SELECT * FROM reservations WHERE phase = ? ORDER BY created_at", "SELECT * FROM reservations WHERE phase = ? ORDER BY created_at",
@ -258,6 +278,7 @@ class StateDB:
Atomically: update reservation phase to ready, set slot_id/instance_id, Atomically: update reservation phase to ready, set slot_id/instance_id,
and increment slot lease_count. and increment slot lease_count.
""" """
with self._lock:
now = _now_iso(self._clock) now = _now_iso(self._clock)
self._conn.execute("BEGIN IMMEDIATE") self._conn.execute("BEGIN IMMEDIATE")
@ -294,6 +315,7 @@ class StateDB:
def release_reservation(self, reservation_id: str) -> dict | None: def release_reservation(self, reservation_id: str) -> dict | None:
"""Release a reservation, decrementing the slot lease count.""" """Release a reservation, decrementing the slot lease count."""
with self._lock:
now = _now_iso(self._clock) now = _now_iso(self._clock)
self._conn.execute("BEGIN IMMEDIATE") self._conn.execute("BEGIN IMMEDIATE")
@ -338,6 +360,7 @@ class StateDB:
def expire_reservations(self, now: datetime) -> list[str]: def expire_reservations(self, now: datetime) -> list[str]:
"""Expire all reservations past their expires_at. Returns expired IDs.""" """Expire all reservations past their expires_at. Returns expired IDs."""
with self._lock:
now_iso = now.isoformat() now_iso = now.isoformat()
expired_ids: list[str] = [] expired_ids: list[str] = []
@ -378,6 +401,7 @@ class StateDB:
def record_event(self, kind: str, payload: dict) -> None: # type: ignore[type-arg] def record_event(self, kind: str, payload: dict) -> None: # type: ignore[type-arg]
"""Record an audit event.""" """Record an audit event."""
with self._lock:
self._conn.execute("BEGIN IMMEDIATE") self._conn.execute("BEGIN IMMEDIATE")
try: try:
self._record_event_inner(kind, payload) self._record_event_inner(kind, payload)
@ -388,6 +412,7 @@ class StateDB:
def _record_event_inner(self, kind: str, payload: dict) -> None: # type: ignore[type-arg] def _record_event_inner(self, kind: str, payload: dict) -> None: # type: ignore[type-arg]
"""Insert an event row (must be called inside an active transaction).""" """Insert an event row (must be called inside an active transaction)."""
with self._lock:
now = _now_iso(self._clock) now = _now_iso(self._clock)
self._conn.execute( self._conn.execute(
"INSERT INTO events (ts, kind, payload_json) VALUES (?, ?, ?)", "INSERT INTO events (ts, kind, payload_json) VALUES (?, ?, ?)",
@ -398,6 +423,7 @@ class StateDB:
def get_state_summary(self) -> dict: def get_state_summary(self) -> dict:
"""Return aggregate slot and reservation counts.""" """Return aggregate slot and reservation counts."""
with self._lock:
slot_counts: dict[str, int] = {} slot_counts: dict[str, int] = {}
cur = self._conn.execute("SELECT state, COUNT(*) FROM slots GROUP BY state") cur = self._conn.execute("SELECT state, COUNT(*) FROM slots GROUP BY state")
for state_val, count in cur.fetchall(): for state_val, count in cur.fetchall():
@ -438,4 +464,5 @@ class StateDB:
def close(self) -> None: def close(self) -> None:
"""Close the database connection.""" """Close the database connection."""
with self._lock:
self._conn.close() self._conn.close()

View file

@ -1 +1,407 @@
"""End-to-end integration tests with FakeRuntime — Plan 05.""" """End-to-end integration tests with FakeRuntime and a fake HAProxy socket."""
from __future__ import annotations
import socket
import threading
import time
from pathlib import Path
from fastapi.testclient import TestClient
from nix_builder_autoscaler.api import create_app
from nix_builder_autoscaler.config import (
AppConfig,
AwsConfig,
CapacityConfig,
HaproxyConfig,
SchedulerConfig,
)
from nix_builder_autoscaler.metrics import MetricsRegistry
from nix_builder_autoscaler.models import SlotState
from nix_builder_autoscaler.providers.clock import FakeClock
from nix_builder_autoscaler.providers.haproxy import HAProxyRuntime
from nix_builder_autoscaler.reconciler import Reconciler
from nix_builder_autoscaler.runtime.fake import FakeRuntime
from nix_builder_autoscaler.scheduler import scheduling_tick
from nix_builder_autoscaler.state_db import StateDB
class FakeHAProxySocketServer:
"""Tiny fake HAProxy runtime socket server for integration tests."""
def __init__(self, socket_path: Path, backend: str, slot_ids: list[str]) -> None:
self._socket_path = socket_path
self._backend = backend
self._slot_ids = slot_ids
self._stop_event = threading.Event()
self._thread: threading.Thread | None = None
self._lock = threading.Lock()
self._state: dict[str, dict[str, object]] = {
slot_id: {
"enabled": False,
"addr": "0.0.0.0",
"port": 22,
"status": "MAINT",
"scur": 0,
"qcur": 0,
}
for slot_id in slot_ids
}
def start(self) -> None:
self._thread = threading.Thread(target=self._serve, name="fake-haproxy", daemon=True)
self._thread.start()
deadline = time.time() + 2.0
while time.time() < deadline:
if self._socket_path.exists():
return
time.sleep(0.01)
msg = f"fake haproxy socket not created: {self._socket_path}"
raise RuntimeError(msg)
def stop(self) -> None:
self._stop_event.set()
try:
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.connect(str(self._socket_path))
sock.sendall(b"\n")
except OSError:
pass
if self._thread is not None:
self._thread.join(timeout=2.0)
if self._socket_path.exists():
self._socket_path.unlink()
def _serve(self) -> None:
if self._socket_path.exists():
self._socket_path.unlink()
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as server:
server.bind(str(self._socket_path))
server.listen(16)
server.settimeout(0.2)
while not self._stop_event.is_set():
try:
conn, _ = server.accept()
except TimeoutError:
continue
except OSError:
if self._stop_event.is_set():
break
continue
with conn:
payload = b""
while True:
chunk = conn.recv(4096)
if not chunk:
break
payload += chunk
command = payload.decode().strip()
response = self._handle_command(command)
try:
conn.sendall(response.encode())
except BrokenPipeError:
continue
def _handle_command(self, command: str) -> str:
if command == "show stat":
return self._render_show_stat()
parts = command.split()
if not parts:
return "\n"
if parts[0:2] == ["set", "server"] and len(parts) >= 7:
slot_id = self._parse_slot(parts[2])
if slot_id is None:
return "No such server.\n"
with self._lock:
slot_state = self._state[slot_id]
slot_state["addr"] = parts[4]
slot_state["port"] = int(parts[6])
slot_state["status"] = "UP" if slot_state["enabled"] else "DOWN"
return "\n"
if parts[0:2] == ["enable", "server"] and len(parts) >= 3:
slot_id = self._parse_slot(parts[2])
if slot_id is None:
return "No such server.\n"
with self._lock:
slot_state = self._state[slot_id]
slot_state["enabled"] = True
slot_state["status"] = "UP"
return "\n"
if parts[0:2] == ["disable", "server"] and len(parts) >= 3:
slot_id = self._parse_slot(parts[2])
if slot_id is None:
return "No such server.\n"
with self._lock:
slot_state = self._state[slot_id]
slot_state["enabled"] = False
slot_state["status"] = "MAINT"
return "\n"
return "Unknown command.\n"
def _parse_slot(self, backend_slot: str) -> str | None:
backend, _, slot_id = backend_slot.partition("/")
if backend != self._backend or slot_id not in self._state:
return None
return slot_id
def _render_show_stat(self) -> str:
header = "# pxname,svname,qcur,qmax,scur,smax,slim,stot,status\n"
rows = [f"{self._backend},BACKEND,0,0,0,0,0,0,UP\n"]
with self._lock:
for slot_id in self._slot_ids:
slot_state = self._state[slot_id]
rows.append(
f"{self._backend},{slot_id},{slot_state['qcur']},0,"
f"{slot_state['scur']},0,50,0,{slot_state['status']}\n"
)
return header + "".join(rows)
class DaemonHarness:
"""In-process threaded harness for scheduler/reconciler/API integration."""
def __init__(
self,
root: Path,
*,
db_path: Path | None = None,
runtime: FakeRuntime | None = None,
max_slots: int = 3,
min_slots: int = 0,
idle_scale_down_seconds: int = 1,
drain_timeout_seconds: int = 120,
) -> None:
root.mkdir(parents=True, exist_ok=True)
self.clock = FakeClock()
self.metrics = MetricsRegistry()
self.runtime = runtime or FakeRuntime(launch_latency_ticks=2, ip_delay_ticks=1)
self._stop_event = threading.Event()
self._threads: list[threading.Thread] = []
self._reconcile_lock = threading.Lock()
self._db_path = db_path or (root / "state.db")
self._socket_path = root / "haproxy.sock"
self._slot_ids = [f"slot{i:03d}" for i in range(1, 4)]
self.config = AppConfig(
aws=AwsConfig(region="us-east-1"),
haproxy=HaproxyConfig(
runtime_socket=str(self._socket_path),
backend="all",
slot_prefix="slot",
slot_count=3,
check_ready_up_count=1,
),
capacity=CapacityConfig(
default_system="x86_64-linux",
max_slots=max_slots,
min_slots=min_slots,
max_leases_per_slot=1,
target_warm_slots=0,
reservation_ttl_seconds=1200,
idle_scale_down_seconds=idle_scale_down_seconds,
drain_timeout_seconds=drain_timeout_seconds,
),
scheduler=SchedulerConfig(tick_seconds=0.05, reconcile_seconds=0.05),
)
self.db = StateDB(str(self._db_path), clock=self.clock)
self.db.init_schema()
self.db.init_slots("slot", 3, "x86_64-linux", "all")
self.haproxy_server = FakeHAProxySocketServer(self._socket_path, "all", self._slot_ids)
self.haproxy = HAProxyRuntime(str(self._socket_path), "all", "slot")
self.reconciler = Reconciler(
self.db,
self.runtime,
self.haproxy,
self.config,
self.clock,
self.metrics,
)
app = create_app(
self.db,
self.config,
self.clock,
self.metrics,
reconcile_now=self.reconcile_now,
)
self.client = TestClient(app)
def start(self) -> None:
self.haproxy_server.start()
with self._reconcile_lock:
self.runtime.tick()
self.reconciler.tick()
self._threads = [
threading.Thread(target=self._scheduler_loop, name="sched", daemon=True),
threading.Thread(target=self._reconciler_loop, name="recon", daemon=True),
]
for thread in self._threads:
thread.start()
def stop(self) -> None:
self._stop_event.set()
for thread in self._threads:
thread.join(timeout=2.0)
self.client.close()
self.haproxy_server.stop()
self.db.close()
def create_reservation(self, reason: str) -> str:
response = self.client.post(
"/v1/reservations",
json={"system": "x86_64-linux", "reason": reason},
)
assert response.status_code == 200
return str(response.json()["reservation_id"])
def release_reservation(self, reservation_id: str) -> None:
response = self.client.post(f"/v1/reservations/{reservation_id}/release")
assert response.status_code == 200
def reservation(self, reservation_id: str) -> dict:
response = self.client.get(f"/v1/reservations/{reservation_id}")
assert response.status_code == 200
return response.json()
def wait_for(self, predicate, timeout: float = 6.0) -> None: # noqa: ANN001
deadline = time.time() + timeout
while time.time() < deadline:
if predicate():
return
time.sleep(0.02)
raise AssertionError("condition not met before timeout")
def reconcile_now(self) -> dict[str, bool]:
with self._reconcile_lock:
self.runtime.tick()
self.reconciler.tick()
return {"triggered": True}
def _scheduler_loop(self) -> None:
while not self._stop_event.is_set():
scheduling_tick(self.db, self.runtime, self.config, self.clock, self.metrics)
self._stop_event.wait(self.config.scheduler.tick_seconds)
def _reconciler_loop(self) -> None:
while not self._stop_event.is_set():
with self._reconcile_lock:
self.runtime.tick()
self.reconciler.tick()
self._stop_event.wait(self.config.scheduler.reconcile_seconds)
def test_cold_start_reservation_launch_bind_ready(tmp_path: Path) -> None:
harness = DaemonHarness(tmp_path)
harness.start()
try:
reservation_id = harness.create_reservation("cold-start")
harness.wait_for(lambda: harness.reservation(reservation_id)["phase"] == "ready")
reservation = harness.reservation(reservation_id)
assert reservation["slot"] is not None
slot = harness.db.get_slot(reservation["slot"])
assert slot is not None
assert slot["state"] == SlotState.READY.value
assert slot["instance_ip"] is not None
finally:
harness.stop()
def test_burst_three_concurrent_reservations(tmp_path: Path) -> None:
harness = DaemonHarness(tmp_path, max_slots=3)
harness.start()
try:
reservation_ids = [harness.create_reservation(f"burst-{i}") for i in range(3)]
harness.wait_for(
lambda: all(harness.reservation(rid)["phase"] == "ready" for rid in reservation_ids),
timeout=8.0,
)
slots = [harness.reservation(rid)["slot"] for rid in reservation_ids]
assert len(set(slots)) == 3
finally:
harness.stop()
def test_scale_down_after_release_and_idle_timeout(tmp_path: Path) -> None:
harness = DaemonHarness(tmp_path, idle_scale_down_seconds=1, drain_timeout_seconds=0)
harness.start()
try:
reservation_id = harness.create_reservation("scale-down")
harness.wait_for(lambda: harness.reservation(reservation_id)["phase"] == "ready")
slot_id = str(harness.reservation(reservation_id)["slot"])
harness.release_reservation(reservation_id)
harness.clock.advance(2)
harness.wait_for(
lambda: (
harness.db.get_slot(slot_id) is not None
and harness.db.get_slot(slot_id)["state"] == SlotState.EMPTY.value
)
)
finally:
harness.stop()
def test_restart_recovery_midflight(tmp_path: Path) -> None:
db_path = tmp_path / "state.db"
runtime = FakeRuntime(launch_latency_ticks=6, ip_delay_ticks=2)
first = DaemonHarness(tmp_path / "run1", db_path=db_path, runtime=runtime)
first.start()
reservation_id = first.create_reservation("restart-midflight")
first.wait_for(
lambda: len(first.db.list_slots(SlotState.LAUNCHING)) > 0,
timeout=4.0,
)
first.stop()
second = DaemonHarness(tmp_path / "run2", db_path=db_path, runtime=runtime)
second.start()
try:
second.wait_for(lambda: second.reservation(reservation_id)["phase"] == "ready", timeout=8.0)
finally:
second.stop()
def test_interruption_recovery_pending_reservation_resolves(tmp_path: Path) -> None:
harness = DaemonHarness(tmp_path, max_slots=2, idle_scale_down_seconds=60)
harness.start()
try:
first_reservation = harness.create_reservation("baseline")
harness.wait_for(lambda: harness.reservation(first_reservation)["phase"] == "ready")
slot_id = str(harness.reservation(first_reservation)["slot"])
instance_id = str(harness.reservation(first_reservation)["instance_id"])
second_reservation = harness.create_reservation("post-interruption")
harness.release_reservation(first_reservation)
harness.runtime.inject_interruption(instance_id)
harness.runtime._instances[instance_id].state = "shutting-down"
harness.wait_for(
lambda: (
harness.db.get_slot(slot_id) is not None
and harness.db.get_slot(slot_id)["state"]
in {
SlotState.DRAINING.value,
SlotState.TERMINATING.value,
SlotState.EMPTY.value,
}
),
timeout=6.0,
)
harness.wait_for(
lambda: harness.reservation(second_reservation)["phase"] == "ready",
timeout=10.0,
)
finally:
harness.stop()

View file

@ -3,24 +3,29 @@
from __future__ import annotations from __future__ import annotations
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from nix_builder_autoscaler.api import create_app from nix_builder_autoscaler.api import create_app
from nix_builder_autoscaler.config import AppConfig, CapacityConfig from nix_builder_autoscaler.config import AppConfig, CapacityConfig
from nix_builder_autoscaler.metrics import MetricsRegistry from nix_builder_autoscaler.metrics import MetricsRegistry
from nix_builder_autoscaler.models import SlotState
from nix_builder_autoscaler.providers.clock import FakeClock from nix_builder_autoscaler.providers.clock import FakeClock
from nix_builder_autoscaler.state_db import StateDB from nix_builder_autoscaler.state_db import StateDB
def _make_client() -> tuple[TestClient, StateDB, FakeClock, MetricsRegistry]: def _make_client(
*,
reconcile_now: Any = None, # noqa: ANN401
) -> tuple[TestClient, StateDB, FakeClock, MetricsRegistry]:
clock = FakeClock() clock = FakeClock()
db = StateDB(":memory:", clock=clock) db = StateDB(":memory:", clock=clock)
db.init_schema() db.init_schema()
db.init_slots("slot", 3, "x86_64-linux", "all") db.init_slots("slot", 3, "x86_64-linux", "all")
config = AppConfig(capacity=CapacityConfig(reservation_ttl_seconds=1200)) config = AppConfig(capacity=CapacityConfig(reservation_ttl_seconds=1200))
metrics = MetricsRegistry() metrics = MetricsRegistry()
app = create_app(db, config, clock, metrics) app = create_app(db, config, clock, metrics, reconcile_now=reconcile_now)
return TestClient(app), db, clock, metrics return TestClient(app), db, clock, metrics
@ -120,6 +125,20 @@ def test_health_ready_returns_ok_when_no_checks() -> None:
assert response.json()["status"] == "ok" assert response.json()["status"] == "ok"
def test_health_ready_degraded_when_ready_check_fails() -> None:
clock = FakeClock()
db = StateDB(":memory:", clock=clock)
db.init_schema()
db.init_slots("slot", 3, "x86_64-linux", "all")
config = AppConfig(capacity=CapacityConfig(reservation_ttl_seconds=1200))
metrics = MetricsRegistry()
app = create_app(db, config, clock, metrics, ready_check=lambda: False)
client = TestClient(app)
response = client.get("/health/ready")
assert response.status_code == 503
assert response.json()["status"] == "degraded"
def test_metrics_returns_prometheus_text() -> None: def test_metrics_returns_prometheus_text() -> None:
client, _, _, metrics = _make_client() client, _, _, metrics = _make_client()
metrics.counter("autoscaler_test_counter", {}, 1.0) metrics.counter("autoscaler_test_counter", {}, 1.0)
@ -150,3 +169,67 @@ def test_release_nonexistent_returns_404() -> None:
response = client.post("/v1/reservations/resv_nonexistent/release") response = client.post("/v1/reservations/resv_nonexistent/release")
assert response.status_code == 404 assert response.status_code == 404
assert response.json()["error"]["code"] == "not_found" assert response.json()["error"]["code"] == "not_found"
def test_admin_drain_success() -> None:
client, db, _, _ = _make_client()
db.update_slot_state("slot001", SlotState.LAUNCHING, instance_id="i-test")
db.update_slot_state("slot001", SlotState.BOOTING)
db.update_slot_state("slot001", SlotState.BINDING, instance_ip="100.64.0.1")
db.update_slot_state("slot001", SlotState.READY)
response = client.post("/v1/admin/drain", json={"slot_id": "slot001"})
assert response.status_code == 200
assert response.json()["state"] == "draining"
slot = db.get_slot("slot001")
assert slot is not None
assert slot["state"] == SlotState.DRAINING.value
def test_admin_drain_invalid_state_returns_409() -> None:
client, _, _, _ = _make_client()
response = client.post("/v1/admin/drain", json={"slot_id": "slot001"})
assert response.status_code == 409
assert response.json()["error"]["code"] == "invalid_state"
def test_admin_unquarantine_success() -> None:
client, db, _, _ = _make_client()
db.update_slot_state("slot001", SlotState.ERROR, instance_id="i-bad")
response = client.post("/v1/admin/unquarantine", json={"slot_id": "slot001"})
assert response.status_code == 200
assert response.json()["state"] == "empty"
slot = db.get_slot("slot001")
assert slot is not None
assert slot["state"] == SlotState.EMPTY.value
assert slot["instance_id"] is None
def test_admin_unquarantine_invalid_state_returns_409() -> None:
client, _, _, _ = _make_client()
response = client.post("/v1/admin/unquarantine", json={"slot_id": "slot001"})
assert response.status_code == 409
assert response.json()["error"]["code"] == "invalid_state"
def test_admin_reconcile_now_not_configured_returns_503() -> None:
client, _, _, _ = _make_client()
response = client.post("/v1/admin/reconcile-now")
assert response.status_code == 503
assert response.json()["error"]["code"] == "not_configured"
def test_admin_reconcile_now_success() -> None:
called = {"value": False}
def _reconcile_now() -> dict[str, object]:
called["value"] = True
return {"triggered": True}
client, _, _, _ = _make_client(reconcile_now=_reconcile_now)
response = client.post("/v1/admin/reconcile-now")
assert response.status_code == 200
assert response.json()["status"] == "accepted"
assert response.json()["triggered"] is True
assert called["value"] is True

View file

@ -130,6 +130,135 @@ class TestDescribeInstance:
assert info["tailscale_ip"] is None assert info["tailscale_ip"] is None
assert info["launch_time"] == launch_time.isoformat() assert info["launch_time"] == launch_time.isoformat()
@patch.object(
EC2Runtime,
"_read_tailscale_status",
return_value={
"Peer": {
"peer1": {
"HostName": "nix-builder-slot001-i-running1",
"Online": True,
"TailscaleIPs": ["100.64.0.10"],
}
}
},
)
def test_discovers_tailscale_ip_from_localapi(self, _mock_status):
ec2_client = boto3.client("ec2", region_name="us-east-1")
stubber = Stubber(ec2_client)
launch_time = datetime(2026, 1, 15, 12, 30, 0, tzinfo=UTC)
response = {
"Reservations": [
{
"Instances": [
{
"InstanceId": "i-running1",
"State": {"Code": 16, "Name": "running"},
"LaunchTime": launch_time,
"Tags": [{"Key": "AutoscalerSlot", "Value": "slot001"}],
}
],
}
],
}
stubber.add_response(
"describe_instances",
response,
{"InstanceIds": ["i-running1"]},
)
runtime = _make_runtime(stubber, ec2_client)
info = runtime.describe_instance("i-running1")
assert info["tailscale_ip"] == "100.64.0.10"
@patch.object(EC2Runtime, "_read_tailscale_status", return_value={"Peer": {}})
def test_discovery_unavailable_returns_none(self, _mock_status):
ec2_client = boto3.client("ec2", region_name="us-east-1")
stubber = Stubber(ec2_client)
launch_time = datetime(2026, 1, 15, 12, 30, 0, tzinfo=UTC)
response = {
"Reservations": [
{
"Instances": [
{
"InstanceId": "i-running1",
"State": {"Code": 16, "Name": "running"},
"LaunchTime": launch_time,
"Tags": [{"Key": "AutoscalerSlot", "Value": "slot001"}],
}
],
}
],
}
stubber.add_response(
"describe_instances",
response,
{"InstanceIds": ["i-running1"]},
)
runtime = _make_runtime(stubber, ec2_client)
info = runtime.describe_instance("i-running1")
assert info["tailscale_ip"] is None
@patch.object(
EC2Runtime,
"_read_tailscale_status",
return_value={
"Peer": {
"peer1": {
"HostName": "nix-builder-slot001-old",
"Online": True,
"TailscaleIPs": ["100.64.0.10"],
},
"peer2": {
"HostName": "nix-builder-slot001-new",
"Online": True,
"TailscaleIPs": ["100.64.0.11"],
},
}
},
)
def test_ambiguous_slot_match_returns_none(self, _mock_status):
ec2_client = boto3.client("ec2", region_name="us-east-1")
stubber = Stubber(ec2_client)
launch_time = datetime(2026, 1, 15, 12, 30, 0, tzinfo=UTC)
response = {
"Reservations": [
{
"Instances": [
{
"InstanceId": "i-running1",
"State": {"Code": 16, "Name": "running"},
"LaunchTime": launch_time,
"Tags": [{"Key": "AutoscalerSlot", "Value": "slot001"}],
}
],
}
],
}
stubber.add_response(
"describe_instances",
response,
{"InstanceIds": ["i-running1"]},
)
runtime = _make_runtime(stubber, ec2_client)
info = runtime.describe_instance("i-running1")
assert info["tailscale_ip"] is None
def test_localapi_permission_error_returns_none(self):
ec2_client = boto3.client("ec2", region_name="us-east-1")
runtime = EC2Runtime(_make_config(), _client=ec2_client)
with patch(
"nix_builder_autoscaler.runtime.ec2._UnixSocketHTTPConnection.connect",
side_effect=PermissionError,
):
assert runtime._read_tailscale_status() is None
def test_missing_instance_returns_terminated(self): def test_missing_instance_returns_terminated(self):
ec2_client = boto3.client("ec2", region_name="us-east-1") ec2_client = boto3.client("ec2", region_name="us-east-1")
stubber = Stubber(ec2_client) stubber = Stubber(ec2_client)

View file

@ -178,8 +178,7 @@
checkPhase = '' checkPhase = ''
runHook preCheck runHook preCheck
export HOME=$(mktemp -d) export HOME=$(mktemp -d)
# Exit code 5 means no tests collected — tolerate until integration tests are written pytest nix_builder_autoscaler/tests/integration/ -v
pytest nix_builder_autoscaler/tests/integration/ -v || test $? -eq 5
runHook postCheck runHook postCheck
''; '';
doCheck = true; doCheck = true;