nix-builder-autoscaler/agent/nix_builder_autoscaler/reconciler.py

434 lines
16 KiB
Python

"""Reconciler — advances slots through the state machine.
Each tick queries EC2 and HAProxy, then processes each slot according to
its current state: launching→booting→binding→ready, with draining and
terminating paths for teardown.
"""
from __future__ import annotations
import contextlib
import logging
import time
from datetime import datetime
from typing import TYPE_CHECKING
from .models import SlotState
from .providers.haproxy import HAProxyError
if TYPE_CHECKING:
from .config import AppConfig
from .metrics import MetricsRegistry
from .providers.clock import Clock
from .providers.haproxy import HAProxyRuntime
from .runtime.base import RuntimeAdapter
from .state_db import StateDB
log = logging.getLogger(__name__)
_TERMINAL_OR_STOPPED_STATES = ("terminated", "shutting-down", "stopping", "stopped")
class Reconciler:
"""Advances slots through the state machine by polling EC2 and HAProxy.
Maintains binding health-check counters between ticks.
"""
def __init__(
self,
db: StateDB,
runtime: RuntimeAdapter,
haproxy: HAProxyRuntime,
config: AppConfig,
clock: Clock,
metrics: MetricsRegistry,
) -> None:
self._db = db
self._runtime = runtime
self._haproxy = haproxy
self._config = config
self._clock = clock
self._metrics = metrics
self._binding_up_counts: dict[str, int] = {}
def tick(self) -> None:
"""Execute one reconciliation tick."""
t0 = time.monotonic()
# 1. Query EC2
try:
managed = self._runtime.list_managed_instances()
except Exception:
log.exception("ec2_list_failed")
managed = []
ec2_by_slot: dict[str, dict] = {}
for inst in managed:
sid = inst.get("slot_id")
if sid:
ec2_by_slot[sid] = inst
# 2. Query HAProxy
try:
haproxy_health = self._haproxy_read_slot_health()
except HAProxyError:
log.warning("haproxy_stat_failed", exc_info=True)
haproxy_health = {}
# 3. Process each slot
all_slots = self._db.list_slots()
for slot in all_slots:
state = slot["state"]
if state == SlotState.LAUNCHING.value:
self._handle_launching(slot)
elif state == SlotState.BOOTING.value:
self._handle_booting(slot)
elif state == SlotState.BINDING.value:
self._handle_binding(slot, haproxy_health)
elif state == SlotState.READY.value:
self._handle_ready(slot, ec2_by_slot)
elif state == SlotState.DRAINING.value:
self._handle_draining(slot)
elif state == SlotState.TERMINATING.value:
self._handle_terminating(slot, ec2_by_slot)
# 4. Clean stale binding counters
binding_ids = {s["slot_id"] for s in all_slots if s["state"] == SlotState.BINDING.value}
stale = [k for k in self._binding_up_counts if k not in binding_ids]
for k in stale:
del self._binding_up_counts[k]
# 5. Emit metrics
tick_duration = time.monotonic() - t0
self._update_metrics(tick_duration)
def _handle_launching(self, slot: dict) -> None:
"""Check if launching instance has reached running state."""
instance_id = slot["instance_id"]
if not instance_id:
self._db.update_slot_state(slot["slot_id"], SlotState.ERROR)
return
info = self._runtime.describe_instance(instance_id)
ec2_state = info["state"]
if ec2_state == "running":
self._db.update_slot_state(slot["slot_id"], SlotState.BOOTING)
log.info("slot_booting", extra={"slot_id": slot["slot_id"]})
return
if ec2_state in _TERMINAL_OR_STOPPED_STATES:
self._begin_termination(
slot,
reason="slot_launch_lost",
extra={"ec2_state": ec2_state},
)
return
age_seconds = self._slot_state_age_seconds(slot)
if age_seconds >= self._config.capacity.launch_timeout_seconds:
self._begin_termination(
slot,
reason="slot_launch_timeout",
extra={
"ec2_state": ec2_state,
"age_seconds": age_seconds,
"timeout_seconds": self._config.capacity.launch_timeout_seconds,
},
)
def _handle_booting(self, slot: dict) -> None:
"""Check if booting instance has a Tailscale IP yet."""
instance_id = slot["instance_id"]
if not instance_id:
self._db.update_slot_state(slot["slot_id"], SlotState.ERROR)
return
info = self._runtime.describe_instance(instance_id)
ec2_state = info["state"]
if ec2_state in _TERMINAL_OR_STOPPED_STATES:
self._begin_termination(
slot,
reason="slot_boot_lost",
extra={"ec2_state": ec2_state},
)
return
tailscale_ip = info.get("tailscale_ip")
if tailscale_ip is not None:
self._db.update_slot_state(slot["slot_id"], SlotState.BINDING, instance_ip=tailscale_ip)
try:
self._haproxy_set_slot_addr(slot["slot_id"], tailscale_ip)
self._haproxy_enable_slot(slot["slot_id"])
except HAProxyError:
log.warning(
"haproxy_binding_setup_failed",
extra={"slot_id": slot["slot_id"]},
exc_info=True,
)
return
age_seconds = self._slot_state_age_seconds(slot)
if age_seconds >= self._config.capacity.boot_timeout_seconds:
self._begin_termination(
slot,
reason="slot_boot_timeout",
extra={
"ec2_state": ec2_state,
"age_seconds": age_seconds,
"timeout_seconds": self._config.capacity.boot_timeout_seconds,
},
)
def _handle_binding(self, slot: dict, haproxy_health: dict) -> None:
"""Check HAProxy health to determine when slot is ready."""
slot_id = slot["slot_id"]
instance_id = slot.get("instance_id")
if instance_id:
info = self._runtime.describe_instance(instance_id)
ec2_state = info["state"]
if ec2_state in _TERMINAL_OR_STOPPED_STATES:
self._begin_termination(
slot,
reason="slot_binding_lost",
extra={"ec2_state": ec2_state},
)
self._binding_up_counts.pop(slot_id, None)
return
age_seconds = self._slot_state_age_seconds(slot)
if age_seconds >= self._config.capacity.binding_timeout_seconds:
self._begin_termination(
slot,
reason="slot_binding_timeout",
extra={
"age_seconds": age_seconds,
"timeout_seconds": self._config.capacity.binding_timeout_seconds,
},
)
self._binding_up_counts.pop(slot_id, None)
return
health = haproxy_health.get(slot_id)
if health is not None and health.status == "UP":
count = self._binding_up_counts.get(slot_id, 0) + 1
self._binding_up_counts[slot_id] = count
if count >= self._config.haproxy.check_ready_up_count:
self._db.update_slot_state(slot_id, SlotState.READY)
self._binding_up_counts.pop(slot_id, None)
log.info("slot_ready", extra={"slot_id": slot_id})
else:
self._binding_up_counts[slot_id] = 0
# Retry HAProxy setup
ip = slot.get("instance_ip")
if ip:
try:
self._haproxy_set_slot_addr(slot_id, ip)
self._haproxy_enable_slot(slot_id)
except HAProxyError:
pass
def _handle_ready(self, slot: dict, ec2_by_slot: dict[str, dict]) -> None:
"""Verify EC2 instance is still alive for ready slots."""
slot_id = slot["slot_id"]
ec2_info = ec2_by_slot.get(slot_id)
if ec2_info is None or ec2_info["state"] == "terminated":
self._db.update_slot_state(slot_id, SlotState.ERROR, instance_id=None, instance_ip=None)
log.warning("slot_ready_instance_gone", extra={"slot_id": slot_id})
elif ec2_info["state"] == "shutting-down":
self._db.update_slot_fields(slot_id, interruption_pending=1)
log.info("slot_interruption_detected", extra={"slot_id": slot_id})
def _handle_draining(self, slot: dict) -> None:
"""Disable HAProxy and terminate when drain conditions are met."""
slot_id = slot["slot_id"]
# Disable HAProxy (idempotent)
with contextlib.suppress(HAProxyError):
self._haproxy_disable_slot(slot_id)
now = self._clock.now()
last_change = datetime.fromisoformat(slot["last_state_change"])
drain_duration = (now - last_change).total_seconds()
drain_timeout = self._config.capacity.drain_timeout_seconds
if slot["lease_count"] == 0 or drain_duration >= drain_timeout:
instance_id = slot.get("instance_id")
if instance_id:
try:
self._runtime.terminate_instance(instance_id)
self._metrics.counter(
"autoscaler_ec2_terminate_total",
{"result": "success"},
1.0,
)
except Exception:
self._metrics.counter(
"autoscaler_ec2_terminate_total",
{"result": "error"},
1.0,
)
log.warning(
"terminate_failed",
extra={"slot_id": slot_id, "instance_id": instance_id},
exc_info=True,
)
self._db.update_slot_state(slot_id, SlotState.TERMINATING)
log.info(
"slot_terminating",
extra={"slot_id": slot_id, "drain_duration": drain_duration},
)
def _handle_terminating(self, slot: dict, ec2_by_slot: dict[str, dict]) -> None:
"""Wait for EC2 to confirm termination, then reset slot to empty."""
slot_id = slot["slot_id"]
instance_id = slot.get("instance_id")
if not instance_id:
self._db.update_slot_state(
slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0
)
log.info("slot_emptied", extra={"slot_id": slot_id})
return
info = self._runtime.describe_instance(instance_id)
state = info["state"]
if state == "terminated":
self._db.update_slot_state(
slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0
)
log.info("slot_emptied", extra={"slot_id": slot_id})
return
age_seconds = self._slot_state_age_seconds(slot)
if age_seconds >= self._config.capacity.terminating_timeout_seconds:
self._terminate_instance_best_effort(slot_id, instance_id)
# Reset last_state_change after a retry so repeated retries are paced.
self._db.update_slot_state(slot_id, SlotState.TERMINATING)
log.warning(
"slot_termination_timeout_retry",
extra={
"slot_id": slot_id,
"instance_id": instance_id,
"ec2_state": state,
"age_seconds": age_seconds,
"timeout_seconds": self._config.capacity.terminating_timeout_seconds,
},
)
def _slot_state_age_seconds(self, slot: dict) -> float:
now = self._clock.now()
last_change = datetime.fromisoformat(slot["last_state_change"])
return (now - last_change).total_seconds()
def _begin_termination(self, slot: dict, reason: str, extra: dict | None = None) -> None:
slot_id = slot["slot_id"]
instance_id = slot.get("instance_id")
if instance_id:
self._terminate_instance_best_effort(slot_id, instance_id)
self._db.update_slot_state(slot_id, SlotState.TERMINATING)
else:
self._db.update_slot_state(
slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0
)
payload = {"slot_id": slot_id}
if instance_id:
payload["instance_id"] = instance_id
if extra:
payload.update(extra)
log.warning(reason, extra=payload)
def _terminate_instance_best_effort(self, slot_id: str, instance_id: str) -> None:
try:
self._runtime.terminate_instance(instance_id)
self._metrics.counter(
"autoscaler_ec2_terminate_total",
{"result": "success"},
1.0,
)
except Exception:
self._metrics.counter(
"autoscaler_ec2_terminate_total",
{"result": "error"},
1.0,
)
log.warning(
"terminate_failed",
extra={"slot_id": slot_id, "instance_id": instance_id},
exc_info=True,
)
def _update_metrics(self, tick_duration: float) -> None:
"""Emit reconciler metrics."""
summary = self._db.get_state_summary()
for state, count in summary["slots"].items():
self._metrics.gauge("autoscaler_slots_total", {"state": state}, float(count))
self._metrics.histogram_observe("autoscaler_reconcile_duration_seconds", {}, tick_duration)
def _haproxy_set_slot_addr(self, slot_id: str, ip: str) -> None:
try:
self._haproxy.set_slot_addr(slot_id, ip)
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "set_slot_addr", "result": "success"},
1.0,
)
except HAProxyError:
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "set_slot_addr", "result": "error"},
1.0,
)
raise
def _haproxy_enable_slot(self, slot_id: str) -> None:
try:
self._haproxy.enable_slot(slot_id)
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "enable_slot", "result": "success"},
1.0,
)
except HAProxyError:
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "enable_slot", "result": "error"},
1.0,
)
raise
def _haproxy_disable_slot(self, slot_id: str) -> None:
try:
self._haproxy.disable_slot(slot_id)
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "disable_slot", "result": "success"},
1.0,
)
except HAProxyError:
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "disable_slot", "result": "error"},
1.0,
)
raise
def _haproxy_read_slot_health(self) -> dict:
try:
health = self._haproxy.read_slot_health()
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "show_stat", "result": "success"},
1.0,
)
return health
except HAProxyError:
self._metrics.counter(
"autoscaler_haproxy_command_total",
{"cmd": "show_stat", "result": "error"},
1.0,
)
raise