add runtime adapters, scheduler, reconciler, and their unit tests
This commit is contained in:
parent
d1976a5fd8
commit
b63d69c81d
10 changed files with 1471 additions and 28 deletions
|
|
@ -1,11 +1,62 @@
|
|||
"""EC2 user-data template rendering — stub for Plan 02."""
|
||||
"""EC2 user-data template rendering for builder instance bootstrap.
|
||||
|
||||
The generated script follows the NixOS AMI pattern: write config files
|
||||
that existing systemd services (tailscale-autoconnect, nix-daemon) consume,
|
||||
rather than calling ``tailscale up`` directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
|
||||
|
||||
def render_userdata(slot_id: str, region: str, ssm_param: str = "/nix-builder/ts-authkey") -> str:
|
||||
"""Render a bash user-data script for builder instance bootstrap.
|
||||
|
||||
Full implementation in Plan 02.
|
||||
The returned string is a complete shell script. On NixOS AMIs the script
|
||||
is executed by ``amazon-init.service``. The caller (EC2Runtime) passes it
|
||||
to ``run_instances`` as ``UserData``; boto3 base64-encodes automatically.
|
||||
|
||||
Args:
|
||||
slot_id: Autoscaler slot identifier (used as Tailscale hostname suffix).
|
||||
region: AWS region for SSM parameter lookup.
|
||||
ssm_param: SSM parameter path containing the Tailscale auth key.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
return textwrap.dedent(f"""\
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
SLOT_ID="{slot_id}"
|
||||
REGION="{region}"
|
||||
SSM_PARAM="{ssm_param}"
|
||||
|
||||
# --- Fetch Tailscale auth key from SSM Parameter Store ---
|
||||
mkdir -p /run/credentials
|
||||
TS_AUTHKEY=$(aws ssm get-parameter \\
|
||||
--region "$REGION" \\
|
||||
--with-decryption \\
|
||||
--name "$SSM_PARAM" \\
|
||||
--query 'Parameter.Value' \\
|
||||
--output text)
|
||||
printf '%s' "$TS_AUTHKEY" > /run/credentials/tailscale-auth-key
|
||||
chmod 600 /run/credentials/tailscale-auth-key
|
||||
|
||||
# --- Write tailscale-autoconnect config ---
|
||||
mkdir -p /etc/tailscale
|
||||
cat > /etc/tailscale/autoconnect.conf <<TSCONF
|
||||
TS_AUTHKEY_FILE=/run/credentials/tailscale-auth-key
|
||||
TS_AUTHKEY_EPHEMERAL=true
|
||||
TS_AUTHKEY_PREAUTHORIZED=true
|
||||
TS_HOSTNAME=nix-builder-$SLOT_ID
|
||||
TS_EXTRA_ARGS="--ssh --advertise-tags=tag:nix-builder"
|
||||
TSCONF
|
||||
|
||||
# --- Start/restart tailscale-autoconnect so it picks up the config ---
|
||||
systemctl restart tailscale-autoconnect.service || true
|
||||
|
||||
# --- Ensure nix-daemon is running ---
|
||||
systemctl start nix-daemon.service || true
|
||||
|
||||
# --- Signal readiness ---
|
||||
echo "ready" > /run/nix-builder-ready
|
||||
""")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
"""HAProxy runtime socket adapter — stub for Plan 02."""
|
||||
"""HAProxy runtime socket adapter for managing builder slots."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import io
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
|
|
@ -21,7 +24,12 @@ class SlotHealth:
|
|||
class HAProxyRuntime:
|
||||
"""HAProxy runtime CLI adapter via Unix socket.
|
||||
|
||||
Full implementation in Plan 02.
|
||||
Communicates with HAProxy using the admin socket text protocol.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the HAProxy admin Unix socket.
|
||||
backend: HAProxy backend name (e.g. "all").
|
||||
slot_prefix: Server name prefix used for builder slots.
|
||||
"""
|
||||
|
||||
def __init__(self, socket_path: str, backend: str, slot_prefix: str) -> None:
|
||||
|
|
@ -31,24 +39,76 @@ class HAProxyRuntime:
|
|||
|
||||
def set_slot_addr(self, slot_id: str, ip: str, port: int = 22) -> None:
|
||||
"""Update server address for a slot."""
|
||||
raise NotImplementedError
|
||||
cmd = f"set server {self._backend}/{slot_id} addr {ip} port {port}"
|
||||
resp = self._run(cmd)
|
||||
self._check_response(resp, slot_id)
|
||||
|
||||
def enable_slot(self, slot_id: str) -> None:
|
||||
"""Enable a server slot."""
|
||||
raise NotImplementedError
|
||||
cmd = f"enable server {self._backend}/{slot_id}"
|
||||
resp = self._run(cmd)
|
||||
self._check_response(resp, slot_id)
|
||||
|
||||
def disable_slot(self, slot_id: str) -> None:
|
||||
"""Disable a server slot."""
|
||||
raise NotImplementedError
|
||||
cmd = f"disable server {self._backend}/{slot_id}"
|
||||
resp = self._run(cmd)
|
||||
self._check_response(resp, slot_id)
|
||||
|
||||
def slot_is_up(self, slot_id: str) -> bool:
|
||||
"""Return True when HAProxy health status is UP for slot."""
|
||||
raise NotImplementedError
|
||||
health = self.read_slot_health()
|
||||
entry = health.get(slot_id)
|
||||
return entry is not None and entry.status == "UP"
|
||||
|
||||
def slot_session_count(self, slot_id: str) -> int:
|
||||
"""Return current active session count for slot."""
|
||||
raise NotImplementedError
|
||||
health = self.read_slot_health()
|
||||
entry = health.get(slot_id)
|
||||
if entry is None:
|
||||
raise HAProxyError(f"Slot not found in HAProxy stats: {slot_id}")
|
||||
return entry.scur
|
||||
|
||||
def read_slot_health(self) -> dict[str, SlotHealth]:
|
||||
"""Return full stats snapshot for all slots."""
|
||||
raise NotImplementedError
|
||||
"""Return full stats snapshot for all slots in the backend."""
|
||||
raw = self._run("show stat")
|
||||
reader = csv.DictReader(io.StringIO(raw))
|
||||
result: dict[str, SlotHealth] = {}
|
||||
for row in reader:
|
||||
pxname = row.get("# pxname", "").strip()
|
||||
svname = row.get("svname", "").strip()
|
||||
if pxname == self._backend and svname.startswith(self._slot_prefix):
|
||||
result[svname] = SlotHealth(
|
||||
status=row.get("status", "").strip(),
|
||||
scur=int(row.get("scur", "0")),
|
||||
qcur=int(row.get("qcur", "0")),
|
||||
)
|
||||
return result
|
||||
|
||||
def _run(self, command: str) -> str:
|
||||
"""Send a command to the HAProxy admin socket and return the response."""
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.connect(self._socket_path)
|
||||
sock.sendall((command + "\n").encode())
|
||||
sock.shutdown(socket.SHUT_WR)
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
chunk = sock.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
chunks.append(chunk)
|
||||
return b"".join(chunks).decode()
|
||||
except FileNotFoundError as e:
|
||||
raise HAProxyError(f"HAProxy socket not found: {self._socket_path}") from e
|
||||
except ConnectionRefusedError as e:
|
||||
raise HAProxyError(f"Connection refused to HAProxy socket: {self._socket_path}") from e
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
@staticmethod
|
||||
def _check_response(response: str, slot_id: str) -> None:
|
||||
"""Raise HAProxyError if the response indicates an error."""
|
||||
stripped = response.strip()
|
||||
if stripped.startswith(("No such", "Unknown")):
|
||||
raise HAProxyError(f"HAProxy error for {slot_id}: {stripped}")
|
||||
|
|
|
|||
|
|
@ -1 +1,258 @@
|
|||
"""Reconciler — stub for Plan 03."""
|
||||
"""Reconciler — advances slots through the state machine.
|
||||
|
||||
Each tick queries EC2 and HAProxy, then processes each slot according to
|
||||
its current state: launching→booting→binding→ready, with draining and
|
||||
terminating paths for teardown.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .models import SlotState
|
||||
from .providers.haproxy import HAProxyError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import AppConfig
|
||||
from .metrics import MetricsRegistry
|
||||
from .providers.clock import Clock
|
||||
from .providers.haproxy import HAProxyRuntime
|
||||
from .runtime.base import RuntimeAdapter
|
||||
from .state_db import StateDB
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Reconciler:
|
||||
"""Advances slots through the state machine by polling EC2 and HAProxy.
|
||||
|
||||
Maintains binding health-check counters between ticks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: StateDB,
|
||||
runtime: RuntimeAdapter,
|
||||
haproxy: HAProxyRuntime,
|
||||
config: AppConfig,
|
||||
clock: Clock,
|
||||
metrics: MetricsRegistry,
|
||||
) -> None:
|
||||
self._db = db
|
||||
self._runtime = runtime
|
||||
self._haproxy = haproxy
|
||||
self._config = config
|
||||
self._clock = clock
|
||||
self._metrics = metrics
|
||||
self._binding_up_counts: dict[str, int] = {}
|
||||
|
||||
def tick(self) -> None:
|
||||
"""Execute one reconciliation tick."""
|
||||
t0 = time.monotonic()
|
||||
|
||||
# 1. Query EC2
|
||||
try:
|
||||
managed = self._runtime.list_managed_instances()
|
||||
except Exception:
|
||||
log.exception("ec2_list_failed")
|
||||
managed = []
|
||||
ec2_by_slot: dict[str, dict] = {}
|
||||
for inst in managed:
|
||||
sid = inst.get("slot_id")
|
||||
if sid:
|
||||
ec2_by_slot[sid] = inst
|
||||
|
||||
# 2. Query HAProxy
|
||||
try:
|
||||
haproxy_health = self._haproxy.read_slot_health()
|
||||
except HAProxyError:
|
||||
log.warning("haproxy_stat_failed", exc_info=True)
|
||||
haproxy_health = {}
|
||||
|
||||
# 3. Process each slot
|
||||
all_slots = self._db.list_slots()
|
||||
for slot in all_slots:
|
||||
state = slot["state"]
|
||||
if state == SlotState.LAUNCHING.value:
|
||||
self._handle_launching(slot)
|
||||
elif state == SlotState.BOOTING.value:
|
||||
self._handle_booting(slot)
|
||||
elif state == SlotState.BINDING.value:
|
||||
self._handle_binding(slot, haproxy_health)
|
||||
elif state == SlotState.READY.value:
|
||||
self._handle_ready(slot, ec2_by_slot)
|
||||
elif state == SlotState.DRAINING.value:
|
||||
self._handle_draining(slot)
|
||||
elif state == SlotState.TERMINATING.value:
|
||||
self._handle_terminating(slot, ec2_by_slot)
|
||||
|
||||
# 4. Clean stale binding counters
|
||||
binding_ids = {s["slot_id"] for s in all_slots if s["state"] == SlotState.BINDING.value}
|
||||
stale = [k for k in self._binding_up_counts if k not in binding_ids]
|
||||
for k in stale:
|
||||
del self._binding_up_counts[k]
|
||||
|
||||
# 5. Emit metrics
|
||||
tick_duration = time.monotonic() - t0
|
||||
self._update_metrics(tick_duration)
|
||||
|
||||
def _handle_launching(self, slot: dict) -> None:
|
||||
"""Check if launching instance has reached running state."""
|
||||
instance_id = slot["instance_id"]
|
||||
if not instance_id:
|
||||
self._db.update_slot_state(slot["slot_id"], SlotState.ERROR)
|
||||
return
|
||||
|
||||
info = self._runtime.describe_instance(instance_id)
|
||||
ec2_state = info["state"]
|
||||
|
||||
if ec2_state == "running":
|
||||
self._db.update_slot_state(slot["slot_id"], SlotState.BOOTING)
|
||||
log.info("slot_booting", extra={"slot_id": slot["slot_id"]})
|
||||
elif ec2_state in ("terminated", "shutting-down"):
|
||||
self._db.update_slot_state(slot["slot_id"], SlotState.ERROR)
|
||||
log.warning(
|
||||
"slot_launch_terminated",
|
||||
extra={"slot_id": slot["slot_id"], "ec2_state": ec2_state},
|
||||
)
|
||||
|
||||
def _handle_booting(self, slot: dict) -> None:
|
||||
"""Check if booting instance has a Tailscale IP yet."""
|
||||
instance_id = slot["instance_id"]
|
||||
if not instance_id:
|
||||
self._db.update_slot_state(slot["slot_id"], SlotState.ERROR)
|
||||
return
|
||||
|
||||
info = self._runtime.describe_instance(instance_id)
|
||||
ec2_state = info["state"]
|
||||
|
||||
if ec2_state in ("terminated", "shutting-down"):
|
||||
self._db.update_slot_state(slot["slot_id"], SlotState.ERROR)
|
||||
log.warning(
|
||||
"slot_boot_terminated",
|
||||
extra={"slot_id": slot["slot_id"], "ec2_state": ec2_state},
|
||||
)
|
||||
return
|
||||
|
||||
tailscale_ip = info.get("tailscale_ip")
|
||||
if tailscale_ip is not None:
|
||||
self._db.update_slot_state(slot["slot_id"], SlotState.BINDING, instance_ip=tailscale_ip)
|
||||
try:
|
||||
self._haproxy.set_slot_addr(slot["slot_id"], tailscale_ip)
|
||||
self._haproxy.enable_slot(slot["slot_id"])
|
||||
except HAProxyError:
|
||||
log.warning(
|
||||
"haproxy_binding_setup_failed",
|
||||
extra={"slot_id": slot["slot_id"]},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _handle_binding(self, slot: dict, haproxy_health: dict) -> None:
|
||||
"""Check HAProxy health to determine when slot is ready."""
|
||||
slot_id = slot["slot_id"]
|
||||
health = haproxy_health.get(slot_id)
|
||||
|
||||
if health is not None and health.status == "UP":
|
||||
count = self._binding_up_counts.get(slot_id, 0) + 1
|
||||
self._binding_up_counts[slot_id] = count
|
||||
if count >= self._config.haproxy.check_ready_up_count:
|
||||
self._db.update_slot_state(slot_id, SlotState.READY)
|
||||
self._binding_up_counts.pop(slot_id, None)
|
||||
log.info("slot_ready", extra={"slot_id": slot_id})
|
||||
else:
|
||||
self._binding_up_counts[slot_id] = 0
|
||||
# Retry HAProxy setup
|
||||
ip = slot.get("instance_ip")
|
||||
if ip:
|
||||
try:
|
||||
self._haproxy.set_slot_addr(slot_id, ip)
|
||||
self._haproxy.enable_slot(slot_id)
|
||||
except HAProxyError:
|
||||
pass
|
||||
|
||||
# Check if instance is still alive
|
||||
instance_id = slot.get("instance_id")
|
||||
if instance_id:
|
||||
info = self._runtime.describe_instance(instance_id)
|
||||
if info["state"] in ("terminated", "shutting-down"):
|
||||
self._db.update_slot_state(slot_id, SlotState.ERROR)
|
||||
self._binding_up_counts.pop(slot_id, None)
|
||||
log.warning(
|
||||
"slot_binding_terminated",
|
||||
extra={"slot_id": slot_id},
|
||||
)
|
||||
|
||||
def _handle_ready(self, slot: dict, ec2_by_slot: dict[str, dict]) -> None:
|
||||
"""Verify EC2 instance is still alive for ready slots."""
|
||||
slot_id = slot["slot_id"]
|
||||
ec2_info = ec2_by_slot.get(slot_id)
|
||||
|
||||
if ec2_info is None or ec2_info["state"] == "terminated":
|
||||
self._db.update_slot_state(slot_id, SlotState.ERROR, instance_id=None, instance_ip=None)
|
||||
log.warning("slot_ready_instance_gone", extra={"slot_id": slot_id})
|
||||
elif ec2_info["state"] == "shutting-down":
|
||||
self._db.update_slot_fields(slot_id, interruption_pending=1)
|
||||
log.info("slot_interruption_detected", extra={"slot_id": slot_id})
|
||||
|
||||
def _handle_draining(self, slot: dict) -> None:
|
||||
"""Disable HAProxy and terminate when drain conditions are met."""
|
||||
slot_id = slot["slot_id"]
|
||||
|
||||
# Disable HAProxy (idempotent)
|
||||
with contextlib.suppress(HAProxyError):
|
||||
self._haproxy.disable_slot(slot_id)
|
||||
|
||||
now = self._clock.now()
|
||||
last_change = datetime.fromisoformat(slot["last_state_change"])
|
||||
drain_duration = (now - last_change).total_seconds()
|
||||
|
||||
drain_timeout = self._config.capacity.drain_timeout_seconds
|
||||
if slot["lease_count"] == 0 or drain_duration >= drain_timeout:
|
||||
instance_id = slot.get("instance_id")
|
||||
if instance_id:
|
||||
try:
|
||||
self._runtime.terminate_instance(instance_id)
|
||||
self._metrics.counter("autoscaler_ec2_terminate_total", {}, 1.0)
|
||||
except Exception:
|
||||
log.warning(
|
||||
"terminate_failed",
|
||||
extra={"slot_id": slot_id, "instance_id": instance_id},
|
||||
exc_info=True,
|
||||
)
|
||||
self._db.update_slot_state(slot_id, SlotState.TERMINATING)
|
||||
log.info(
|
||||
"slot_terminating",
|
||||
extra={"slot_id": slot_id, "drain_duration": drain_duration},
|
||||
)
|
||||
|
||||
def _handle_terminating(self, slot: dict, ec2_by_slot: dict[str, dict]) -> None:
|
||||
"""Wait for EC2 to confirm termination, then reset slot to empty."""
|
||||
slot_id = slot["slot_id"]
|
||||
instance_id = slot.get("instance_id")
|
||||
|
||||
if not instance_id:
|
||||
self._db.update_slot_state(
|
||||
slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0
|
||||
)
|
||||
log.info("slot_emptied", extra={"slot_id": slot_id})
|
||||
return
|
||||
|
||||
info = self._runtime.describe_instance(instance_id)
|
||||
if info["state"] == "terminated":
|
||||
self._db.update_slot_state(
|
||||
slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0
|
||||
)
|
||||
log.info("slot_emptied", extra={"slot_id": slot_id})
|
||||
|
||||
def _update_metrics(self, tick_duration: float) -> None:
|
||||
"""Emit reconciler metrics."""
|
||||
summary = self._db.get_state_summary()
|
||||
for state, count in summary["slots"].items():
|
||||
if state == "total":
|
||||
continue
|
||||
self._metrics.gauge("autoscaler_slots", {"state": state}, float(count))
|
||||
self._metrics.histogram_observe("autoscaler_reconciler_tick_seconds", {}, tick_duration)
|
||||
|
|
|
|||
|
|
@ -1,32 +1,175 @@
|
|||
"""EC2 runtime adapter — stub for Plan 02."""
|
||||
"""EC2 runtime adapter for managing Spot instances."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from ..config import AwsConfig
|
||||
from .base import RuntimeAdapter
|
||||
from .base import RuntimeError as RuntimeAdapterError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# EC2 ClientError code → normalized error category
|
||||
_ERROR_CATEGORIES: dict[str, str] = {
|
||||
"InsufficientInstanceCapacity": "capacity_unavailable",
|
||||
"SpotMaxPriceTooLow": "price_too_low",
|
||||
"RequestLimitExceeded": "throttled",
|
||||
}
|
||||
|
||||
_RETRYABLE_CODES: frozenset[str] = frozenset({"RequestLimitExceeded"})
|
||||
|
||||
|
||||
class EC2Runtime(RuntimeAdapter):
|
||||
"""EC2 Spot instance runtime adapter.
|
||||
|
||||
Full implementation in Plan 02.
|
||||
Args:
|
||||
config: AWS configuration dataclass.
|
||||
environment: Environment tag value (e.g. ``"dev"``, ``"prod"``).
|
||||
_client: Optional pre-configured boto3 EC2 client (for testing).
|
||||
"""
|
||||
|
||||
def __init__(self, region: str, launch_template_id: str, **kwargs: object) -> None:
|
||||
self._region = region
|
||||
self._launch_template_id = launch_template_id
|
||||
def __init__(
|
||||
self,
|
||||
config: AwsConfig,
|
||||
environment: str = "dev",
|
||||
*,
|
||||
_client: Any = None,
|
||||
) -> None:
|
||||
self._client: Any = _client or boto3.client("ec2", region_name=config.region)
|
||||
self._launch_template_id = config.launch_template_id
|
||||
self._subnet_ids = list(config.subnet_ids)
|
||||
self._security_group_ids = list(config.security_group_ids)
|
||||
self._instance_profile_arn = config.instance_profile_arn
|
||||
self._environment = environment
|
||||
self._subnet_index = 0
|
||||
|
||||
def launch_spot(self, slot_id: str, user_data: str) -> str:
|
||||
"""Launch a spot instance for slot_id."""
|
||||
raise NotImplementedError
|
||||
"""Launch a spot instance for *slot_id*. Return instance ID."""
|
||||
params: dict[str, Any] = {
|
||||
"MinCount": 1,
|
||||
"MaxCount": 1,
|
||||
"LaunchTemplate": {
|
||||
"LaunchTemplateId": self._launch_template_id,
|
||||
"Version": "$Latest",
|
||||
},
|
||||
"InstanceMarketOptions": {
|
||||
"MarketType": "spot",
|
||||
"SpotOptions": {
|
||||
"SpotInstanceType": "one-time",
|
||||
"InstanceInterruptionBehavior": "terminate",
|
||||
},
|
||||
},
|
||||
"UserData": user_data,
|
||||
"TagSpecifications": [
|
||||
{
|
||||
"ResourceType": "instance",
|
||||
"Tags": [
|
||||
{"Key": "Name", "Value": f"nix-builder-{slot_id}"},
|
||||
{"Key": "AutoscalerSlot", "Value": slot_id},
|
||||
{"Key": "ManagedBy", "Value": "nix-builder-autoscaler"},
|
||||
{"Key": "Service", "Value": "nix-builder"},
|
||||
{"Key": "Environment", "Value": self._environment},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
if self._subnet_ids:
|
||||
subnet = self._subnet_ids[self._subnet_index % len(self._subnet_ids)]
|
||||
self._subnet_index += 1
|
||||
params["SubnetId"] = subnet
|
||||
|
||||
resp = self._call_with_backoff(self._client.run_instances, **params)
|
||||
return resp["Instances"][0]["InstanceId"]
|
||||
|
||||
def describe_instance(self, instance_id: str) -> dict:
|
||||
"""Return normalized instance info."""
|
||||
raise NotImplementedError
|
||||
"""Return normalized instance info dict."""
|
||||
try:
|
||||
resp = self._call_with_backoff(
|
||||
self._client.describe_instances, InstanceIds=[instance_id]
|
||||
)
|
||||
except RuntimeAdapterError:
|
||||
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
|
||||
|
||||
reservations = resp.get("Reservations", [])
|
||||
if not reservations or not reservations[0].get("Instances"):
|
||||
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
|
||||
|
||||
inst = reservations[0]["Instances"][0]
|
||||
launch_time = inst.get("LaunchTime")
|
||||
return {
|
||||
"state": inst["State"]["Name"],
|
||||
"tailscale_ip": None,
|
||||
"launch_time": launch_time.isoformat() if launch_time else None,
|
||||
}
|
||||
|
||||
def terminate_instance(self, instance_id: str) -> None:
|
||||
"""Terminate the instance."""
|
||||
raise NotImplementedError
|
||||
self._call_with_backoff(self._client.terminate_instances, InstanceIds=[instance_id])
|
||||
|
||||
def list_managed_instances(self) -> list[dict]:
|
||||
"""Return list of managed instances."""
|
||||
raise NotImplementedError
|
||||
resp = self._call_with_backoff(
|
||||
self._client.describe_instances,
|
||||
Filters=[
|
||||
{"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]},
|
||||
{
|
||||
"Name": "instance-state-name",
|
||||
"Values": ["pending", "running", "shutting-down", "stopping"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
result: list[dict] = []
|
||||
for reservation in resp.get("Reservations", []):
|
||||
for inst in reservation.get("Instances", []):
|
||||
tags = inst.get("Tags", [])
|
||||
result.append(
|
||||
{
|
||||
"instance_id": inst["InstanceId"],
|
||||
"state": inst["State"]["Name"],
|
||||
"slot_id": self._get_tag(tags, "AutoscalerSlot"),
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any:
|
||||
"""Call *fn* with exponential backoff and full jitter on retryable errors."""
|
||||
delay = 0.5
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except ClientError as e:
|
||||
code = e.response["Error"]["Code"]
|
||||
if code in _RETRYABLE_CODES and attempt < max_retries:
|
||||
jitter = random.uniform(0, min(delay, 10.0))
|
||||
time.sleep(jitter)
|
||||
delay *= 2
|
||||
log.warning(
|
||||
"Retryable EC2 error (attempt %d/%d): %s",
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
code,
|
||||
)
|
||||
continue
|
||||
category = _ERROR_CATEGORIES.get(code, "unknown")
|
||||
raise RuntimeAdapterError(str(e), category=category) from e
|
||||
|
||||
# Unreachable — loop always returns or raises on every path
|
||||
msg = "Retries exhausted"
|
||||
raise RuntimeAdapterError(msg, category="unknown")
|
||||
|
||||
@staticmethod
|
||||
def _get_tag(tags: list[dict[str, str]], key: str) -> str | None:
|
||||
"""Extract a tag value from an EC2 tag list."""
|
||||
for tag in tags:
|
||||
if tag.get("Key") == key:
|
||||
return tag.get("Value")
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1 +1,267 @@
|
|||
"""Scheduler — stub for Plan 03."""
|
||||
"""Scheduler — stateless scheduling tick for the autoscaler.
|
||||
|
||||
Each tick: expire reservations, handle interruptions, assign pending
|
||||
reservations to ready slots, launch new capacity, maintain warm pool
|
||||
and min-slots, check idle scale-down, and emit metrics.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .bootstrap.userdata import render_userdata
|
||||
from .models import SlotState
|
||||
from .runtime.base import RuntimeError as RuntimeAdapterError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import AppConfig
|
||||
from .metrics import MetricsRegistry
|
||||
from .providers.clock import Clock
|
||||
from .runtime.base import RuntimeAdapter
|
||||
from .state_db import StateDB
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def scheduling_tick(
|
||||
db: StateDB,
|
||||
runtime: RuntimeAdapter,
|
||||
config: AppConfig,
|
||||
clock: Clock,
|
||||
metrics: MetricsRegistry,
|
||||
) -> None:
|
||||
"""Execute one scheduling tick.
|
||||
|
||||
All dependencies are passed as arguments — no global state.
|
||||
"""
|
||||
t0 = time.monotonic()
|
||||
|
||||
# 1. Expire old reservations
|
||||
expired = db.expire_reservations(clock.now())
|
||||
if expired:
|
||||
log.info("expired_reservations", extra={"count": len(expired), "ids": expired})
|
||||
|
||||
# 2. Handle interruption-pending slots
|
||||
_handle_interruptions(db)
|
||||
|
||||
# 3. Assign pending reservations to ready slots
|
||||
_assign_reservations(db, config)
|
||||
|
||||
# 4. Launch new capacity for unmet demand
|
||||
_launch_for_unmet_demand(db, runtime, config, metrics)
|
||||
|
||||
# 5. Ensure minimum slots and warm pool
|
||||
_ensure_min_and_warm(db, runtime, config, metrics)
|
||||
|
||||
# 6. Check scale-down for idle slots
|
||||
_check_idle_scale_down(db, config, clock)
|
||||
|
||||
# 7. Emit metrics
|
||||
tick_duration = time.monotonic() - t0
|
||||
_update_metrics(db, metrics, tick_duration)
|
||||
|
||||
|
||||
def _handle_interruptions(db: StateDB) -> None:
|
||||
"""Move ready slots with interruption_pending to draining."""
|
||||
ready_slots = db.list_slots(SlotState.READY)
|
||||
for slot in ready_slots:
|
||||
if slot["interruption_pending"]:
|
||||
db.update_slot_state(slot["slot_id"], SlotState.DRAINING, interruption_pending=0)
|
||||
log.info(
|
||||
"interruption_drain",
|
||||
extra={"slot_id": slot["slot_id"]},
|
||||
)
|
||||
|
||||
|
||||
def _assign_reservations(db: StateDB, config: AppConfig) -> None:
|
||||
"""Assign pending reservations to ready slots with capacity."""
|
||||
from .models import ReservationPhase
|
||||
|
||||
pending = db.list_reservations(ReservationPhase.PENDING)
|
||||
if not pending:
|
||||
return
|
||||
|
||||
ready_slots = db.list_slots(SlotState.READY)
|
||||
if not ready_slots:
|
||||
return
|
||||
|
||||
max_leases = config.capacity.max_leases_per_slot
|
||||
# Track in-memory capacity to prevent double-assignment within the same tick
|
||||
capacity_map: dict[str, int] = {s["slot_id"]: s["lease_count"] for s in ready_slots}
|
||||
|
||||
for resv in pending:
|
||||
system = resv["system"]
|
||||
slot = _find_assignable_slot(ready_slots, system, max_leases, capacity_map)
|
||||
if slot is None:
|
||||
continue
|
||||
db.assign_reservation(resv["reservation_id"], slot["slot_id"], slot["instance_id"])
|
||||
capacity_map[slot["slot_id"]] += 1
|
||||
log.info(
|
||||
"reservation_assigned",
|
||||
extra={
|
||||
"reservation_id": resv["reservation_id"],
|
||||
"slot_id": slot["slot_id"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _find_assignable_slot(
|
||||
ready_slots: list[dict],
|
||||
system: str,
|
||||
max_leases: int,
|
||||
capacity_map: dict[str, int],
|
||||
) -> dict | None:
|
||||
"""Return first ready slot for system with remaining capacity, or None."""
|
||||
for slot in ready_slots:
|
||||
if slot["system"] != system:
|
||||
continue
|
||||
sid = slot["slot_id"]
|
||||
current: int = capacity_map[sid] if sid in capacity_map else slot["lease_count"]
|
||||
if current < max_leases:
|
||||
return slot
|
||||
return None
|
||||
|
||||
|
||||
def _count_active_slots(db: StateDB) -> int:
|
||||
"""Count slots NOT in empty or error states."""
|
||||
all_slots = db.list_slots()
|
||||
return sum(
|
||||
1 for s in all_slots if s["state"] not in (SlotState.EMPTY.value, SlotState.ERROR.value)
|
||||
)
|
||||
|
||||
|
||||
def _launch_for_unmet_demand(
|
||||
db: StateDB,
|
||||
runtime: RuntimeAdapter,
|
||||
config: AppConfig,
|
||||
metrics: MetricsRegistry,
|
||||
) -> None:
|
||||
"""Launch new capacity for pending reservations that couldn't be assigned."""
|
||||
from .models import ReservationPhase
|
||||
|
||||
pending = db.list_reservations(ReservationPhase.PENDING)
|
||||
if not pending:
|
||||
return
|
||||
|
||||
active = _count_active_slots(db)
|
||||
if active >= config.capacity.max_slots:
|
||||
return
|
||||
|
||||
empty_slots = db.list_slots(SlotState.EMPTY)
|
||||
if not empty_slots:
|
||||
return
|
||||
|
||||
for launched, slot in enumerate(empty_slots):
|
||||
if launched >= len(pending):
|
||||
break
|
||||
if active + launched >= config.capacity.max_slots:
|
||||
break
|
||||
_launch_slot(db, runtime, config, metrics, slot)
|
||||
|
||||
|
||||
def _ensure_min_and_warm(
|
||||
db: StateDB,
|
||||
runtime: RuntimeAdapter,
|
||||
config: AppConfig,
|
||||
metrics: MetricsRegistry,
|
||||
) -> None:
|
||||
"""Ensure minimum slots and warm pool targets are met."""
|
||||
active = _count_active_slots(db)
|
||||
|
||||
# Ensure min_slots
|
||||
if active < config.capacity.min_slots:
|
||||
needed = config.capacity.min_slots - active
|
||||
empty_slots = db.list_slots(SlotState.EMPTY)
|
||||
launched = 0
|
||||
for slot in empty_slots:
|
||||
if launched >= needed:
|
||||
break
|
||||
if active + launched >= config.capacity.max_slots:
|
||||
break
|
||||
_launch_slot(db, runtime, config, metrics, slot)
|
||||
launched += 1
|
||||
active += launched
|
||||
|
||||
# Ensure warm pool
|
||||
if config.capacity.target_warm_slots > 0:
|
||||
ready_idle = sum(1 for s in db.list_slots(SlotState.READY) if s["lease_count"] == 0)
|
||||
pending_warm = (
|
||||
len(db.list_slots(SlotState.LAUNCHING))
|
||||
+ len(db.list_slots(SlotState.BOOTING))
|
||||
+ len(db.list_slots(SlotState.BINDING))
|
||||
)
|
||||
warm_total = ready_idle + pending_warm
|
||||
if warm_total < config.capacity.target_warm_slots:
|
||||
needed = config.capacity.target_warm_slots - warm_total
|
||||
empty_slots = db.list_slots(SlotState.EMPTY)
|
||||
launched = 0
|
||||
for slot in empty_slots:
|
||||
if launched >= needed:
|
||||
break
|
||||
if active + launched >= config.capacity.max_slots:
|
||||
break
|
||||
_launch_slot(db, runtime, config, metrics, slot)
|
||||
launched += 1
|
||||
|
||||
|
||||
def _launch_slot(
|
||||
db: StateDB,
|
||||
runtime: RuntimeAdapter,
|
||||
config: AppConfig,
|
||||
metrics: MetricsRegistry,
|
||||
slot: dict,
|
||||
) -> None:
|
||||
"""Launch a single slot. Transition to LAUNCHING on success, ERROR on failure."""
|
||||
slot_id = slot["slot_id"]
|
||||
user_data = render_userdata(slot_id, config.aws.region)
|
||||
metrics.counter("autoscaler_ec2_launch_total", {}, 1.0)
|
||||
try:
|
||||
instance_id = runtime.launch_spot(slot_id, user_data)
|
||||
db.update_slot_state(slot_id, SlotState.LAUNCHING, instance_id=instance_id)
|
||||
log.info("slot_launched", extra={"slot_id": slot_id, "instance_id": instance_id})
|
||||
except RuntimeAdapterError as exc:
|
||||
db.update_slot_state(slot_id, SlotState.ERROR)
|
||||
log.warning(
|
||||
"slot_launch_failed",
|
||||
extra={"slot_id": slot_id, "error": str(exc), "category": exc.category},
|
||||
)
|
||||
|
||||
|
||||
def _check_idle_scale_down(db: StateDB, config: AppConfig, clock: Clock) -> None:
|
||||
"""Move idle ready slots to draining when idle threshold exceeded."""
|
||||
ready_slots = db.list_slots(SlotState.READY)
|
||||
now = clock.now()
|
||||
active = _count_active_slots(db)
|
||||
|
||||
for slot in ready_slots:
|
||||
if slot["lease_count"] != 0:
|
||||
continue
|
||||
last_change = datetime.fromisoformat(slot["last_state_change"])
|
||||
idle_seconds = (now - last_change).total_seconds()
|
||||
if idle_seconds > config.capacity.idle_scale_down_seconds:
|
||||
if active <= config.capacity.min_slots:
|
||||
continue
|
||||
db.update_slot_state(slot["slot_id"], SlotState.DRAINING)
|
||||
active -= 1
|
||||
log.info(
|
||||
"idle_scale_down",
|
||||
extra={"slot_id": slot["slot_id"], "idle_seconds": idle_seconds},
|
||||
)
|
||||
|
||||
|
||||
def _update_metrics(db: StateDB, metrics: MetricsRegistry, tick_duration: float) -> None:
|
||||
"""Refresh all gauge/counter/histogram values."""
|
||||
summary = db.get_state_summary()
|
||||
|
||||
for state, count in summary["slots"].items():
|
||||
if state == "total":
|
||||
continue
|
||||
metrics.gauge("autoscaler_slots", {"state": state}, float(count))
|
||||
|
||||
for phase, count in summary["reservations"].items():
|
||||
metrics.gauge("autoscaler_reservations", {"phase": phase}, float(count))
|
||||
|
||||
metrics.histogram_observe("autoscaler_scheduler_tick_seconds", {}, tick_duration)
|
||||
|
|
|
|||
|
|
@ -153,6 +153,47 @@ class StateDB:
|
|||
self._conn.execute("ROLLBACK")
|
||||
raise
|
||||
|
||||
def update_slot_fields(self, slot_id: str, **fields: object) -> None:
|
||||
"""Update specific slot columns without changing state or last_state_change.
|
||||
|
||||
Uses BEGIN IMMEDIATE. Allowed fields: instance_id, instance_ip,
|
||||
instance_launch_time, lease_count, cooldown_until, interruption_pending.
|
||||
"""
|
||||
allowed = {
|
||||
"instance_id",
|
||||
"instance_ip",
|
||||
"instance_launch_time",
|
||||
"lease_count",
|
||||
"cooldown_until",
|
||||
"interruption_pending",
|
||||
}
|
||||
if not fields:
|
||||
return
|
||||
|
||||
set_parts: list[str] = []
|
||||
params: list[object] = []
|
||||
for k, v in fields.items():
|
||||
if k not in allowed:
|
||||
msg = f"Unknown slot field: {k}"
|
||||
raise ValueError(msg)
|
||||
set_parts.append(f"{k} = ?")
|
||||
params.append(v)
|
||||
|
||||
params.append(slot_id)
|
||||
sql = f"UPDATE slots SET {', '.join(set_parts)} WHERE slot_id = ?"
|
||||
|
||||
self._conn.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
self._conn.execute(sql, params)
|
||||
self._record_event_inner(
|
||||
"slot_fields_updated",
|
||||
{"slot_id": slot_id, **fields},
|
||||
)
|
||||
self._conn.execute("COMMIT")
|
||||
except Exception:
|
||||
self._conn.execute("ROLLBACK")
|
||||
raise
|
||||
|
||||
# -- Reservation operations ---------------------------------------------
|
||||
|
||||
def create_reservation(
|
||||
|
|
|
|||
|
|
@ -1 +1,148 @@
|
|||
"""HAProxy provider unit tests — Plan 02."""
|
||||
"""Unit tests for the HAProxy provider, mocking at socket level."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nix_builder_autoscaler.providers.haproxy import HAProxyError, HAProxyRuntime
|
||||
|
||||
# HAProxy `show stat` CSV — trimmed to columns the parser uses.
|
||||
# Full output has many more columns; we keep through `status` (col 17).
|
||||
SHOW_STAT_CSV = (
|
||||
"# pxname,svname,qcur,qmax,scur,smax,slim,stot,"
|
||||
"bin,bout,dreq,dresp,ereq,econ,eresp,wretr,wredis,status\n"
|
||||
"all,BACKEND,0,0,2,5,200,100,5000,6000,0,0,,0,0,0,0,UP\n"
|
||||
"all,slot001,0,0,1,3,50,50,2500,3000,0,0,,0,0,0,0,UP\n"
|
||||
"all,slot002,0,0,1,2,50,50,2500,3000,0,0,,0,0,0,0,DOWN\n"
|
||||
"all,slot003,0,0,0,0,50,0,0,0,0,0,,0,0,0,0,MAINT\n"
|
||||
)
|
||||
|
||||
|
||||
class TestSetSlotAddr:
|
||||
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
|
||||
def test_sends_correct_command(self, mock_socket_cls):
|
||||
mock_sock = MagicMock()
|
||||
mock_socket_cls.return_value = mock_sock
|
||||
mock_sock.recv.side_effect = [b"\n", b""]
|
||||
|
||||
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
|
||||
h.set_slot_addr("slot001", "100.64.0.1", 22)
|
||||
|
||||
mock_sock.connect.assert_called_once_with("/tmp/test.sock")
|
||||
mock_sock.sendall.assert_called_once_with(
|
||||
b"set server all/slot001 addr 100.64.0.1 port 22\n"
|
||||
)
|
||||
|
||||
|
||||
class TestEnableSlot:
|
||||
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
|
||||
def test_sends_correct_command(self, mock_socket_cls):
|
||||
mock_sock = MagicMock()
|
||||
mock_socket_cls.return_value = mock_sock
|
||||
mock_sock.recv.side_effect = [b"\n", b""]
|
||||
|
||||
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
|
||||
h.enable_slot("slot001")
|
||||
|
||||
mock_sock.sendall.assert_called_once_with(b"enable server all/slot001\n")
|
||||
|
||||
|
||||
class TestDisableSlot:
|
||||
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
|
||||
def test_sends_correct_command(self, mock_socket_cls):
|
||||
mock_sock = MagicMock()
|
||||
mock_socket_cls.return_value = mock_sock
|
||||
mock_sock.recv.side_effect = [b"\n", b""]
|
||||
|
||||
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
|
||||
h.disable_slot("slot001")
|
||||
|
||||
mock_sock.sendall.assert_called_once_with(b"disable server all/slot001\n")
|
||||
|
||||
|
||||
class TestReadSlotHealth:
|
||||
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
|
||||
def test_parses_csv_correctly(self, mock_socket_cls):
|
||||
mock_sock = MagicMock()
|
||||
mock_socket_cls.return_value = mock_sock
|
||||
mock_sock.recv.side_effect = [SHOW_STAT_CSV.encode(), b""]
|
||||
|
||||
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
|
||||
health = h.read_slot_health()
|
||||
|
||||
assert len(health) == 3
|
||||
# BACKEND row should be excluded (svname "BACKEND" doesn't start with "slot")
|
||||
|
||||
assert health["slot001"].status == "UP"
|
||||
assert health["slot001"].scur == 1
|
||||
assert health["slot001"].qcur == 0
|
||||
|
||||
assert health["slot002"].status == "DOWN"
|
||||
assert health["slot002"].scur == 1
|
||||
assert health["slot002"].qcur == 0
|
||||
|
||||
assert health["slot003"].status == "MAINT"
|
||||
assert health["slot003"].scur == 0
|
||||
assert health["slot003"].qcur == 0
|
||||
|
||||
|
||||
class TestSlotIsUp:
|
||||
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
|
||||
def test_returns_true_for_up_slot(self, mock_socket_cls):
|
||||
mock_sock = MagicMock()
|
||||
mock_socket_cls.return_value = mock_sock
|
||||
mock_sock.recv.side_effect = [SHOW_STAT_CSV.encode(), b""]
|
||||
|
||||
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
|
||||
assert h.slot_is_up("slot001") is True
|
||||
|
||||
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
|
||||
def test_returns_false_for_down_slot(self, mock_socket_cls):
|
||||
mock_sock = MagicMock()
|
||||
mock_socket_cls.return_value = mock_sock
|
||||
mock_sock.recv.side_effect = [SHOW_STAT_CSV.encode(), b""]
|
||||
|
||||
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
|
||||
assert h.slot_is_up("slot002") is False
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
|
||||
def test_unrecognized_slot_raises_haproxy_error(self, mock_socket_cls):
|
||||
mock_sock = MagicMock()
|
||||
mock_socket_cls.return_value = mock_sock
|
||||
mock_sock.recv.side_effect = [b"No such server.\n", b""]
|
||||
|
||||
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
|
||||
with pytest.raises(HAProxyError, match="No such server"):
|
||||
h.set_slot_addr("slot999", "100.64.0.1")
|
||||
|
||||
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
|
||||
def test_socket_not_found_raises_haproxy_error(self, mock_socket_cls):
|
||||
mock_sock = MagicMock()
|
||||
mock_socket_cls.return_value = mock_sock
|
||||
mock_sock.connect.side_effect = FileNotFoundError("No such file")
|
||||
|
||||
h = HAProxyRuntime("/tmp/nonexistent.sock", "all", "slot")
|
||||
with pytest.raises(HAProxyError, match="socket not found"):
|
||||
h.set_slot_addr("slot001", "100.64.0.1")
|
||||
|
||||
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
|
||||
def test_connection_refused_raises_haproxy_error(self, mock_socket_cls):
|
||||
mock_sock = MagicMock()
|
||||
mock_socket_cls.return_value = mock_sock
|
||||
mock_sock.connect.side_effect = ConnectionRefusedError("Connection refused")
|
||||
|
||||
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
|
||||
with pytest.raises(HAProxyError, match="Connection refused"):
|
||||
h.enable_slot("slot001")
|
||||
|
||||
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
|
||||
def test_slot_session_count_missing_slot_raises(self, mock_socket_cls):
|
||||
mock_sock = MagicMock()
|
||||
mock_socket_cls.return_value = mock_sock
|
||||
mock_sock.recv.side_effect = [SHOW_STAT_CSV.encode(), b""]
|
||||
|
||||
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
|
||||
with pytest.raises(HAProxyError, match="Slot not found"):
|
||||
h.slot_session_count("slot_nonexistent")
|
||||
|
|
|
|||
|
|
@ -1 +1,286 @@
|
|||
"""EC2 runtime unit tests — Plan 02."""
|
||||
"""Unit tests for the EC2 runtime adapter using botocore Stubber."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import patch
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from botocore.stub import Stubber
|
||||
|
||||
from nix_builder_autoscaler.config import AwsConfig
|
||||
from nix_builder_autoscaler.runtime.base import RuntimeError as RuntimeAdapterError
|
||||
from nix_builder_autoscaler.runtime.ec2 import EC2Runtime
|
||||
|
||||
|
||||
def _make_config():
|
||||
return AwsConfig(
|
||||
region="us-east-1",
|
||||
launch_template_id="lt-abc123",
|
||||
subnet_ids=["subnet-aaa", "subnet-bbb"],
|
||||
security_group_ids=["sg-111"],
|
||||
instance_profile_arn="arn:aws:iam::123456789012:instance-profile/nix-builder",
|
||||
)
|
||||
|
||||
|
||||
def _make_runtime(stubber, ec2_client, **kwargs):
|
||||
config = kwargs.pop("config", _make_config())
|
||||
environment = kwargs.pop("environment", "dev")
|
||||
stubber.activate()
|
||||
return EC2Runtime(config, environment=environment, _client=ec2_client)
|
||||
|
||||
|
||||
class TestLaunchSpot:
|
||||
def test_correct_params_and_returns_instance_id(self):
|
||||
config = _make_config()
|
||||
ec2_client = boto3.client("ec2", region_name="us-east-1")
|
||||
stubber = Stubber(ec2_client)
|
||||
|
||||
expected_params = {
|
||||
"MinCount": 1,
|
||||
"MaxCount": 1,
|
||||
"LaunchTemplate": {
|
||||
"LaunchTemplateId": "lt-abc123",
|
||||
"Version": "$Latest",
|
||||
},
|
||||
"InstanceMarketOptions": {
|
||||
"MarketType": "spot",
|
||||
"SpotOptions": {
|
||||
"SpotInstanceType": "one-time",
|
||||
"InstanceInterruptionBehavior": "terminate",
|
||||
},
|
||||
},
|
||||
"SubnetId": "subnet-aaa",
|
||||
"UserData": "#!/bin/bash\necho hello",
|
||||
"TagSpecifications": [
|
||||
{
|
||||
"ResourceType": "instance",
|
||||
"Tags": [
|
||||
{"Key": "Name", "Value": "nix-builder-slot001"},
|
||||
{"Key": "AutoscalerSlot", "Value": "slot001"},
|
||||
{"Key": "ManagedBy", "Value": "nix-builder-autoscaler"},
|
||||
{"Key": "Service", "Value": "nix-builder"},
|
||||
{"Key": "Environment", "Value": "dev"},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = {
|
||||
"Instances": [{"InstanceId": "i-12345678"}],
|
||||
"OwnerId": "123456789012",
|
||||
}
|
||||
|
||||
stubber.add_response("run_instances", response, expected_params)
|
||||
runtime = _make_runtime(stubber, ec2_client, config=config)
|
||||
|
||||
iid = runtime.launch_spot("slot001", "#!/bin/bash\necho hello")
|
||||
assert iid == "i-12345678"
|
||||
stubber.assert_no_pending_responses()
|
||||
|
||||
def test_round_robin_subnets(self):
|
||||
config = _make_config()
|
||||
ec2_client = boto3.client("ec2", region_name="us-east-1")
|
||||
stubber = Stubber(ec2_client)
|
||||
|
||||
# Two launches should use subnet-aaa then subnet-bbb
|
||||
for _subnet in ["subnet-aaa", "subnet-bbb"]:
|
||||
stubber.add_response(
|
||||
"run_instances",
|
||||
{"Instances": [{"InstanceId": "i-abc"}], "OwnerId": "123"},
|
||||
)
|
||||
|
||||
runtime = _make_runtime(stubber, ec2_client, config=config)
|
||||
runtime.launch_spot("slot001", "")
|
||||
runtime.launch_spot("slot002", "")
|
||||
stubber.assert_no_pending_responses()
|
||||
|
||||
|
||||
class TestDescribeInstance:
|
||||
def test_normalizes_response(self):
|
||||
ec2_client = boto3.client("ec2", region_name="us-east-1")
|
||||
stubber = Stubber(ec2_client)
|
||||
|
||||
launch_time = datetime(2026, 1, 15, 12, 30, 0, tzinfo=UTC)
|
||||
response = {
|
||||
"Reservations": [
|
||||
{
|
||||
"Instances": [
|
||||
{
|
||||
"InstanceId": "i-running1",
|
||||
"State": {"Code": 16, "Name": "running"},
|
||||
"LaunchTime": launch_time,
|
||||
"Tags": [
|
||||
{"Key": "AutoscalerSlot", "Value": "slot001"},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
stubber.add_response(
|
||||
"describe_instances",
|
||||
response,
|
||||
{"InstanceIds": ["i-running1"]},
|
||||
)
|
||||
runtime = _make_runtime(stubber, ec2_client)
|
||||
|
||||
info = runtime.describe_instance("i-running1")
|
||||
assert info["state"] == "running"
|
||||
assert info["tailscale_ip"] is None
|
||||
assert info["launch_time"] == launch_time.isoformat()
|
||||
|
||||
def test_missing_instance_returns_terminated(self):
|
||||
ec2_client = boto3.client("ec2", region_name="us-east-1")
|
||||
stubber = Stubber(ec2_client)
|
||||
|
||||
stubber.add_response(
|
||||
"describe_instances",
|
||||
{"Reservations": []},
|
||||
{"InstanceIds": ["i-gone"]},
|
||||
)
|
||||
runtime = _make_runtime(stubber, ec2_client)
|
||||
|
||||
info = runtime.describe_instance("i-gone")
|
||||
assert info["state"] == "terminated"
|
||||
assert info["tailscale_ip"] is None
|
||||
assert info["launch_time"] is None
|
||||
|
||||
|
||||
class TestListManagedInstances:
|
||||
def test_filters_by_tag(self):
|
||||
ec2_client = boto3.client("ec2", region_name="us-east-1")
|
||||
stubber = Stubber(ec2_client)
|
||||
|
||||
expected_params = {
|
||||
"Filters": [
|
||||
{"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]},
|
||||
{
|
||||
"Name": "instance-state-name",
|
||||
"Values": ["pending", "running", "shutting-down", "stopping"],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
response = {
|
||||
"Reservations": [
|
||||
{
|
||||
"Instances": [
|
||||
{
|
||||
"InstanceId": "i-aaa",
|
||||
"State": {"Code": 16, "Name": "running"},
|
||||
"Tags": [
|
||||
{"Key": "AutoscalerSlot", "Value": "slot001"},
|
||||
{"Key": "ManagedBy", "Value": "nix-builder-autoscaler"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"InstanceId": "i-bbb",
|
||||
"State": {"Code": 0, "Name": "pending"},
|
||||
"Tags": [
|
||||
{"Key": "AutoscalerSlot", "Value": "slot002"},
|
||||
{"Key": "ManagedBy", "Value": "nix-builder-autoscaler"},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
stubber.add_response("describe_instances", response, expected_params)
|
||||
runtime = _make_runtime(stubber, ec2_client)
|
||||
|
||||
managed = runtime.list_managed_instances()
|
||||
assert len(managed) == 2
|
||||
assert managed[0]["instance_id"] == "i-aaa"
|
||||
assert managed[0]["state"] == "running"
|
||||
assert managed[0]["slot_id"] == "slot001"
|
||||
assert managed[1]["instance_id"] == "i-bbb"
|
||||
assert managed[1]["state"] == "pending"
|
||||
assert managed[1]["slot_id"] == "slot002"
|
||||
|
||||
|
||||
class TestTerminateInstance:
|
||||
def test_calls_terminate_api(self):
|
||||
ec2_client = boto3.client("ec2", region_name="us-east-1")
|
||||
stubber = Stubber(ec2_client)
|
||||
|
||||
response = {
|
||||
"TerminatingInstances": [
|
||||
{
|
||||
"InstanceId": "i-kill",
|
||||
"CurrentState": {"Code": 32, "Name": "shutting-down"},
|
||||
"PreviousState": {"Code": 16, "Name": "running"},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
stubber.add_response(
|
||||
"terminate_instances",
|
||||
response,
|
||||
{"InstanceIds": ["i-kill"]},
|
||||
)
|
||||
runtime = _make_runtime(stubber, ec2_client)
|
||||
|
||||
# Should not raise
|
||||
runtime.terminate_instance("i-kill")
|
||||
stubber.assert_no_pending_responses()
|
||||
|
||||
|
||||
class TestErrorClassification:
|
||||
def test_insufficient_capacity_classified(self):
|
||||
ec2_client = boto3.client("ec2", region_name="us-east-1")
|
||||
stubber = Stubber(ec2_client)
|
||||
|
||||
stubber.add_client_error(
|
||||
"run_instances",
|
||||
service_error_code="InsufficientInstanceCapacity",
|
||||
service_message="Insufficient capacity",
|
||||
)
|
||||
runtime = _make_runtime(stubber, ec2_client)
|
||||
|
||||
with pytest.raises(RuntimeAdapterError) as exc_info:
|
||||
runtime.launch_spot("slot001", "#!/bin/bash")
|
||||
assert exc_info.value.category == "capacity_unavailable"
|
||||
|
||||
@patch("nix_builder_autoscaler.runtime.ec2.time.sleep")
|
||||
def test_request_limit_exceeded_retried(self, mock_sleep):
|
||||
ec2_client = boto3.client("ec2", region_name="us-east-1")
|
||||
stubber = Stubber(ec2_client)
|
||||
|
||||
# First call: throttled
|
||||
stubber.add_client_error(
|
||||
"run_instances",
|
||||
service_error_code="RequestLimitExceeded",
|
||||
service_message="Rate exceeded",
|
||||
)
|
||||
# Second call: success
|
||||
stubber.add_response(
|
||||
"run_instances",
|
||||
{"Instances": [{"InstanceId": "i-retry123"}], "OwnerId": "123"},
|
||||
)
|
||||
runtime = _make_runtime(stubber, ec2_client)
|
||||
|
||||
iid = runtime.launch_spot("slot001", "#!/bin/bash")
|
||||
assert iid == "i-retry123"
|
||||
assert mock_sleep.called
|
||||
stubber.assert_no_pending_responses()
|
||||
|
||||
@patch("nix_builder_autoscaler.runtime.ec2.time.sleep")
|
||||
def test_request_limit_exceeded_exhausted(self, mock_sleep):
|
||||
"""After max retries, RequestLimitExceeded raises with 'throttled' category."""
|
||||
ec2_client = boto3.client("ec2", region_name="us-east-1")
|
||||
stubber = Stubber(ec2_client)
|
||||
|
||||
# 4 errors (max_retries=3: attempt 0,1,2,3 all fail)
|
||||
for _ in range(4):
|
||||
stubber.add_client_error(
|
||||
"run_instances",
|
||||
service_error_code="RequestLimitExceeded",
|
||||
service_message="Rate exceeded",
|
||||
)
|
||||
runtime = _make_runtime(stubber, ec2_client)
|
||||
|
||||
with pytest.raises(RuntimeAdapterError) as exc_info:
|
||||
runtime.launch_spot("slot001", "#!/bin/bash")
|
||||
assert exc_info.value.category == "throttled"
|
||||
|
|
|
|||
|
|
@ -1 +1,194 @@
|
|||
"""Scheduler unit tests — Plan 03."""
|
||||
|
||||
from nix_builder_autoscaler.config import AppConfig, AwsConfig, CapacityConfig
|
||||
from nix_builder_autoscaler.metrics import MetricsRegistry
|
||||
from nix_builder_autoscaler.models import ReservationPhase, SlotState
|
||||
from nix_builder_autoscaler.providers.clock import FakeClock
|
||||
from nix_builder_autoscaler.runtime.fake import FakeRuntime
|
||||
from nix_builder_autoscaler.scheduler import scheduling_tick
|
||||
from nix_builder_autoscaler.state_db import StateDB
|
||||
|
||||
|
||||
def _make_env(
|
||||
slot_count=3,
|
||||
max_slots=3,
|
||||
max_leases=1,
|
||||
idle_scale_down_seconds=900,
|
||||
target_warm=0,
|
||||
min_slots=0,
|
||||
):
|
||||
clock = FakeClock()
|
||||
db = StateDB(":memory:", clock=clock)
|
||||
db.init_schema()
|
||||
db.init_slots("slot", slot_count, "x86_64-linux", "all")
|
||||
runtime = FakeRuntime(launch_latency_ticks=2, ip_delay_ticks=1)
|
||||
config = AppConfig(
|
||||
capacity=CapacityConfig(
|
||||
max_slots=max_slots,
|
||||
max_leases_per_slot=max_leases,
|
||||
idle_scale_down_seconds=idle_scale_down_seconds,
|
||||
target_warm_slots=target_warm,
|
||||
min_slots=min_slots,
|
||||
reservation_ttl_seconds=1200,
|
||||
),
|
||||
aws=AwsConfig(region="us-east-1"),
|
||||
)
|
||||
metrics = MetricsRegistry()
|
||||
return db, runtime, config, clock, metrics
|
||||
|
||||
|
||||
def _make_slot_ready(db, slot_id, instance_id="i-test1", ip="100.64.0.1"):
|
||||
"""Transition a slot through the full state machine to ready."""
|
||||
db.update_slot_state(slot_id, SlotState.LAUNCHING, instance_id=instance_id)
|
||||
db.update_slot_state(slot_id, SlotState.BOOTING)
|
||||
db.update_slot_state(slot_id, SlotState.BINDING, instance_ip=ip)
|
||||
db.update_slot_state(slot_id, SlotState.READY)
|
||||
|
||||
|
||||
# --- Test cases ---
|
||||
|
||||
|
||||
def test_pending_reservation_assigned_to_ready_slot():
|
||||
db, runtime, config, clock, metrics = _make_env()
|
||||
_make_slot_ready(db, "slot001")
|
||||
|
||||
resv = db.create_reservation("x86_64-linux", "test", None, 1200)
|
||||
|
||||
scheduling_tick(db, runtime, config, clock, metrics)
|
||||
|
||||
updated = db.get_reservation(resv["reservation_id"])
|
||||
assert updated["phase"] == ReservationPhase.READY.value
|
||||
assert updated["slot_id"] == "slot001"
|
||||
assert updated["instance_id"] == "i-test1"
|
||||
|
||||
slot = db.get_slot("slot001")
|
||||
assert slot["lease_count"] == 1
|
||||
|
||||
|
||||
def test_two_pending_one_slot_only_one_assigned_per_tick():
|
||||
db, runtime, config, clock, metrics = _make_env(max_leases=1)
|
||||
_make_slot_ready(db, "slot001")
|
||||
|
||||
r1 = db.create_reservation("x86_64-linux", "test1", None, 1200)
|
||||
r2 = db.create_reservation("x86_64-linux", "test2", None, 1200)
|
||||
|
||||
scheduling_tick(db, runtime, config, clock, metrics)
|
||||
|
||||
u1 = db.get_reservation(r1["reservation_id"])
|
||||
u2 = db.get_reservation(r2["reservation_id"])
|
||||
|
||||
ready_count = sum(1 for r in [u1, u2] if r["phase"] == ReservationPhase.READY.value)
|
||||
pending_count = sum(1 for r in [u1, u2] if r["phase"] == ReservationPhase.PENDING.value)
|
||||
assert ready_count == 1
|
||||
assert pending_count == 1
|
||||
|
||||
slot = db.get_slot("slot001")
|
||||
assert slot["lease_count"] == 1
|
||||
|
||||
|
||||
def test_reservation_expires_when_ttl_passes():
|
||||
db, runtime, config, clock, metrics = _make_env()
|
||||
config.capacity.reservation_ttl_seconds = 60
|
||||
|
||||
db.create_reservation("x86_64-linux", "test", None, 60)
|
||||
|
||||
clock.advance(61)
|
||||
scheduling_tick(db, runtime, config, clock, metrics)
|
||||
|
||||
reservations = db.list_reservations(ReservationPhase.EXPIRED)
|
||||
assert len(reservations) == 1
|
||||
|
||||
|
||||
def test_scale_down_starts_when_idle_exceeds_threshold():
|
||||
db, runtime, config, clock, metrics = _make_env(idle_scale_down_seconds=900)
|
||||
_make_slot_ready(db, "slot001")
|
||||
|
||||
clock.advance(901)
|
||||
scheduling_tick(db, runtime, config, clock, metrics)
|
||||
|
||||
slot = db.get_slot("slot001")
|
||||
assert slot["state"] == SlotState.DRAINING.value
|
||||
|
||||
|
||||
def test_slot_does_not_drain_while_lease_count_positive():
|
||||
db, runtime, config, clock, metrics = _make_env(idle_scale_down_seconds=900)
|
||||
_make_slot_ready(db, "slot001")
|
||||
|
||||
resv = db.create_reservation("x86_64-linux", "test", None, 1200)
|
||||
scheduling_tick(db, runtime, config, clock, metrics)
|
||||
|
||||
# Confirm assigned
|
||||
updated = db.get_reservation(resv["reservation_id"])
|
||||
assert updated["phase"] == ReservationPhase.READY.value
|
||||
|
||||
clock.advance(901)
|
||||
scheduling_tick(db, runtime, config, clock, metrics)
|
||||
|
||||
slot = db.get_slot("slot001")
|
||||
assert slot["state"] == SlotState.READY.value
|
||||
|
||||
|
||||
def test_interruption_pending_slot_moves_to_draining():
|
||||
db, runtime, config, clock, metrics = _make_env()
|
||||
_make_slot_ready(db, "slot001")
|
||||
|
||||
db.update_slot_fields("slot001", interruption_pending=1)
|
||||
|
||||
scheduling_tick(db, runtime, config, clock, metrics)
|
||||
|
||||
slot = db.get_slot("slot001")
|
||||
assert slot["state"] == SlotState.DRAINING.value
|
||||
assert slot["interruption_pending"] == 0
|
||||
|
||||
|
||||
def test_launch_triggered_for_unmet_demand():
|
||||
db, runtime, config, clock, metrics = _make_env()
|
||||
|
||||
db.create_reservation("x86_64-linux", "test", None, 1200)
|
||||
|
||||
scheduling_tick(db, runtime, config, clock, metrics)
|
||||
|
||||
launching = db.list_slots(SlotState.LAUNCHING)
|
||||
assert len(launching) == 1
|
||||
assert launching[0]["instance_id"] is not None
|
||||
|
||||
# FakeRuntime should have one pending instance
|
||||
managed = runtime.list_managed_instances()
|
||||
assert len(managed) == 1
|
||||
|
||||
|
||||
def test_launch_respects_max_slots():
|
||||
db, runtime, config, clock, metrics = _make_env(max_slots=1)
|
||||
_make_slot_ready(db, "slot001")
|
||||
|
||||
# Slot001 is at capacity (lease_count will be 1 after assignment)
|
||||
db.create_reservation("x86_64-linux", "test1", None, 1200)
|
||||
db.create_reservation("x86_64-linux", "test2", None, 1200)
|
||||
|
||||
scheduling_tick(db, runtime, config, clock, metrics)
|
||||
|
||||
# One reservation assigned, one still pending — but no new launch
|
||||
# because active_slots (1) == max_slots (1)
|
||||
launching = db.list_slots(SlotState.LAUNCHING)
|
||||
assert len(launching) == 0
|
||||
|
||||
|
||||
def test_min_slots_maintained():
|
||||
db, runtime, config, clock, metrics = _make_env(min_slots=1)
|
||||
|
||||
# No reservations, all slots empty
|
||||
scheduling_tick(db, runtime, config, clock, metrics)
|
||||
|
||||
launching = db.list_slots(SlotState.LAUNCHING)
|
||||
assert len(launching) == 1
|
||||
|
||||
|
||||
def test_scale_down_respects_min_slots():
|
||||
db, runtime, config, clock, metrics = _make_env(min_slots=1, idle_scale_down_seconds=900)
|
||||
_make_slot_ready(db, "slot001")
|
||||
|
||||
clock.advance(901)
|
||||
scheduling_tick(db, runtime, config, clock, metrics)
|
||||
|
||||
slot = db.get_slot("slot001")
|
||||
assert slot["state"] == SlotState.READY.value
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ line-length = 100
|
|||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "UP", "B", "SIM", "ANN"]
|
||||
ignore = []
|
||||
ignore = ["ANN401"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"*/tests/*" = ["ANN"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue