From 3be933f16bf5a5bea908a91830354789bf97364e Mon Sep 17 00:00:00 2001 From: Abel Luck Date: Fri, 27 Feb 2026 15:40:39 +0100 Subject: [PATCH] add timeout safeguards for all slot lifecycle stages --- agent/nix_builder_autoscaler/config.py | 4 + agent/nix_builder_autoscaler/reconciler.py | 150 +++++++++++-- .../tests/test_reconciler.py | 197 ++++++++++++++++++ .../nixos/services/nix-builder-autoscaler.nix | 28 +++ 4 files changed, 356 insertions(+), 23 deletions(-) create mode 100644 agent/nix_builder_autoscaler/tests/test_reconciler.py diff --git a/agent/nix_builder_autoscaler/config.py b/agent/nix_builder_autoscaler/config.py index 7d25d43..b465d7a 100644 --- a/agent/nix_builder_autoscaler/config.py +++ b/agent/nix_builder_autoscaler/config.py @@ -64,6 +64,10 @@ class CapacityConfig: reservation_ttl_seconds: int = 1200 idle_scale_down_seconds: int = 900 drain_timeout_seconds: int = 120 + launch_timeout_seconds: int = 300 + boot_timeout_seconds: int = 300 + binding_timeout_seconds: int = 180 + terminating_timeout_seconds: int = 300 @dataclass diff --git a/agent/nix_builder_autoscaler/reconciler.py b/agent/nix_builder_autoscaler/reconciler.py index 448f92d..676c1b6 100644 --- a/agent/nix_builder_autoscaler/reconciler.py +++ b/agent/nix_builder_autoscaler/reconciler.py @@ -25,6 +25,7 @@ if TYPE_CHECKING: from .state_db import StateDB log = logging.getLogger(__name__) +_TERMINAL_OR_STOPPED_STATES = ("terminated", "shutting-down", "stopping", "stopped") class Reconciler: @@ -113,11 +114,26 @@ class Reconciler: if ec2_state == "running": self._db.update_slot_state(slot["slot_id"], SlotState.BOOTING) log.info("slot_booting", extra={"slot_id": slot["slot_id"]}) - elif ec2_state in ("terminated", "shutting-down"): - self._db.update_slot_state(slot["slot_id"], SlotState.ERROR) - log.warning( - "slot_launch_terminated", - extra={"slot_id": slot["slot_id"], "ec2_state": ec2_state}, + return + + if ec2_state in _TERMINAL_OR_STOPPED_STATES: + self._begin_termination( + slot, + reason="slot_launch_lost", + extra={"ec2_state": ec2_state}, + ) + return + + age_seconds = self._slot_state_age_seconds(slot) + if age_seconds >= self._config.capacity.launch_timeout_seconds: + self._begin_termination( + slot, + reason="slot_launch_timeout", + extra={ + "ec2_state": ec2_state, + "age_seconds": age_seconds, + "timeout_seconds": self._config.capacity.launch_timeout_seconds, + }, ) def _handle_booting(self, slot: dict) -> None: @@ -130,11 +146,11 @@ class Reconciler: info = self._runtime.describe_instance(instance_id) ec2_state = info["state"] - if ec2_state in ("terminated", "shutting-down"): - self._db.update_slot_state(slot["slot_id"], SlotState.ERROR) - log.warning( - "slot_boot_terminated", - extra={"slot_id": slot["slot_id"], "ec2_state": ec2_state}, + if ec2_state in _TERMINAL_OR_STOPPED_STATES: + self._begin_termination( + slot, + reason="slot_boot_lost", + extra={"ec2_state": ec2_state}, ) return @@ -150,10 +166,49 @@ class Reconciler: extra={"slot_id": slot["slot_id"]}, exc_info=True, ) + return + + age_seconds = self._slot_state_age_seconds(slot) + if age_seconds >= self._config.capacity.boot_timeout_seconds: + self._begin_termination( + slot, + reason="slot_boot_timeout", + extra={ + "ec2_state": ec2_state, + "age_seconds": age_seconds, + "timeout_seconds": self._config.capacity.boot_timeout_seconds, + }, + ) def _handle_binding(self, slot: dict, haproxy_health: dict) -> None: """Check HAProxy health to determine when slot is ready.""" slot_id = slot["slot_id"] + instance_id = slot.get("instance_id") + if instance_id: + info = self._runtime.describe_instance(instance_id) + ec2_state = info["state"] + if ec2_state in _TERMINAL_OR_STOPPED_STATES: + self._begin_termination( + slot, + reason="slot_binding_lost", + extra={"ec2_state": ec2_state}, + ) + self._binding_up_counts.pop(slot_id, None) + return + + age_seconds = self._slot_state_age_seconds(slot) + if age_seconds >= self._config.capacity.binding_timeout_seconds: + self._begin_termination( + slot, + reason="slot_binding_timeout", + extra={ + "age_seconds": age_seconds, + "timeout_seconds": self._config.capacity.binding_timeout_seconds, + }, + ) + self._binding_up_counts.pop(slot_id, None) + return + health = haproxy_health.get(slot_id) if health is not None and health.status == "UP": @@ -174,18 +229,6 @@ class Reconciler: except HAProxyError: pass - # Check if instance is still alive - instance_id = slot.get("instance_id") - if instance_id: - info = self._runtime.describe_instance(instance_id) - if info["state"] in ("terminated", "shutting-down"): - self._db.update_slot_state(slot_id, SlotState.ERROR) - self._binding_up_counts.pop(slot_id, None) - log.warning( - "slot_binding_terminated", - extra={"slot_id": slot_id}, - ) - def _handle_ready(self, slot: dict, ec2_by_slot: dict[str, dict]) -> None: """Verify EC2 instance is still alive for ready slots.""" slot_id = slot["slot_id"] @@ -251,11 +294,72 @@ class Reconciler: return info = self._runtime.describe_instance(instance_id) - if info["state"] == "terminated": + state = info["state"] + if state == "terminated": self._db.update_slot_state( slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0 ) log.info("slot_emptied", extra={"slot_id": slot_id}) + return + + age_seconds = self._slot_state_age_seconds(slot) + if age_seconds >= self._config.capacity.terminating_timeout_seconds: + self._terminate_instance_best_effort(slot_id, instance_id) + # Reset last_state_change after a retry so repeated retries are paced. + self._db.update_slot_state(slot_id, SlotState.TERMINATING) + log.warning( + "slot_termination_timeout_retry", + extra={ + "slot_id": slot_id, + "instance_id": instance_id, + "ec2_state": state, + "age_seconds": age_seconds, + "timeout_seconds": self._config.capacity.terminating_timeout_seconds, + }, + ) + + def _slot_state_age_seconds(self, slot: dict) -> float: + now = self._clock.now() + last_change = datetime.fromisoformat(slot["last_state_change"]) + return (now - last_change).total_seconds() + + def _begin_termination(self, slot: dict, reason: str, extra: dict | None = None) -> None: + slot_id = slot["slot_id"] + instance_id = slot.get("instance_id") + if instance_id: + self._terminate_instance_best_effort(slot_id, instance_id) + self._db.update_slot_state(slot_id, SlotState.TERMINATING) + else: + self._db.update_slot_state( + slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0 + ) + + payload = {"slot_id": slot_id} + if instance_id: + payload["instance_id"] = instance_id + if extra: + payload.update(extra) + log.warning(reason, extra=payload) + + def _terminate_instance_best_effort(self, slot_id: str, instance_id: str) -> None: + try: + self._runtime.terminate_instance(instance_id) + self._metrics.counter( + "autoscaler_ec2_terminate_total", + {"result": "success"}, + 1.0, + ) + except Exception: + self._metrics.counter( + "autoscaler_ec2_terminate_total", + {"result": "error"}, + 1.0, + ) + log.warning( + "terminate_failed", + extra={"slot_id": slot_id, "instance_id": instance_id}, + exc_info=True, + ) def _update_metrics(self, tick_duration: float) -> None: """Emit reconciler metrics.""" diff --git a/agent/nix_builder_autoscaler/tests/test_reconciler.py b/agent/nix_builder_autoscaler/tests/test_reconciler.py new file mode 100644 index 0000000..34af294 --- /dev/null +++ b/agent/nix_builder_autoscaler/tests/test_reconciler.py @@ -0,0 +1,197 @@ +"""Unit tests for reconciler timeout and failure safeguards.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from nix_builder_autoscaler.config import AppConfig, AwsConfig, CapacityConfig +from nix_builder_autoscaler.metrics import MetricsRegistry +from nix_builder_autoscaler.models import SlotState +from nix_builder_autoscaler.providers.clock import FakeClock +from nix_builder_autoscaler.providers.haproxy import SlotHealth +from nix_builder_autoscaler.reconciler import Reconciler +from nix_builder_autoscaler.state_db import StateDB + + +@dataclass +class _Instance: + state: str + slot_id: str + tailscale_ip: str | None = None + + +class _RuntimeStub: + def __init__(self) -> None: + self.instances: dict[str, _Instance] = {} + self.terminate_calls: list[str] = [] + + def list_managed_instances(self) -> list[dict]: + return [ + { + "instance_id": iid, + "state": inst.state, + "slot_id": inst.slot_id, + } + for iid, inst in self.instances.items() + if inst.state != "terminated" + ] + + def describe_instance(self, instance_id: str) -> dict: + inst = self.instances.get(instance_id) + if inst is None: + return {"state": "terminated", "tailscale_ip": None, "launch_time": None} + return {"state": inst.state, "tailscale_ip": inst.tailscale_ip, "launch_time": None} + + def terminate_instance(self, instance_id: str) -> None: + self.terminate_calls.append(instance_id) + if instance_id in self.instances: + self.instances[instance_id].state = "terminated" + + +class _HAProxyStub: + def __init__(self, health: dict[str, SlotHealth] | None = None) -> None: + self.health = health or {} + + def set_slot_addr(self, slot_id: str, ip: str, port: int = 22) -> None: # noqa: ARG002 + return + + def enable_slot(self, slot_id: str) -> None: # noqa: ARG002 + return + + def disable_slot(self, slot_id: str) -> None: # noqa: ARG002 + return + + def read_slot_health(self) -> dict[str, SlotHealth]: + return self.health + + +def _make_env( + *, + launch_timeout=300, + boot_timeout=300, + binding_timeout=180, + terminating_timeout=300, +): + clock = FakeClock() + db = StateDB(":memory:", clock=clock) + db.init_schema() + db.init_slots("slot", 1, "x86_64-linux", "all") + + runtime = _RuntimeStub() + haproxy = _HAProxyStub() + config = AppConfig( + capacity=CapacityConfig( + launch_timeout_seconds=launch_timeout, + boot_timeout_seconds=boot_timeout, + binding_timeout_seconds=binding_timeout, + terminating_timeout_seconds=terminating_timeout, + ), + aws=AwsConfig(region="us-east-1"), + ) + metrics = MetricsRegistry() + reconciler = Reconciler(db, runtime, haproxy, config, clock, metrics) + return db, runtime, reconciler, clock + + +def test_launching_timeout_moves_slot_to_terminating() -> None: + db, runtime, reconciler, clock = _make_env(launch_timeout=10) + runtime.instances["i-1"] = _Instance(state="pending", slot_id="slot001") + db.update_slot_state("slot001", SlotState.LAUNCHING, instance_id="i-1") + + clock.advance(11) + reconciler.tick() + + slot = db.get_slot("slot001") + assert slot is not None + assert slot["state"] == SlotState.TERMINATING.value + assert runtime.terminate_calls == ["i-1"] + + +def test_launching_stopped_state_begins_termination() -> None: + db, runtime, reconciler, _ = _make_env() + runtime.instances["i-1"] = _Instance(state="stopped", slot_id="slot001") + db.update_slot_state("slot001", SlotState.LAUNCHING, instance_id="i-1") + + reconciler.tick() + + slot = db.get_slot("slot001") + assert slot is not None + assert slot["state"] == SlotState.TERMINATING.value + assert runtime.terminate_calls == ["i-1"] + + +def test_booting_timeout_moves_slot_to_terminating() -> None: + db, runtime, reconciler, clock = _make_env(boot_timeout=15) + runtime.instances["i-2"] = _Instance(state="running", slot_id="slot001", tailscale_ip=None) + db.update_slot_state("slot001", SlotState.BOOTING, instance_id="i-2") + + clock.advance(16) + reconciler.tick() + + slot = db.get_slot("slot001") + assert slot is not None + assert slot["state"] == SlotState.TERMINATING.value + assert runtime.terminate_calls == ["i-2"] + + +def test_binding_timeout_moves_slot_to_terminating() -> None: + db, runtime, reconciler, clock = _make_env(binding_timeout=8) + runtime.instances["i-3"] = _Instance( + state="running", + slot_id="slot001", + tailscale_ip="100.64.0.3", + ) + db.update_slot_state( + "slot001", + SlotState.BINDING, + instance_id="i-3", + instance_ip="100.64.0.3", + ) + + clock.advance(9) + reconciler.tick() + + slot = db.get_slot("slot001") + assert slot is not None + assert slot["state"] == SlotState.TERMINATING.value + assert runtime.terminate_calls == ["i-3"] + + +def test_binding_stopped_state_begins_termination() -> None: + db, runtime, reconciler, _ = _make_env() + runtime.instances["i-4"] = _Instance( + state="stopping", + slot_id="slot001", + tailscale_ip="100.64.0.4", + ) + db.update_slot_state( + "slot001", + SlotState.BINDING, + instance_id="i-4", + instance_ip="100.64.0.4", + ) + + reconciler.tick() + + slot = db.get_slot("slot001") + assert slot is not None + assert slot["state"] == SlotState.TERMINATING.value + assert runtime.terminate_calls == ["i-4"] + + +def test_terminating_timeout_reissues_terminate_with_pacing() -> None: + db, runtime, reconciler, clock = _make_env(terminating_timeout=5) + runtime.instances["i-5"] = _Instance(state="shutting-down", slot_id="slot001") + db.update_slot_state("slot001", SlotState.TERMINATING, instance_id="i-5") + + clock.advance(6) + reconciler.tick() + + slot = db.get_slot("slot001") + assert slot is not None + assert slot["state"] == SlotState.TERMINATING.value + assert runtime.terminate_calls == ["i-5"] + + # Immediate next tick should not retry yet because last_state_change was refreshed. + reconciler.tick() + assert runtime.terminate_calls == ["i-5"] diff --git a/nix/modules/nixos/services/nix-builder-autoscaler.nix b/nix/modules/nixos/services/nix-builder-autoscaler.nix index 14b80dc..803b7a3 100644 --- a/nix/modules/nixos/services/nix-builder-autoscaler.nix +++ b/nix/modules/nixos/services/nix-builder-autoscaler.nix @@ -181,6 +181,30 @@ in description = "Drain timeout before force termination."; }; + launchTimeoutSeconds = lib.mkOption { + type = lib.types.int; + default = 300; + description = "Max seconds a slot may remain launching before forced termination."; + }; + + bootTimeoutSeconds = lib.mkOption { + type = lib.types.int; + default = 300; + description = "Max seconds a slot may remain booting before forced termination."; + }; + + bindingTimeoutSeconds = lib.mkOption { + type = lib.types.int; + default = 180; + description = "Max seconds a slot may remain binding before forced termination."; + }; + + terminatingTimeoutSeconds = lib.mkOption { + type = lib.types.int; + default = 300; + description = "Max seconds between terminate retries while slot is terminating."; + }; + launchBatchSize = lib.mkOption { type = lib.types.int; default = 1; @@ -301,6 +325,10 @@ in reservation_ttl_seconds = ${toString cfg.capacity.reservationTtlSeconds} idle_scale_down_seconds = ${toString cfg.capacity.idleScaleDownSeconds} drain_timeout_seconds = ${toString cfg.capacity.drainTimeoutSeconds} + launch_timeout_seconds = ${toString cfg.capacity.launchTimeoutSeconds} + boot_timeout_seconds = ${toString cfg.capacity.bootTimeoutSeconds} + binding_timeout_seconds = ${toString cfg.capacity.bindingTimeoutSeconds} + terminating_timeout_seconds = ${toString cfg.capacity.terminatingTimeoutSeconds} [security] socket_mode = "${cfg.security.socketMode}"