434 lines
16 KiB
Python
434 lines
16 KiB
Python
"""Reconciler — advances slots through the state machine.
|
|
|
|
Each tick queries EC2 and HAProxy, then processes each slot according to
|
|
its current state: launching→booting→binding→ready, with draining and
|
|
terminating paths for teardown.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import logging
|
|
import time
|
|
from datetime import datetime
|
|
from typing import TYPE_CHECKING
|
|
|
|
from .models import SlotState
|
|
from .providers.haproxy import HAProxyError
|
|
|
|
if TYPE_CHECKING:
|
|
from .config import AppConfig
|
|
from .metrics import MetricsRegistry
|
|
from .providers.clock import Clock
|
|
from .providers.haproxy import HAProxyRuntime
|
|
from .runtime.base import RuntimeAdapter
|
|
from .state_db import StateDB
|
|
|
|
log = logging.getLogger(__name__)
|
|
_TERMINAL_OR_STOPPED_STATES = ("terminated", "shutting-down", "stopping", "stopped")
|
|
|
|
|
|
class Reconciler:
|
|
"""Advances slots through the state machine by polling EC2 and HAProxy.
|
|
|
|
Maintains binding health-check counters between ticks.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
db: StateDB,
|
|
runtime: RuntimeAdapter,
|
|
haproxy: HAProxyRuntime,
|
|
config: AppConfig,
|
|
clock: Clock,
|
|
metrics: MetricsRegistry,
|
|
) -> None:
|
|
self._db = db
|
|
self._runtime = runtime
|
|
self._haproxy = haproxy
|
|
self._config = config
|
|
self._clock = clock
|
|
self._metrics = metrics
|
|
self._binding_up_counts: dict[str, int] = {}
|
|
|
|
def tick(self) -> None:
|
|
"""Execute one reconciliation tick."""
|
|
t0 = time.monotonic()
|
|
|
|
# 1. Query EC2
|
|
try:
|
|
managed = self._runtime.list_managed_instances()
|
|
except Exception:
|
|
log.exception("ec2_list_failed")
|
|
managed = []
|
|
ec2_by_slot: dict[str, dict] = {}
|
|
for inst in managed:
|
|
sid = inst.get("slot_id")
|
|
if sid:
|
|
ec2_by_slot[sid] = inst
|
|
|
|
# 2. Query HAProxy
|
|
try:
|
|
haproxy_health = self._haproxy_read_slot_health()
|
|
except HAProxyError:
|
|
log.warning("haproxy_stat_failed", exc_info=True)
|
|
haproxy_health = {}
|
|
|
|
# 3. Process each slot
|
|
all_slots = self._db.list_slots()
|
|
for slot in all_slots:
|
|
state = slot["state"]
|
|
if state == SlotState.LAUNCHING.value:
|
|
self._handle_launching(slot)
|
|
elif state == SlotState.BOOTING.value:
|
|
self._handle_booting(slot)
|
|
elif state == SlotState.BINDING.value:
|
|
self._handle_binding(slot, haproxy_health)
|
|
elif state == SlotState.READY.value:
|
|
self._handle_ready(slot, ec2_by_slot)
|
|
elif state == SlotState.DRAINING.value:
|
|
self._handle_draining(slot)
|
|
elif state == SlotState.TERMINATING.value:
|
|
self._handle_terminating(slot, ec2_by_slot)
|
|
|
|
# 4. Clean stale binding counters
|
|
binding_ids = {s["slot_id"] for s in all_slots if s["state"] == SlotState.BINDING.value}
|
|
stale = [k for k in self._binding_up_counts if k not in binding_ids]
|
|
for k in stale:
|
|
del self._binding_up_counts[k]
|
|
|
|
# 5. Emit metrics
|
|
tick_duration = time.monotonic() - t0
|
|
self._update_metrics(tick_duration)
|
|
|
|
def _handle_launching(self, slot: dict) -> None:
|
|
"""Check if launching instance has reached running state."""
|
|
instance_id = slot["instance_id"]
|
|
if not instance_id:
|
|
self._db.update_slot_state(slot["slot_id"], SlotState.ERROR)
|
|
return
|
|
|
|
info = self._runtime.describe_instance(instance_id)
|
|
ec2_state = info["state"]
|
|
|
|
if ec2_state == "running":
|
|
self._db.update_slot_state(slot["slot_id"], SlotState.BOOTING)
|
|
log.info("slot_booting", extra={"slot_id": slot["slot_id"]})
|
|
return
|
|
|
|
if ec2_state in _TERMINAL_OR_STOPPED_STATES:
|
|
self._begin_termination(
|
|
slot,
|
|
reason="slot_launch_lost",
|
|
extra={"ec2_state": ec2_state},
|
|
)
|
|
return
|
|
|
|
age_seconds = self._slot_state_age_seconds(slot)
|
|
if age_seconds >= self._config.capacity.launch_timeout_seconds:
|
|
self._begin_termination(
|
|
slot,
|
|
reason="slot_launch_timeout",
|
|
extra={
|
|
"ec2_state": ec2_state,
|
|
"age_seconds": age_seconds,
|
|
"timeout_seconds": self._config.capacity.launch_timeout_seconds,
|
|
},
|
|
)
|
|
|
|
def _handle_booting(self, slot: dict) -> None:
|
|
"""Check if booting instance has a Tailscale IP yet."""
|
|
instance_id = slot["instance_id"]
|
|
if not instance_id:
|
|
self._db.update_slot_state(slot["slot_id"], SlotState.ERROR)
|
|
return
|
|
|
|
info = self._runtime.describe_instance(instance_id)
|
|
ec2_state = info["state"]
|
|
|
|
if ec2_state in _TERMINAL_OR_STOPPED_STATES:
|
|
self._begin_termination(
|
|
slot,
|
|
reason="slot_boot_lost",
|
|
extra={"ec2_state": ec2_state},
|
|
)
|
|
return
|
|
|
|
tailscale_ip = info.get("tailscale_ip")
|
|
if tailscale_ip is not None:
|
|
self._db.update_slot_state(slot["slot_id"], SlotState.BINDING, instance_ip=tailscale_ip)
|
|
try:
|
|
self._haproxy_set_slot_addr(slot["slot_id"], tailscale_ip)
|
|
self._haproxy_enable_slot(slot["slot_id"])
|
|
except HAProxyError:
|
|
log.warning(
|
|
"haproxy_binding_setup_failed",
|
|
extra={"slot_id": slot["slot_id"]},
|
|
exc_info=True,
|
|
)
|
|
return
|
|
|
|
age_seconds = self._slot_state_age_seconds(slot)
|
|
if age_seconds >= self._config.capacity.boot_timeout_seconds:
|
|
self._begin_termination(
|
|
slot,
|
|
reason="slot_boot_timeout",
|
|
extra={
|
|
"ec2_state": ec2_state,
|
|
"age_seconds": age_seconds,
|
|
"timeout_seconds": self._config.capacity.boot_timeout_seconds,
|
|
},
|
|
)
|
|
|
|
def _handle_binding(self, slot: dict, haproxy_health: dict) -> None:
|
|
"""Check HAProxy health to determine when slot is ready."""
|
|
slot_id = slot["slot_id"]
|
|
instance_id = slot.get("instance_id")
|
|
if instance_id:
|
|
info = self._runtime.describe_instance(instance_id)
|
|
ec2_state = info["state"]
|
|
if ec2_state in _TERMINAL_OR_STOPPED_STATES:
|
|
self._begin_termination(
|
|
slot,
|
|
reason="slot_binding_lost",
|
|
extra={"ec2_state": ec2_state},
|
|
)
|
|
self._binding_up_counts.pop(slot_id, None)
|
|
return
|
|
|
|
age_seconds = self._slot_state_age_seconds(slot)
|
|
if age_seconds >= self._config.capacity.binding_timeout_seconds:
|
|
self._begin_termination(
|
|
slot,
|
|
reason="slot_binding_timeout",
|
|
extra={
|
|
"age_seconds": age_seconds,
|
|
"timeout_seconds": self._config.capacity.binding_timeout_seconds,
|
|
},
|
|
)
|
|
self._binding_up_counts.pop(slot_id, None)
|
|
return
|
|
|
|
health = haproxy_health.get(slot_id)
|
|
|
|
if health is not None and health.status == "UP":
|
|
count = self._binding_up_counts.get(slot_id, 0) + 1
|
|
self._binding_up_counts[slot_id] = count
|
|
if count >= self._config.haproxy.check_ready_up_count:
|
|
self._db.update_slot_state(slot_id, SlotState.READY)
|
|
self._binding_up_counts.pop(slot_id, None)
|
|
log.info("slot_ready", extra={"slot_id": slot_id})
|
|
else:
|
|
self._binding_up_counts[slot_id] = 0
|
|
# Retry HAProxy setup
|
|
ip = slot.get("instance_ip")
|
|
if ip:
|
|
try:
|
|
self._haproxy_set_slot_addr(slot_id, ip)
|
|
self._haproxy_enable_slot(slot_id)
|
|
except HAProxyError:
|
|
pass
|
|
|
|
def _handle_ready(self, slot: dict, ec2_by_slot: dict[str, dict]) -> None:
|
|
"""Verify EC2 instance is still alive for ready slots."""
|
|
slot_id = slot["slot_id"]
|
|
ec2_info = ec2_by_slot.get(slot_id)
|
|
|
|
if ec2_info is None or ec2_info["state"] == "terminated":
|
|
self._db.update_slot_state(slot_id, SlotState.ERROR, instance_id=None, instance_ip=None)
|
|
log.warning("slot_ready_instance_gone", extra={"slot_id": slot_id})
|
|
elif ec2_info["state"] == "shutting-down":
|
|
self._db.update_slot_fields(slot_id, interruption_pending=1)
|
|
log.info("slot_interruption_detected", extra={"slot_id": slot_id})
|
|
|
|
def _handle_draining(self, slot: dict) -> None:
|
|
"""Disable HAProxy and terminate when drain conditions are met."""
|
|
slot_id = slot["slot_id"]
|
|
|
|
# Disable HAProxy (idempotent)
|
|
with contextlib.suppress(HAProxyError):
|
|
self._haproxy_disable_slot(slot_id)
|
|
|
|
now = self._clock.now()
|
|
last_change = datetime.fromisoformat(slot["last_state_change"])
|
|
drain_duration = (now - last_change).total_seconds()
|
|
|
|
drain_timeout = self._config.capacity.drain_timeout_seconds
|
|
if slot["lease_count"] == 0 or drain_duration >= drain_timeout:
|
|
instance_id = slot.get("instance_id")
|
|
if instance_id:
|
|
try:
|
|
self._runtime.terminate_instance(instance_id)
|
|
self._metrics.counter(
|
|
"autoscaler_ec2_terminate_total",
|
|
{"result": "success"},
|
|
1.0,
|
|
)
|
|
except Exception:
|
|
self._metrics.counter(
|
|
"autoscaler_ec2_terminate_total",
|
|
{"result": "error"},
|
|
1.0,
|
|
)
|
|
log.warning(
|
|
"terminate_failed",
|
|
extra={"slot_id": slot_id, "instance_id": instance_id},
|
|
exc_info=True,
|
|
)
|
|
self._db.update_slot_state(slot_id, SlotState.TERMINATING)
|
|
log.info(
|
|
"slot_terminating",
|
|
extra={"slot_id": slot_id, "drain_duration": drain_duration},
|
|
)
|
|
|
|
def _handle_terminating(self, slot: dict, ec2_by_slot: dict[str, dict]) -> None:
|
|
"""Wait for EC2 to confirm termination, then reset slot to empty."""
|
|
slot_id = slot["slot_id"]
|
|
instance_id = slot.get("instance_id")
|
|
|
|
if not instance_id:
|
|
self._db.update_slot_state(
|
|
slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0
|
|
)
|
|
log.info("slot_emptied", extra={"slot_id": slot_id})
|
|
return
|
|
|
|
info = self._runtime.describe_instance(instance_id)
|
|
state = info["state"]
|
|
if state == "terminated":
|
|
self._db.update_slot_state(
|
|
slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0
|
|
)
|
|
log.info("slot_emptied", extra={"slot_id": slot_id})
|
|
return
|
|
|
|
age_seconds = self._slot_state_age_seconds(slot)
|
|
if age_seconds >= self._config.capacity.terminating_timeout_seconds:
|
|
self._terminate_instance_best_effort(slot_id, instance_id)
|
|
# Reset last_state_change after a retry so repeated retries are paced.
|
|
self._db.update_slot_state(slot_id, SlotState.TERMINATING)
|
|
log.warning(
|
|
"slot_termination_timeout_retry",
|
|
extra={
|
|
"slot_id": slot_id,
|
|
"instance_id": instance_id,
|
|
"ec2_state": state,
|
|
"age_seconds": age_seconds,
|
|
"timeout_seconds": self._config.capacity.terminating_timeout_seconds,
|
|
},
|
|
)
|
|
|
|
def _slot_state_age_seconds(self, slot: dict) -> float:
|
|
now = self._clock.now()
|
|
last_change = datetime.fromisoformat(slot["last_state_change"])
|
|
return (now - last_change).total_seconds()
|
|
|
|
def _begin_termination(self, slot: dict, reason: str, extra: dict | None = None) -> None:
|
|
slot_id = slot["slot_id"]
|
|
instance_id = slot.get("instance_id")
|
|
if instance_id:
|
|
self._terminate_instance_best_effort(slot_id, instance_id)
|
|
self._db.update_slot_state(slot_id, SlotState.TERMINATING)
|
|
else:
|
|
self._db.update_slot_state(
|
|
slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0
|
|
)
|
|
|
|
payload = {"slot_id": slot_id}
|
|
if instance_id:
|
|
payload["instance_id"] = instance_id
|
|
if extra:
|
|
payload.update(extra)
|
|
log.warning(reason, extra=payload)
|
|
|
|
def _terminate_instance_best_effort(self, slot_id: str, instance_id: str) -> None:
|
|
try:
|
|
self._runtime.terminate_instance(instance_id)
|
|
self._metrics.counter(
|
|
"autoscaler_ec2_terminate_total",
|
|
{"result": "success"},
|
|
1.0,
|
|
)
|
|
except Exception:
|
|
self._metrics.counter(
|
|
"autoscaler_ec2_terminate_total",
|
|
{"result": "error"},
|
|
1.0,
|
|
)
|
|
log.warning(
|
|
"terminate_failed",
|
|
extra={"slot_id": slot_id, "instance_id": instance_id},
|
|
exc_info=True,
|
|
)
|
|
|
|
def _update_metrics(self, tick_duration: float) -> None:
|
|
"""Emit reconciler metrics."""
|
|
summary = self._db.get_state_summary()
|
|
for state, count in summary["slots"].items():
|
|
self._metrics.gauge("autoscaler_slots_total", {"state": state}, float(count))
|
|
self._metrics.histogram_observe("autoscaler_reconcile_duration_seconds", {}, tick_duration)
|
|
|
|
def _haproxy_set_slot_addr(self, slot_id: str, ip: str) -> None:
|
|
try:
|
|
self._haproxy.set_slot_addr(slot_id, ip)
|
|
self._metrics.counter(
|
|
"autoscaler_haproxy_command_total",
|
|
{"cmd": "set_slot_addr", "result": "success"},
|
|
1.0,
|
|
)
|
|
except HAProxyError:
|
|
self._metrics.counter(
|
|
"autoscaler_haproxy_command_total",
|
|
{"cmd": "set_slot_addr", "result": "error"},
|
|
1.0,
|
|
)
|
|
raise
|
|
|
|
def _haproxy_enable_slot(self, slot_id: str) -> None:
|
|
try:
|
|
self._haproxy.enable_slot(slot_id)
|
|
self._metrics.counter(
|
|
"autoscaler_haproxy_command_total",
|
|
{"cmd": "enable_slot", "result": "success"},
|
|
1.0,
|
|
)
|
|
except HAProxyError:
|
|
self._metrics.counter(
|
|
"autoscaler_haproxy_command_total",
|
|
{"cmd": "enable_slot", "result": "error"},
|
|
1.0,
|
|
)
|
|
raise
|
|
|
|
def _haproxy_disable_slot(self, slot_id: str) -> None:
|
|
try:
|
|
self._haproxy.disable_slot(slot_id)
|
|
self._metrics.counter(
|
|
"autoscaler_haproxy_command_total",
|
|
{"cmd": "disable_slot", "result": "success"},
|
|
1.0,
|
|
)
|
|
except HAProxyError:
|
|
self._metrics.counter(
|
|
"autoscaler_haproxy_command_total",
|
|
{"cmd": "disable_slot", "result": "error"},
|
|
1.0,
|
|
)
|
|
raise
|
|
|
|
def _haproxy_read_slot_health(self) -> dict:
|
|
try:
|
|
health = self._haproxy.read_slot_health()
|
|
self._metrics.counter(
|
|
"autoscaler_haproxy_command_total",
|
|
{"cmd": "show_stat", "result": "success"},
|
|
1.0,
|
|
)
|
|
return health
|
|
except HAProxyError:
|
|
self._metrics.counter(
|
|
"autoscaler_haproxy_command_total",
|
|
{"cmd": "show_stat", "result": "error"},
|
|
1.0,
|
|
)
|
|
raise
|