"""Reconciler — advances slots through the state machine. Each tick queries EC2 and HAProxy, then processes each slot according to its current state: launching→booting→binding→ready, with draining and terminating paths for teardown. """ from __future__ import annotations import contextlib import logging import time from datetime import datetime, timedelta from typing import TYPE_CHECKING from .models import SlotState from .providers.haproxy import HAProxyError if TYPE_CHECKING: from .config import AppConfig from .metrics import MetricsRegistry from .providers.clock import Clock from .providers.haproxy import HAProxyRuntime from .runtime.base import RuntimeAdapter from .state_db import StateDB log = logging.getLogger(__name__) _TERMINAL_OR_STOPPED_STATES = ("terminated", "shutting-down", "stopping", "stopped") class Reconciler: """Advances slots through the state machine by polling EC2 and HAProxy. Maintains binding health-check counters between ticks. """ def __init__( self, db: StateDB, runtime: RuntimeAdapter, haproxy: HAProxyRuntime, config: AppConfig, clock: Clock, metrics: MetricsRegistry, ) -> None: self._db = db self._runtime = runtime self._haproxy = haproxy self._config = config self._clock = clock self._metrics = metrics self._binding_up_counts: dict[str, int] = {} self._termination_cooldown_until: datetime | None = None def tick(self) -> None: """Execute one reconciliation tick.""" t0 = time.monotonic() # 1. Query EC2 try: managed = self._runtime.list_managed_instances() except Exception: log.exception("ec2_list_failed") managed = [] ec2_by_slot: dict[str, dict] = {} for inst in managed: sid = inst.get("slot_id") if sid: ec2_by_slot[sid] = inst # 2. Query HAProxy try: haproxy_health = self._haproxy_read_slot_health() except HAProxyError: log.warning("haproxy_stat_failed", exc_info=True) haproxy_health = {} # 3. Process each slot all_slots = self._db.list_slots() for slot in all_slots: state = slot["state"] if state == SlotState.LAUNCHING.value: self._handle_launching(slot) elif state == SlotState.BOOTING.value: self._handle_booting(slot) elif state == SlotState.BINDING.value: self._handle_binding(slot, haproxy_health) elif state == SlotState.READY.value: self._handle_ready(slot, ec2_by_slot) elif state == SlotState.DRAINING.value: self._handle_draining(slot) elif state == SlotState.TERMINATING.value: self._handle_terminating(slot, ec2_by_slot) # 4. Clean stale binding counters binding_ids = {s["slot_id"] for s in all_slots if s["state"] == SlotState.BINDING.value} stale = [k for k in self._binding_up_counts if k not in binding_ids] for k in stale: del self._binding_up_counts[k] # 5. Emit metrics tick_duration = time.monotonic() - t0 self._update_metrics(tick_duration) def _handle_launching(self, slot: dict) -> None: """Check if launching instance has reached running state.""" instance_id = slot["instance_id"] if not instance_id: self._db.update_slot_state(slot["slot_id"], SlotState.ERROR) return info = self._runtime.describe_instance(instance_id) ec2_state = info["state"] if ec2_state == "running": self._db.update_slot_state(slot["slot_id"], SlotState.BOOTING) log.info("slot_booting", extra={"slot_id": slot["slot_id"]}) return if ec2_state in _TERMINAL_OR_STOPPED_STATES: self._begin_termination( slot, reason="slot_launch_lost", extra={"ec2_state": ec2_state}, ) return age_seconds = self._slot_state_age_seconds(slot) if age_seconds >= self._config.capacity.launch_timeout_seconds: self._begin_termination( slot, reason="slot_launch_timeout", extra={ "ec2_state": ec2_state, "age_seconds": age_seconds, "timeout_seconds": self._config.capacity.launch_timeout_seconds, }, ) def _handle_booting(self, slot: dict) -> None: """Check if booting instance has a Tailscale IP yet.""" instance_id = slot["instance_id"] if not instance_id: self._db.update_slot_state(slot["slot_id"], SlotState.ERROR) return info = self._runtime.describe_instance(instance_id) ec2_state = info["state"] if ec2_state in _TERMINAL_OR_STOPPED_STATES: self._begin_termination( slot, reason="slot_boot_lost", extra={"ec2_state": ec2_state}, ) return tailscale_ip = info.get("tailscale_ip") if tailscale_ip is not None: self._db.update_slot_state(slot["slot_id"], SlotState.BINDING, instance_ip=tailscale_ip) try: self._haproxy_set_slot_addr(slot["slot_id"], tailscale_ip) self._haproxy_enable_slot(slot["slot_id"]) except HAProxyError: log.warning( "haproxy_binding_setup_failed", extra={"slot_id": slot["slot_id"]}, exc_info=True, ) return age_seconds = self._slot_state_age_seconds(slot) if age_seconds >= self._config.capacity.boot_timeout_seconds: self._begin_termination( slot, reason="slot_boot_timeout", extra={ "ec2_state": ec2_state, "age_seconds": age_seconds, "timeout_seconds": self._config.capacity.boot_timeout_seconds, }, ) def _handle_binding(self, slot: dict, haproxy_health: dict) -> None: """Check HAProxy health to determine when slot is ready.""" slot_id = slot["slot_id"] instance_id = slot.get("instance_id") if instance_id: info = self._runtime.describe_instance(instance_id) ec2_state = info["state"] if ec2_state in _TERMINAL_OR_STOPPED_STATES: self._begin_termination( slot, reason="slot_binding_lost", extra={"ec2_state": ec2_state}, ) self._binding_up_counts.pop(slot_id, None) return age_seconds = self._slot_state_age_seconds(slot) if age_seconds >= self._config.capacity.binding_timeout_seconds: self._begin_termination( slot, reason="slot_binding_timeout", extra={ "age_seconds": age_seconds, "timeout_seconds": self._config.capacity.binding_timeout_seconds, }, ) self._binding_up_counts.pop(slot_id, None) return health = haproxy_health.get(slot_id) if health is not None and health.status == "UP": count = self._binding_up_counts.get(slot_id, 0) + 1 self._binding_up_counts[slot_id] = count if count >= self._config.haproxy.check_ready_up_count: self._db.update_slot_state(slot_id, SlotState.READY) self._binding_up_counts.pop(slot_id, None) log.info("slot_ready", extra={"slot_id": slot_id}) else: self._binding_up_counts[slot_id] = 0 # Retry HAProxy setup ip = slot.get("instance_ip") if ip: try: self._haproxy_set_slot_addr(slot_id, ip) self._haproxy_enable_slot(slot_id) except HAProxyError: pass def _handle_ready(self, slot: dict, ec2_by_slot: dict[str, dict]) -> None: """Verify EC2 instance is still alive for ready slots.""" slot_id = slot["slot_id"] ec2_info = ec2_by_slot.get(slot_id) if ec2_info is None or ec2_info["state"] == "terminated": self._db.update_slot_state(slot_id, SlotState.ERROR, instance_id=None, instance_ip=None) log.warning("slot_ready_instance_gone", extra={"slot_id": slot_id}) elif ec2_info["state"] == "shutting-down": self._db.update_slot_fields(slot_id, interruption_pending=1) log.info("slot_interruption_detected", extra={"slot_id": slot_id}) def _handle_draining(self, slot: dict) -> None: """Disable HAProxy and terminate when drain conditions are met.""" slot_id = slot["slot_id"] # Disable HAProxy (idempotent) with contextlib.suppress(HAProxyError): self._haproxy_disable_slot(slot_id) now = self._clock.now() last_change = datetime.fromisoformat(slot["last_state_change"]) drain_duration = (now - last_change).total_seconds() drain_timeout = self._config.capacity.drain_timeout_seconds if slot["lease_count"] == 0 or drain_duration >= drain_timeout: if not self._can_start_termination(): return instance_id = slot.get("instance_id") if instance_id: try: self._runtime.terminate_instance(instance_id) self._metrics.counter( "autoscaler_ec2_terminate_total", {"result": "success"}, 1.0, ) except Exception: self._metrics.counter( "autoscaler_ec2_terminate_total", {"result": "error"}, 1.0, ) log.warning( "terminate_failed", extra={"slot_id": slot_id, "instance_id": instance_id}, exc_info=True, ) self._db.update_slot_state(slot_id, SlotState.TERMINATING) self._mark_termination_started() log.info( "slot_terminating", extra={"slot_id": slot_id, "drain_duration": drain_duration}, ) def _handle_terminating(self, slot: dict, ec2_by_slot: dict[str, dict]) -> None: """Wait for EC2 to confirm termination, then reset slot to empty.""" slot_id = slot["slot_id"] instance_id = slot.get("instance_id") if not instance_id: self._db.update_slot_state( slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0 ) log.info("slot_emptied", extra={"slot_id": slot_id}) return info = self._runtime.describe_instance(instance_id) state = info["state"] if state == "terminated": self._db.update_slot_state( slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0 ) log.info("slot_emptied", extra={"slot_id": slot_id}) return age_seconds = self._slot_state_age_seconds(slot) if age_seconds >= self._config.capacity.terminating_timeout_seconds: self._terminate_instance_best_effort(slot_id, instance_id) # Reset last_state_change after a retry so repeated retries are paced. self._db.update_slot_state(slot_id, SlotState.TERMINATING) log.warning( "slot_termination_timeout_retry", extra={ "slot_id": slot_id, "instance_id": instance_id, "ec2_state": state, "age_seconds": age_seconds, "timeout_seconds": self._config.capacity.terminating_timeout_seconds, }, ) def _slot_state_age_seconds(self, slot: dict) -> float: now = self._clock.now() last_change = datetime.fromisoformat(slot["last_state_change"]) return (now - last_change).total_seconds() def _begin_termination(self, slot: dict, reason: str, extra: dict | None = None) -> None: if not self._can_start_termination(): return slot_id = slot["slot_id"] instance_id = slot.get("instance_id") started_terminating = False if instance_id: self._terminate_instance_best_effort(slot_id, instance_id) self._db.update_slot_state(slot_id, SlotState.TERMINATING) started_terminating = True else: self._db.update_slot_state( slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0 ) payload = {"slot_id": slot_id} if instance_id: payload["instance_id"] = instance_id if extra: payload.update(extra) log.warning(reason, extra=payload) if started_terminating: self._mark_termination_started() def _can_start_termination(self) -> bool: cooldown_seconds = self._config.capacity.termination_cooldown_seconds if cooldown_seconds <= 0: return True if self._termination_cooldown_until is None: return True return self._clock.now() >= self._termination_cooldown_until def _mark_termination_started(self) -> None: cooldown_seconds = self._config.capacity.termination_cooldown_seconds if cooldown_seconds <= 0: self._termination_cooldown_until = None return self._termination_cooldown_until = self._clock.now() + timedelta(seconds=cooldown_seconds) def _terminate_instance_best_effort(self, slot_id: str, instance_id: str) -> None: try: self._runtime.terminate_instance(instance_id) self._metrics.counter( "autoscaler_ec2_terminate_total", {"result": "success"}, 1.0, ) except Exception: self._metrics.counter( "autoscaler_ec2_terminate_total", {"result": "error"}, 1.0, ) log.warning( "terminate_failed", extra={"slot_id": slot_id, "instance_id": instance_id}, exc_info=True, ) def _update_metrics(self, tick_duration: float) -> None: """Emit reconciler metrics.""" summary = self._db.get_state_summary() for state, count in summary["slots"].items(): self._metrics.gauge("autoscaler_slots_total", {"state": state}, float(count)) self._metrics.histogram_observe("autoscaler_reconcile_duration_seconds", {}, tick_duration) def _haproxy_set_slot_addr(self, slot_id: str, ip: str) -> None: try: self._haproxy.set_slot_addr(slot_id, ip) self._metrics.counter( "autoscaler_haproxy_command_total", {"cmd": "set_slot_addr", "result": "success"}, 1.0, ) except HAProxyError: self._metrics.counter( "autoscaler_haproxy_command_total", {"cmd": "set_slot_addr", "result": "error"}, 1.0, ) raise def _haproxy_enable_slot(self, slot_id: str) -> None: try: self._haproxy.enable_slot(slot_id) self._metrics.counter( "autoscaler_haproxy_command_total", {"cmd": "enable_slot", "result": "success"}, 1.0, ) except HAProxyError: self._metrics.counter( "autoscaler_haproxy_command_total", {"cmd": "enable_slot", "result": "error"}, 1.0, ) raise def _haproxy_disable_slot(self, slot_id: str) -> None: try: self._haproxy.disable_slot(slot_id) self._metrics.counter( "autoscaler_haproxy_command_total", {"cmd": "disable_slot", "result": "success"}, 1.0, ) except HAProxyError: self._metrics.counter( "autoscaler_haproxy_command_total", {"cmd": "disable_slot", "result": "error"}, 1.0, ) raise def _haproxy_read_slot_health(self) -> dict: try: health = self._haproxy.read_slot_health() self._metrics.counter( "autoscaler_haproxy_command_total", {"cmd": "show_stat", "result": "success"}, 1.0, ) return health except HAProxyError: self._metrics.counter( "autoscaler_haproxy_command_total", {"cmd": "show_stat", "result": "error"}, 1.0, ) raise