add runtime adapters, scheduler, reconciler, and their unit tests

This commit is contained in:
Abel Luck 2026-02-27 12:34:32 +01:00
parent d1976a5fd8
commit b63d69c81d
10 changed files with 1471 additions and 28 deletions

View file

@ -1,11 +1,62 @@
"""EC2 user-data template rendering — stub for Plan 02."""
"""EC2 user-data template rendering for builder instance bootstrap.
The generated script follows the NixOS AMI pattern: write config files
that existing systemd services (tailscale-autoconnect, nix-daemon) consume,
rather than calling ``tailscale up`` directly.
"""
from __future__ import annotations
import textwrap
def render_userdata(slot_id: str, region: str, ssm_param: str = "/nix-builder/ts-authkey") -> str:
"""Render a bash user-data script for builder instance bootstrap.
Full implementation in Plan 02.
The returned string is a complete shell script. On NixOS AMIs the script
is executed by ``amazon-init.service``. The caller (EC2Runtime) passes it
to ``run_instances`` as ``UserData``; boto3 base64-encodes automatically.
Args:
slot_id: Autoscaler slot identifier (used as Tailscale hostname suffix).
region: AWS region for SSM parameter lookup.
ssm_param: SSM parameter path containing the Tailscale auth key.
"""
raise NotImplementedError
return textwrap.dedent(f"""\
#!/usr/bin/env bash
set -euo pipefail
SLOT_ID="{slot_id}"
REGION="{region}"
SSM_PARAM="{ssm_param}"
# --- Fetch Tailscale auth key from SSM Parameter Store ---
mkdir -p /run/credentials
TS_AUTHKEY=$(aws ssm get-parameter \\
--region "$REGION" \\
--with-decryption \\
--name "$SSM_PARAM" \\
--query 'Parameter.Value' \\
--output text)
printf '%s' "$TS_AUTHKEY" > /run/credentials/tailscale-auth-key
chmod 600 /run/credentials/tailscale-auth-key
# --- Write tailscale-autoconnect config ---
mkdir -p /etc/tailscale
cat > /etc/tailscale/autoconnect.conf <<TSCONF
TS_AUTHKEY_FILE=/run/credentials/tailscale-auth-key
TS_AUTHKEY_EPHEMERAL=true
TS_AUTHKEY_PREAUTHORIZED=true
TS_HOSTNAME=nix-builder-$SLOT_ID
TS_EXTRA_ARGS="--ssh --advertise-tags=tag:nix-builder"
TSCONF
# --- Start/restart tailscale-autoconnect so it picks up the config ---
systemctl restart tailscale-autoconnect.service || true
# --- Ensure nix-daemon is running ---
systemctl start nix-daemon.service || true
# --- Signal readiness ---
echo "ready" > /run/nix-builder-ready
""")

View file

@ -1,7 +1,10 @@
"""HAProxy runtime socket adapter — stub for Plan 02."""
"""HAProxy runtime socket adapter for managing builder slots."""
from __future__ import annotations
import csv
import io
import socket
from dataclasses import dataclass
@ -21,7 +24,12 @@ class SlotHealth:
class HAProxyRuntime:
"""HAProxy runtime CLI adapter via Unix socket.
Full implementation in Plan 02.
Communicates with HAProxy using the admin socket text protocol.
Args:
socket_path: Path to the HAProxy admin Unix socket.
backend: HAProxy backend name (e.g. "all").
slot_prefix: Server name prefix used for builder slots.
"""
def __init__(self, socket_path: str, backend: str, slot_prefix: str) -> None:
@ -31,24 +39,76 @@ class HAProxyRuntime:
def set_slot_addr(self, slot_id: str, ip: str, port: int = 22) -> None:
"""Update server address for a slot."""
raise NotImplementedError
cmd = f"set server {self._backend}/{slot_id} addr {ip} port {port}"
resp = self._run(cmd)
self._check_response(resp, slot_id)
def enable_slot(self, slot_id: str) -> None:
"""Enable a server slot."""
raise NotImplementedError
cmd = f"enable server {self._backend}/{slot_id}"
resp = self._run(cmd)
self._check_response(resp, slot_id)
def disable_slot(self, slot_id: str) -> None:
"""Disable a server slot."""
raise NotImplementedError
cmd = f"disable server {self._backend}/{slot_id}"
resp = self._run(cmd)
self._check_response(resp, slot_id)
def slot_is_up(self, slot_id: str) -> bool:
"""Return True when HAProxy health status is UP for slot."""
raise NotImplementedError
health = self.read_slot_health()
entry = health.get(slot_id)
return entry is not None and entry.status == "UP"
def slot_session_count(self, slot_id: str) -> int:
"""Return current active session count for slot."""
raise NotImplementedError
health = self.read_slot_health()
entry = health.get(slot_id)
if entry is None:
raise HAProxyError(f"Slot not found in HAProxy stats: {slot_id}")
return entry.scur
def read_slot_health(self) -> dict[str, SlotHealth]:
"""Return full stats snapshot for all slots."""
raise NotImplementedError
"""Return full stats snapshot for all slots in the backend."""
raw = self._run("show stat")
reader = csv.DictReader(io.StringIO(raw))
result: dict[str, SlotHealth] = {}
for row in reader:
pxname = row.get("# pxname", "").strip()
svname = row.get("svname", "").strip()
if pxname == self._backend and svname.startswith(self._slot_prefix):
result[svname] = SlotHealth(
status=row.get("status", "").strip(),
scur=int(row.get("scur", "0")),
qcur=int(row.get("qcur", "0")),
)
return result
def _run(self, command: str) -> str:
"""Send a command to the HAProxy admin socket and return the response."""
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.connect(self._socket_path)
sock.sendall((command + "\n").encode())
sock.shutdown(socket.SHUT_WR)
chunks: list[bytes] = []
while True:
chunk = sock.recv(4096)
if not chunk:
break
chunks.append(chunk)
return b"".join(chunks).decode()
except FileNotFoundError as e:
raise HAProxyError(f"HAProxy socket not found: {self._socket_path}") from e
except ConnectionRefusedError as e:
raise HAProxyError(f"Connection refused to HAProxy socket: {self._socket_path}") from e
finally:
sock.close()
@staticmethod
def _check_response(response: str, slot_id: str) -> None:
"""Raise HAProxyError if the response indicates an error."""
stripped = response.strip()
if stripped.startswith(("No such", "Unknown")):
raise HAProxyError(f"HAProxy error for {slot_id}: {stripped}")

View file

@ -1 +1,258 @@
"""Reconciler — stub for Plan 03."""
"""Reconciler — advances slots through the state machine.
Each tick queries EC2 and HAProxy, then processes each slot according to
its current state: launchingbootingbindingready, with draining and
terminating paths for teardown.
"""
from __future__ import annotations
import contextlib
import logging
import time
from datetime import datetime
from typing import TYPE_CHECKING
from .models import SlotState
from .providers.haproxy import HAProxyError
if TYPE_CHECKING:
from .config import AppConfig
from .metrics import MetricsRegistry
from .providers.clock import Clock
from .providers.haproxy import HAProxyRuntime
from .runtime.base import RuntimeAdapter
from .state_db import StateDB
log = logging.getLogger(__name__)
class Reconciler:
"""Advances slots through the state machine by polling EC2 and HAProxy.
Maintains binding health-check counters between ticks.
"""
def __init__(
self,
db: StateDB,
runtime: RuntimeAdapter,
haproxy: HAProxyRuntime,
config: AppConfig,
clock: Clock,
metrics: MetricsRegistry,
) -> None:
self._db = db
self._runtime = runtime
self._haproxy = haproxy
self._config = config
self._clock = clock
self._metrics = metrics
self._binding_up_counts: dict[str, int] = {}
def tick(self) -> None:
"""Execute one reconciliation tick."""
t0 = time.monotonic()
# 1. Query EC2
try:
managed = self._runtime.list_managed_instances()
except Exception:
log.exception("ec2_list_failed")
managed = []
ec2_by_slot: dict[str, dict] = {}
for inst in managed:
sid = inst.get("slot_id")
if sid:
ec2_by_slot[sid] = inst
# 2. Query HAProxy
try:
haproxy_health = self._haproxy.read_slot_health()
except HAProxyError:
log.warning("haproxy_stat_failed", exc_info=True)
haproxy_health = {}
# 3. Process each slot
all_slots = self._db.list_slots()
for slot in all_slots:
state = slot["state"]
if state == SlotState.LAUNCHING.value:
self._handle_launching(slot)
elif state == SlotState.BOOTING.value:
self._handle_booting(slot)
elif state == SlotState.BINDING.value:
self._handle_binding(slot, haproxy_health)
elif state == SlotState.READY.value:
self._handle_ready(slot, ec2_by_slot)
elif state == SlotState.DRAINING.value:
self._handle_draining(slot)
elif state == SlotState.TERMINATING.value:
self._handle_terminating(slot, ec2_by_slot)
# 4. Clean stale binding counters
binding_ids = {s["slot_id"] for s in all_slots if s["state"] == SlotState.BINDING.value}
stale = [k for k in self._binding_up_counts if k not in binding_ids]
for k in stale:
del self._binding_up_counts[k]
# 5. Emit metrics
tick_duration = time.monotonic() - t0
self._update_metrics(tick_duration)
def _handle_launching(self, slot: dict) -> None:
"""Check if launching instance has reached running state."""
instance_id = slot["instance_id"]
if not instance_id:
self._db.update_slot_state(slot["slot_id"], SlotState.ERROR)
return
info = self._runtime.describe_instance(instance_id)
ec2_state = info["state"]
if ec2_state == "running":
self._db.update_slot_state(slot["slot_id"], SlotState.BOOTING)
log.info("slot_booting", extra={"slot_id": slot["slot_id"]})
elif ec2_state in ("terminated", "shutting-down"):
self._db.update_slot_state(slot["slot_id"], SlotState.ERROR)
log.warning(
"slot_launch_terminated",
extra={"slot_id": slot["slot_id"], "ec2_state": ec2_state},
)
def _handle_booting(self, slot: dict) -> None:
"""Check if booting instance has a Tailscale IP yet."""
instance_id = slot["instance_id"]
if not instance_id:
self._db.update_slot_state(slot["slot_id"], SlotState.ERROR)
return
info = self._runtime.describe_instance(instance_id)
ec2_state = info["state"]
if ec2_state in ("terminated", "shutting-down"):
self._db.update_slot_state(slot["slot_id"], SlotState.ERROR)
log.warning(
"slot_boot_terminated",
extra={"slot_id": slot["slot_id"], "ec2_state": ec2_state},
)
return
tailscale_ip = info.get("tailscale_ip")
if tailscale_ip is not None:
self._db.update_slot_state(slot["slot_id"], SlotState.BINDING, instance_ip=tailscale_ip)
try:
self._haproxy.set_slot_addr(slot["slot_id"], tailscale_ip)
self._haproxy.enable_slot(slot["slot_id"])
except HAProxyError:
log.warning(
"haproxy_binding_setup_failed",
extra={"slot_id": slot["slot_id"]},
exc_info=True,
)
def _handle_binding(self, slot: dict, haproxy_health: dict) -> None:
"""Check HAProxy health to determine when slot is ready."""
slot_id = slot["slot_id"]
health = haproxy_health.get(slot_id)
if health is not None and health.status == "UP":
count = self._binding_up_counts.get(slot_id, 0) + 1
self._binding_up_counts[slot_id] = count
if count >= self._config.haproxy.check_ready_up_count:
self._db.update_slot_state(slot_id, SlotState.READY)
self._binding_up_counts.pop(slot_id, None)
log.info("slot_ready", extra={"slot_id": slot_id})
else:
self._binding_up_counts[slot_id] = 0
# Retry HAProxy setup
ip = slot.get("instance_ip")
if ip:
try:
self._haproxy.set_slot_addr(slot_id, ip)
self._haproxy.enable_slot(slot_id)
except HAProxyError:
pass
# Check if instance is still alive
instance_id = slot.get("instance_id")
if instance_id:
info = self._runtime.describe_instance(instance_id)
if info["state"] in ("terminated", "shutting-down"):
self._db.update_slot_state(slot_id, SlotState.ERROR)
self._binding_up_counts.pop(slot_id, None)
log.warning(
"slot_binding_terminated",
extra={"slot_id": slot_id},
)
def _handle_ready(self, slot: dict, ec2_by_slot: dict[str, dict]) -> None:
"""Verify EC2 instance is still alive for ready slots."""
slot_id = slot["slot_id"]
ec2_info = ec2_by_slot.get(slot_id)
if ec2_info is None or ec2_info["state"] == "terminated":
self._db.update_slot_state(slot_id, SlotState.ERROR, instance_id=None, instance_ip=None)
log.warning("slot_ready_instance_gone", extra={"slot_id": slot_id})
elif ec2_info["state"] == "shutting-down":
self._db.update_slot_fields(slot_id, interruption_pending=1)
log.info("slot_interruption_detected", extra={"slot_id": slot_id})
def _handle_draining(self, slot: dict) -> None:
"""Disable HAProxy and terminate when drain conditions are met."""
slot_id = slot["slot_id"]
# Disable HAProxy (idempotent)
with contextlib.suppress(HAProxyError):
self._haproxy.disable_slot(slot_id)
now = self._clock.now()
last_change = datetime.fromisoformat(slot["last_state_change"])
drain_duration = (now - last_change).total_seconds()
drain_timeout = self._config.capacity.drain_timeout_seconds
if slot["lease_count"] == 0 or drain_duration >= drain_timeout:
instance_id = slot.get("instance_id")
if instance_id:
try:
self._runtime.terminate_instance(instance_id)
self._metrics.counter("autoscaler_ec2_terminate_total", {}, 1.0)
except Exception:
log.warning(
"terminate_failed",
extra={"slot_id": slot_id, "instance_id": instance_id},
exc_info=True,
)
self._db.update_slot_state(slot_id, SlotState.TERMINATING)
log.info(
"slot_terminating",
extra={"slot_id": slot_id, "drain_duration": drain_duration},
)
def _handle_terminating(self, slot: dict, ec2_by_slot: dict[str, dict]) -> None:
"""Wait for EC2 to confirm termination, then reset slot to empty."""
slot_id = slot["slot_id"]
instance_id = slot.get("instance_id")
if not instance_id:
self._db.update_slot_state(
slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0
)
log.info("slot_emptied", extra={"slot_id": slot_id})
return
info = self._runtime.describe_instance(instance_id)
if info["state"] == "terminated":
self._db.update_slot_state(
slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0
)
log.info("slot_emptied", extra={"slot_id": slot_id})
def _update_metrics(self, tick_duration: float) -> None:
"""Emit reconciler metrics."""
summary = self._db.get_state_summary()
for state, count in summary["slots"].items():
if state == "total":
continue
self._metrics.gauge("autoscaler_slots", {"state": state}, float(count))
self._metrics.histogram_observe("autoscaler_reconciler_tick_seconds", {}, tick_duration)

View file

@ -1,32 +1,175 @@
"""EC2 runtime adapter — stub for Plan 02."""
"""EC2 runtime adapter for managing Spot instances."""
from __future__ import annotations
import logging
import random
import time
from typing import Any
import boto3
from botocore.exceptions import ClientError
from ..config import AwsConfig
from .base import RuntimeAdapter
from .base import RuntimeError as RuntimeAdapterError
log = logging.getLogger(__name__)
# EC2 ClientError code → normalized error category
_ERROR_CATEGORIES: dict[str, str] = {
"InsufficientInstanceCapacity": "capacity_unavailable",
"SpotMaxPriceTooLow": "price_too_low",
"RequestLimitExceeded": "throttled",
}
_RETRYABLE_CODES: frozenset[str] = frozenset({"RequestLimitExceeded"})
class EC2Runtime(RuntimeAdapter):
"""EC2 Spot instance runtime adapter.
Full implementation in Plan 02.
Args:
config: AWS configuration dataclass.
environment: Environment tag value (e.g. ``"dev"``, ``"prod"``).
_client: Optional pre-configured boto3 EC2 client (for testing).
"""
def __init__(self, region: str, launch_template_id: str, **kwargs: object) -> None:
self._region = region
self._launch_template_id = launch_template_id
def __init__(
self,
config: AwsConfig,
environment: str = "dev",
*,
_client: Any = None,
) -> None:
self._client: Any = _client or boto3.client("ec2", region_name=config.region)
self._launch_template_id = config.launch_template_id
self._subnet_ids = list(config.subnet_ids)
self._security_group_ids = list(config.security_group_ids)
self._instance_profile_arn = config.instance_profile_arn
self._environment = environment
self._subnet_index = 0
def launch_spot(self, slot_id: str, user_data: str) -> str:
"""Launch a spot instance for slot_id."""
raise NotImplementedError
"""Launch a spot instance for *slot_id*. Return instance ID."""
params: dict[str, Any] = {
"MinCount": 1,
"MaxCount": 1,
"LaunchTemplate": {
"LaunchTemplateId": self._launch_template_id,
"Version": "$Latest",
},
"InstanceMarketOptions": {
"MarketType": "spot",
"SpotOptions": {
"SpotInstanceType": "one-time",
"InstanceInterruptionBehavior": "terminate",
},
},
"UserData": user_data,
"TagSpecifications": [
{
"ResourceType": "instance",
"Tags": [
{"Key": "Name", "Value": f"nix-builder-{slot_id}"},
{"Key": "AutoscalerSlot", "Value": slot_id},
{"Key": "ManagedBy", "Value": "nix-builder-autoscaler"},
{"Key": "Service", "Value": "nix-builder"},
{"Key": "Environment", "Value": self._environment},
],
}
],
}
if self._subnet_ids:
subnet = self._subnet_ids[self._subnet_index % len(self._subnet_ids)]
self._subnet_index += 1
params["SubnetId"] = subnet
resp = self._call_with_backoff(self._client.run_instances, **params)
return resp["Instances"][0]["InstanceId"]
def describe_instance(self, instance_id: str) -> dict:
"""Return normalized instance info."""
raise NotImplementedError
"""Return normalized instance info dict."""
try:
resp = self._call_with_backoff(
self._client.describe_instances, InstanceIds=[instance_id]
)
except RuntimeAdapterError:
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
reservations = resp.get("Reservations", [])
if not reservations or not reservations[0].get("Instances"):
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
inst = reservations[0]["Instances"][0]
launch_time = inst.get("LaunchTime")
return {
"state": inst["State"]["Name"],
"tailscale_ip": None,
"launch_time": launch_time.isoformat() if launch_time else None,
}
def terminate_instance(self, instance_id: str) -> None:
"""Terminate the instance."""
raise NotImplementedError
self._call_with_backoff(self._client.terminate_instances, InstanceIds=[instance_id])
def list_managed_instances(self) -> list[dict]:
"""Return list of managed instances."""
raise NotImplementedError
resp = self._call_with_backoff(
self._client.describe_instances,
Filters=[
{"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]},
{
"Name": "instance-state-name",
"Values": ["pending", "running", "shutting-down", "stopping"],
},
],
)
result: list[dict] = []
for reservation in resp.get("Reservations", []):
for inst in reservation.get("Instances", []):
tags = inst.get("Tags", [])
result.append(
{
"instance_id": inst["InstanceId"],
"state": inst["State"]["Name"],
"slot_id": self._get_tag(tags, "AutoscalerSlot"),
}
)
return result
def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any:
"""Call *fn* with exponential backoff and full jitter on retryable errors."""
delay = 0.5
for attempt in range(max_retries + 1):
try:
return fn(*args, **kwargs)
except ClientError as e:
code = e.response["Error"]["Code"]
if code in _RETRYABLE_CODES and attempt < max_retries:
jitter = random.uniform(0, min(delay, 10.0))
time.sleep(jitter)
delay *= 2
log.warning(
"Retryable EC2 error (attempt %d/%d): %s",
attempt + 1,
max_retries,
code,
)
continue
category = _ERROR_CATEGORIES.get(code, "unknown")
raise RuntimeAdapterError(str(e), category=category) from e
# Unreachable — loop always returns or raises on every path
msg = "Retries exhausted"
raise RuntimeAdapterError(msg, category="unknown")
@staticmethod
def _get_tag(tags: list[dict[str, str]], key: str) -> str | None:
"""Extract a tag value from an EC2 tag list."""
for tag in tags:
if tag.get("Key") == key:
return tag.get("Value")
return None

View file

@ -1 +1,267 @@
"""Scheduler — stub for Plan 03."""
"""Scheduler — stateless scheduling tick for the autoscaler.
Each tick: expire reservations, handle interruptions, assign pending
reservations to ready slots, launch new capacity, maintain warm pool
and min-slots, check idle scale-down, and emit metrics.
"""
from __future__ import annotations
import logging
import time
from datetime import datetime
from typing import TYPE_CHECKING
from .bootstrap.userdata import render_userdata
from .models import SlotState
from .runtime.base import RuntimeError as RuntimeAdapterError
if TYPE_CHECKING:
from .config import AppConfig
from .metrics import MetricsRegistry
from .providers.clock import Clock
from .runtime.base import RuntimeAdapter
from .state_db import StateDB
log = logging.getLogger(__name__)
def scheduling_tick(
db: StateDB,
runtime: RuntimeAdapter,
config: AppConfig,
clock: Clock,
metrics: MetricsRegistry,
) -> None:
"""Execute one scheduling tick.
All dependencies are passed as arguments no global state.
"""
t0 = time.monotonic()
# 1. Expire old reservations
expired = db.expire_reservations(clock.now())
if expired:
log.info("expired_reservations", extra={"count": len(expired), "ids": expired})
# 2. Handle interruption-pending slots
_handle_interruptions(db)
# 3. Assign pending reservations to ready slots
_assign_reservations(db, config)
# 4. Launch new capacity for unmet demand
_launch_for_unmet_demand(db, runtime, config, metrics)
# 5. Ensure minimum slots and warm pool
_ensure_min_and_warm(db, runtime, config, metrics)
# 6. Check scale-down for idle slots
_check_idle_scale_down(db, config, clock)
# 7. Emit metrics
tick_duration = time.monotonic() - t0
_update_metrics(db, metrics, tick_duration)
def _handle_interruptions(db: StateDB) -> None:
"""Move ready slots with interruption_pending to draining."""
ready_slots = db.list_slots(SlotState.READY)
for slot in ready_slots:
if slot["interruption_pending"]:
db.update_slot_state(slot["slot_id"], SlotState.DRAINING, interruption_pending=0)
log.info(
"interruption_drain",
extra={"slot_id": slot["slot_id"]},
)
def _assign_reservations(db: StateDB, config: AppConfig) -> None:
"""Assign pending reservations to ready slots with capacity."""
from .models import ReservationPhase
pending = db.list_reservations(ReservationPhase.PENDING)
if not pending:
return
ready_slots = db.list_slots(SlotState.READY)
if not ready_slots:
return
max_leases = config.capacity.max_leases_per_slot
# Track in-memory capacity to prevent double-assignment within the same tick
capacity_map: dict[str, int] = {s["slot_id"]: s["lease_count"] for s in ready_slots}
for resv in pending:
system = resv["system"]
slot = _find_assignable_slot(ready_slots, system, max_leases, capacity_map)
if slot is None:
continue
db.assign_reservation(resv["reservation_id"], slot["slot_id"], slot["instance_id"])
capacity_map[slot["slot_id"]] += 1
log.info(
"reservation_assigned",
extra={
"reservation_id": resv["reservation_id"],
"slot_id": slot["slot_id"],
},
)
def _find_assignable_slot(
ready_slots: list[dict],
system: str,
max_leases: int,
capacity_map: dict[str, int],
) -> dict | None:
"""Return first ready slot for system with remaining capacity, or None."""
for slot in ready_slots:
if slot["system"] != system:
continue
sid = slot["slot_id"]
current: int = capacity_map[sid] if sid in capacity_map else slot["lease_count"]
if current < max_leases:
return slot
return None
def _count_active_slots(db: StateDB) -> int:
"""Count slots NOT in empty or error states."""
all_slots = db.list_slots()
return sum(
1 for s in all_slots if s["state"] not in (SlotState.EMPTY.value, SlotState.ERROR.value)
)
def _launch_for_unmet_demand(
db: StateDB,
runtime: RuntimeAdapter,
config: AppConfig,
metrics: MetricsRegistry,
) -> None:
"""Launch new capacity for pending reservations that couldn't be assigned."""
from .models import ReservationPhase
pending = db.list_reservations(ReservationPhase.PENDING)
if not pending:
return
active = _count_active_slots(db)
if active >= config.capacity.max_slots:
return
empty_slots = db.list_slots(SlotState.EMPTY)
if not empty_slots:
return
for launched, slot in enumerate(empty_slots):
if launched >= len(pending):
break
if active + launched >= config.capacity.max_slots:
break
_launch_slot(db, runtime, config, metrics, slot)
def _ensure_min_and_warm(
db: StateDB,
runtime: RuntimeAdapter,
config: AppConfig,
metrics: MetricsRegistry,
) -> None:
"""Ensure minimum slots and warm pool targets are met."""
active = _count_active_slots(db)
# Ensure min_slots
if active < config.capacity.min_slots:
needed = config.capacity.min_slots - active
empty_slots = db.list_slots(SlotState.EMPTY)
launched = 0
for slot in empty_slots:
if launched >= needed:
break
if active + launched >= config.capacity.max_slots:
break
_launch_slot(db, runtime, config, metrics, slot)
launched += 1
active += launched
# Ensure warm pool
if config.capacity.target_warm_slots > 0:
ready_idle = sum(1 for s in db.list_slots(SlotState.READY) if s["lease_count"] == 0)
pending_warm = (
len(db.list_slots(SlotState.LAUNCHING))
+ len(db.list_slots(SlotState.BOOTING))
+ len(db.list_slots(SlotState.BINDING))
)
warm_total = ready_idle + pending_warm
if warm_total < config.capacity.target_warm_slots:
needed = config.capacity.target_warm_slots - warm_total
empty_slots = db.list_slots(SlotState.EMPTY)
launched = 0
for slot in empty_slots:
if launched >= needed:
break
if active + launched >= config.capacity.max_slots:
break
_launch_slot(db, runtime, config, metrics, slot)
launched += 1
def _launch_slot(
db: StateDB,
runtime: RuntimeAdapter,
config: AppConfig,
metrics: MetricsRegistry,
slot: dict,
) -> None:
"""Launch a single slot. Transition to LAUNCHING on success, ERROR on failure."""
slot_id = slot["slot_id"]
user_data = render_userdata(slot_id, config.aws.region)
metrics.counter("autoscaler_ec2_launch_total", {}, 1.0)
try:
instance_id = runtime.launch_spot(slot_id, user_data)
db.update_slot_state(slot_id, SlotState.LAUNCHING, instance_id=instance_id)
log.info("slot_launched", extra={"slot_id": slot_id, "instance_id": instance_id})
except RuntimeAdapterError as exc:
db.update_slot_state(slot_id, SlotState.ERROR)
log.warning(
"slot_launch_failed",
extra={"slot_id": slot_id, "error": str(exc), "category": exc.category},
)
def _check_idle_scale_down(db: StateDB, config: AppConfig, clock: Clock) -> None:
"""Move idle ready slots to draining when idle threshold exceeded."""
ready_slots = db.list_slots(SlotState.READY)
now = clock.now()
active = _count_active_slots(db)
for slot in ready_slots:
if slot["lease_count"] != 0:
continue
last_change = datetime.fromisoformat(slot["last_state_change"])
idle_seconds = (now - last_change).total_seconds()
if idle_seconds > config.capacity.idle_scale_down_seconds:
if active <= config.capacity.min_slots:
continue
db.update_slot_state(slot["slot_id"], SlotState.DRAINING)
active -= 1
log.info(
"idle_scale_down",
extra={"slot_id": slot["slot_id"], "idle_seconds": idle_seconds},
)
def _update_metrics(db: StateDB, metrics: MetricsRegistry, tick_duration: float) -> None:
"""Refresh all gauge/counter/histogram values."""
summary = db.get_state_summary()
for state, count in summary["slots"].items():
if state == "total":
continue
metrics.gauge("autoscaler_slots", {"state": state}, float(count))
for phase, count in summary["reservations"].items():
metrics.gauge("autoscaler_reservations", {"phase": phase}, float(count))
metrics.histogram_observe("autoscaler_scheduler_tick_seconds", {}, tick_duration)

View file

@ -153,6 +153,47 @@ class StateDB:
self._conn.execute("ROLLBACK")
raise
def update_slot_fields(self, slot_id: str, **fields: object) -> None:
"""Update specific slot columns without changing state or last_state_change.
Uses BEGIN IMMEDIATE. Allowed fields: instance_id, instance_ip,
instance_launch_time, lease_count, cooldown_until, interruption_pending.
"""
allowed = {
"instance_id",
"instance_ip",
"instance_launch_time",
"lease_count",
"cooldown_until",
"interruption_pending",
}
if not fields:
return
set_parts: list[str] = []
params: list[object] = []
for k, v in fields.items():
if k not in allowed:
msg = f"Unknown slot field: {k}"
raise ValueError(msg)
set_parts.append(f"{k} = ?")
params.append(v)
params.append(slot_id)
sql = f"UPDATE slots SET {', '.join(set_parts)} WHERE slot_id = ?"
self._conn.execute("BEGIN IMMEDIATE")
try:
self._conn.execute(sql, params)
self._record_event_inner(
"slot_fields_updated",
{"slot_id": slot_id, **fields},
)
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
# -- Reservation operations ---------------------------------------------
def create_reservation(

View file

@ -1 +1,148 @@
"""HAProxy provider unit tests — Plan 02."""
"""Unit tests for the HAProxy provider, mocking at socket level."""
from unittest.mock import MagicMock, patch
import pytest
from nix_builder_autoscaler.providers.haproxy import HAProxyError, HAProxyRuntime
# HAProxy `show stat` CSV — trimmed to columns the parser uses.
# Full output has many more columns; we keep through `status` (col 17).
SHOW_STAT_CSV = (
"# pxname,svname,qcur,qmax,scur,smax,slim,stot,"
"bin,bout,dreq,dresp,ereq,econ,eresp,wretr,wredis,status\n"
"all,BACKEND,0,0,2,5,200,100,5000,6000,0,0,,0,0,0,0,UP\n"
"all,slot001,0,0,1,3,50,50,2500,3000,0,0,,0,0,0,0,UP\n"
"all,slot002,0,0,1,2,50,50,2500,3000,0,0,,0,0,0,0,DOWN\n"
"all,slot003,0,0,0,0,50,0,0,0,0,0,,0,0,0,0,MAINT\n"
)
class TestSetSlotAddr:
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
def test_sends_correct_command(self, mock_socket_cls):
mock_sock = MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.recv.side_effect = [b"\n", b""]
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
h.set_slot_addr("slot001", "100.64.0.1", 22)
mock_sock.connect.assert_called_once_with("/tmp/test.sock")
mock_sock.sendall.assert_called_once_with(
b"set server all/slot001 addr 100.64.0.1 port 22\n"
)
class TestEnableSlot:
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
def test_sends_correct_command(self, mock_socket_cls):
mock_sock = MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.recv.side_effect = [b"\n", b""]
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
h.enable_slot("slot001")
mock_sock.sendall.assert_called_once_with(b"enable server all/slot001\n")
class TestDisableSlot:
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
def test_sends_correct_command(self, mock_socket_cls):
mock_sock = MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.recv.side_effect = [b"\n", b""]
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
h.disable_slot("slot001")
mock_sock.sendall.assert_called_once_with(b"disable server all/slot001\n")
class TestReadSlotHealth:
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
def test_parses_csv_correctly(self, mock_socket_cls):
mock_sock = MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.recv.side_effect = [SHOW_STAT_CSV.encode(), b""]
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
health = h.read_slot_health()
assert len(health) == 3
# BACKEND row should be excluded (svname "BACKEND" doesn't start with "slot")
assert health["slot001"].status == "UP"
assert health["slot001"].scur == 1
assert health["slot001"].qcur == 0
assert health["slot002"].status == "DOWN"
assert health["slot002"].scur == 1
assert health["slot002"].qcur == 0
assert health["slot003"].status == "MAINT"
assert health["slot003"].scur == 0
assert health["slot003"].qcur == 0
class TestSlotIsUp:
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
def test_returns_true_for_up_slot(self, mock_socket_cls):
mock_sock = MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.recv.side_effect = [SHOW_STAT_CSV.encode(), b""]
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
assert h.slot_is_up("slot001") is True
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
def test_returns_false_for_down_slot(self, mock_socket_cls):
mock_sock = MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.recv.side_effect = [SHOW_STAT_CSV.encode(), b""]
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
assert h.slot_is_up("slot002") is False
class TestErrorHandling:
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
def test_unrecognized_slot_raises_haproxy_error(self, mock_socket_cls):
mock_sock = MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.recv.side_effect = [b"No such server.\n", b""]
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
with pytest.raises(HAProxyError, match="No such server"):
h.set_slot_addr("slot999", "100.64.0.1")
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
def test_socket_not_found_raises_haproxy_error(self, mock_socket_cls):
mock_sock = MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.connect.side_effect = FileNotFoundError("No such file")
h = HAProxyRuntime("/tmp/nonexistent.sock", "all", "slot")
with pytest.raises(HAProxyError, match="socket not found"):
h.set_slot_addr("slot001", "100.64.0.1")
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
def test_connection_refused_raises_haproxy_error(self, mock_socket_cls):
mock_sock = MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.connect.side_effect = ConnectionRefusedError("Connection refused")
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
with pytest.raises(HAProxyError, match="Connection refused"):
h.enable_slot("slot001")
@patch("nix_builder_autoscaler.providers.haproxy.socket.socket")
def test_slot_session_count_missing_slot_raises(self, mock_socket_cls):
mock_sock = MagicMock()
mock_socket_cls.return_value = mock_sock
mock_sock.recv.side_effect = [SHOW_STAT_CSV.encode(), b""]
h = HAProxyRuntime("/tmp/test.sock", "all", "slot")
with pytest.raises(HAProxyError, match="Slot not found"):
h.slot_session_count("slot_nonexistent")

View file

@ -1 +1,286 @@
"""EC2 runtime unit tests — Plan 02."""
"""Unit tests for the EC2 runtime adapter using botocore Stubber."""
from datetime import UTC, datetime
from unittest.mock import patch
import boto3
import pytest
from botocore.stub import Stubber
from nix_builder_autoscaler.config import AwsConfig
from nix_builder_autoscaler.runtime.base import RuntimeError as RuntimeAdapterError
from nix_builder_autoscaler.runtime.ec2 import EC2Runtime
def _make_config():
return AwsConfig(
region="us-east-1",
launch_template_id="lt-abc123",
subnet_ids=["subnet-aaa", "subnet-bbb"],
security_group_ids=["sg-111"],
instance_profile_arn="arn:aws:iam::123456789012:instance-profile/nix-builder",
)
def _make_runtime(stubber, ec2_client, **kwargs):
config = kwargs.pop("config", _make_config())
environment = kwargs.pop("environment", "dev")
stubber.activate()
return EC2Runtime(config, environment=environment, _client=ec2_client)
class TestLaunchSpot:
def test_correct_params_and_returns_instance_id(self):
config = _make_config()
ec2_client = boto3.client("ec2", region_name="us-east-1")
stubber = Stubber(ec2_client)
expected_params = {
"MinCount": 1,
"MaxCount": 1,
"LaunchTemplate": {
"LaunchTemplateId": "lt-abc123",
"Version": "$Latest",
},
"InstanceMarketOptions": {
"MarketType": "spot",
"SpotOptions": {
"SpotInstanceType": "one-time",
"InstanceInterruptionBehavior": "terminate",
},
},
"SubnetId": "subnet-aaa",
"UserData": "#!/bin/bash\necho hello",
"TagSpecifications": [
{
"ResourceType": "instance",
"Tags": [
{"Key": "Name", "Value": "nix-builder-slot001"},
{"Key": "AutoscalerSlot", "Value": "slot001"},
{"Key": "ManagedBy", "Value": "nix-builder-autoscaler"},
{"Key": "Service", "Value": "nix-builder"},
{"Key": "Environment", "Value": "dev"},
],
}
],
}
response = {
"Instances": [{"InstanceId": "i-12345678"}],
"OwnerId": "123456789012",
}
stubber.add_response("run_instances", response, expected_params)
runtime = _make_runtime(stubber, ec2_client, config=config)
iid = runtime.launch_spot("slot001", "#!/bin/bash\necho hello")
assert iid == "i-12345678"
stubber.assert_no_pending_responses()
def test_round_robin_subnets(self):
config = _make_config()
ec2_client = boto3.client("ec2", region_name="us-east-1")
stubber = Stubber(ec2_client)
# Two launches should use subnet-aaa then subnet-bbb
for _subnet in ["subnet-aaa", "subnet-bbb"]:
stubber.add_response(
"run_instances",
{"Instances": [{"InstanceId": "i-abc"}], "OwnerId": "123"},
)
runtime = _make_runtime(stubber, ec2_client, config=config)
runtime.launch_spot("slot001", "")
runtime.launch_spot("slot002", "")
stubber.assert_no_pending_responses()
class TestDescribeInstance:
def test_normalizes_response(self):
ec2_client = boto3.client("ec2", region_name="us-east-1")
stubber = Stubber(ec2_client)
launch_time = datetime(2026, 1, 15, 12, 30, 0, tzinfo=UTC)
response = {
"Reservations": [
{
"Instances": [
{
"InstanceId": "i-running1",
"State": {"Code": 16, "Name": "running"},
"LaunchTime": launch_time,
"Tags": [
{"Key": "AutoscalerSlot", "Value": "slot001"},
],
}
],
}
],
}
stubber.add_response(
"describe_instances",
response,
{"InstanceIds": ["i-running1"]},
)
runtime = _make_runtime(stubber, ec2_client)
info = runtime.describe_instance("i-running1")
assert info["state"] == "running"
assert info["tailscale_ip"] is None
assert info["launch_time"] == launch_time.isoformat()
def test_missing_instance_returns_terminated(self):
ec2_client = boto3.client("ec2", region_name="us-east-1")
stubber = Stubber(ec2_client)
stubber.add_response(
"describe_instances",
{"Reservations": []},
{"InstanceIds": ["i-gone"]},
)
runtime = _make_runtime(stubber, ec2_client)
info = runtime.describe_instance("i-gone")
assert info["state"] == "terminated"
assert info["tailscale_ip"] is None
assert info["launch_time"] is None
class TestListManagedInstances:
def test_filters_by_tag(self):
ec2_client = boto3.client("ec2", region_name="us-east-1")
stubber = Stubber(ec2_client)
expected_params = {
"Filters": [
{"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]},
{
"Name": "instance-state-name",
"Values": ["pending", "running", "shutting-down", "stopping"],
},
],
}
response = {
"Reservations": [
{
"Instances": [
{
"InstanceId": "i-aaa",
"State": {"Code": 16, "Name": "running"},
"Tags": [
{"Key": "AutoscalerSlot", "Value": "slot001"},
{"Key": "ManagedBy", "Value": "nix-builder-autoscaler"},
],
},
{
"InstanceId": "i-bbb",
"State": {"Code": 0, "Name": "pending"},
"Tags": [
{"Key": "AutoscalerSlot", "Value": "slot002"},
{"Key": "ManagedBy", "Value": "nix-builder-autoscaler"},
],
},
],
}
],
}
stubber.add_response("describe_instances", response, expected_params)
runtime = _make_runtime(stubber, ec2_client)
managed = runtime.list_managed_instances()
assert len(managed) == 2
assert managed[0]["instance_id"] == "i-aaa"
assert managed[0]["state"] == "running"
assert managed[0]["slot_id"] == "slot001"
assert managed[1]["instance_id"] == "i-bbb"
assert managed[1]["state"] == "pending"
assert managed[1]["slot_id"] == "slot002"
class TestTerminateInstance:
def test_calls_terminate_api(self):
ec2_client = boto3.client("ec2", region_name="us-east-1")
stubber = Stubber(ec2_client)
response = {
"TerminatingInstances": [
{
"InstanceId": "i-kill",
"CurrentState": {"Code": 32, "Name": "shutting-down"},
"PreviousState": {"Code": 16, "Name": "running"},
}
],
}
stubber.add_response(
"terminate_instances",
response,
{"InstanceIds": ["i-kill"]},
)
runtime = _make_runtime(stubber, ec2_client)
# Should not raise
runtime.terminate_instance("i-kill")
stubber.assert_no_pending_responses()
class TestErrorClassification:
def test_insufficient_capacity_classified(self):
ec2_client = boto3.client("ec2", region_name="us-east-1")
stubber = Stubber(ec2_client)
stubber.add_client_error(
"run_instances",
service_error_code="InsufficientInstanceCapacity",
service_message="Insufficient capacity",
)
runtime = _make_runtime(stubber, ec2_client)
with pytest.raises(RuntimeAdapterError) as exc_info:
runtime.launch_spot("slot001", "#!/bin/bash")
assert exc_info.value.category == "capacity_unavailable"
@patch("nix_builder_autoscaler.runtime.ec2.time.sleep")
def test_request_limit_exceeded_retried(self, mock_sleep):
ec2_client = boto3.client("ec2", region_name="us-east-1")
stubber = Stubber(ec2_client)
# First call: throttled
stubber.add_client_error(
"run_instances",
service_error_code="RequestLimitExceeded",
service_message="Rate exceeded",
)
# Second call: success
stubber.add_response(
"run_instances",
{"Instances": [{"InstanceId": "i-retry123"}], "OwnerId": "123"},
)
runtime = _make_runtime(stubber, ec2_client)
iid = runtime.launch_spot("slot001", "#!/bin/bash")
assert iid == "i-retry123"
assert mock_sleep.called
stubber.assert_no_pending_responses()
@patch("nix_builder_autoscaler.runtime.ec2.time.sleep")
def test_request_limit_exceeded_exhausted(self, mock_sleep):
"""After max retries, RequestLimitExceeded raises with 'throttled' category."""
ec2_client = boto3.client("ec2", region_name="us-east-1")
stubber = Stubber(ec2_client)
# 4 errors (max_retries=3: attempt 0,1,2,3 all fail)
for _ in range(4):
stubber.add_client_error(
"run_instances",
service_error_code="RequestLimitExceeded",
service_message="Rate exceeded",
)
runtime = _make_runtime(stubber, ec2_client)
with pytest.raises(RuntimeAdapterError) as exc_info:
runtime.launch_spot("slot001", "#!/bin/bash")
assert exc_info.value.category == "throttled"

View file

@ -1 +1,194 @@
"""Scheduler unit tests — Plan 03."""
from nix_builder_autoscaler.config import AppConfig, AwsConfig, CapacityConfig
from nix_builder_autoscaler.metrics import MetricsRegistry
from nix_builder_autoscaler.models import ReservationPhase, SlotState
from nix_builder_autoscaler.providers.clock import FakeClock
from nix_builder_autoscaler.runtime.fake import FakeRuntime
from nix_builder_autoscaler.scheduler import scheduling_tick
from nix_builder_autoscaler.state_db import StateDB
def _make_env(
slot_count=3,
max_slots=3,
max_leases=1,
idle_scale_down_seconds=900,
target_warm=0,
min_slots=0,
):
clock = FakeClock()
db = StateDB(":memory:", clock=clock)
db.init_schema()
db.init_slots("slot", slot_count, "x86_64-linux", "all")
runtime = FakeRuntime(launch_latency_ticks=2, ip_delay_ticks=1)
config = AppConfig(
capacity=CapacityConfig(
max_slots=max_slots,
max_leases_per_slot=max_leases,
idle_scale_down_seconds=idle_scale_down_seconds,
target_warm_slots=target_warm,
min_slots=min_slots,
reservation_ttl_seconds=1200,
),
aws=AwsConfig(region="us-east-1"),
)
metrics = MetricsRegistry()
return db, runtime, config, clock, metrics
def _make_slot_ready(db, slot_id, instance_id="i-test1", ip="100.64.0.1"):
"""Transition a slot through the full state machine to ready."""
db.update_slot_state(slot_id, SlotState.LAUNCHING, instance_id=instance_id)
db.update_slot_state(slot_id, SlotState.BOOTING)
db.update_slot_state(slot_id, SlotState.BINDING, instance_ip=ip)
db.update_slot_state(slot_id, SlotState.READY)
# --- Test cases ---
def test_pending_reservation_assigned_to_ready_slot():
db, runtime, config, clock, metrics = _make_env()
_make_slot_ready(db, "slot001")
resv = db.create_reservation("x86_64-linux", "test", None, 1200)
scheduling_tick(db, runtime, config, clock, metrics)
updated = db.get_reservation(resv["reservation_id"])
assert updated["phase"] == ReservationPhase.READY.value
assert updated["slot_id"] == "slot001"
assert updated["instance_id"] == "i-test1"
slot = db.get_slot("slot001")
assert slot["lease_count"] == 1
def test_two_pending_one_slot_only_one_assigned_per_tick():
db, runtime, config, clock, metrics = _make_env(max_leases=1)
_make_slot_ready(db, "slot001")
r1 = db.create_reservation("x86_64-linux", "test1", None, 1200)
r2 = db.create_reservation("x86_64-linux", "test2", None, 1200)
scheduling_tick(db, runtime, config, clock, metrics)
u1 = db.get_reservation(r1["reservation_id"])
u2 = db.get_reservation(r2["reservation_id"])
ready_count = sum(1 for r in [u1, u2] if r["phase"] == ReservationPhase.READY.value)
pending_count = sum(1 for r in [u1, u2] if r["phase"] == ReservationPhase.PENDING.value)
assert ready_count == 1
assert pending_count == 1
slot = db.get_slot("slot001")
assert slot["lease_count"] == 1
def test_reservation_expires_when_ttl_passes():
db, runtime, config, clock, metrics = _make_env()
config.capacity.reservation_ttl_seconds = 60
db.create_reservation("x86_64-linux", "test", None, 60)
clock.advance(61)
scheduling_tick(db, runtime, config, clock, metrics)
reservations = db.list_reservations(ReservationPhase.EXPIRED)
assert len(reservations) == 1
def test_scale_down_starts_when_idle_exceeds_threshold():
db, runtime, config, clock, metrics = _make_env(idle_scale_down_seconds=900)
_make_slot_ready(db, "slot001")
clock.advance(901)
scheduling_tick(db, runtime, config, clock, metrics)
slot = db.get_slot("slot001")
assert slot["state"] == SlotState.DRAINING.value
def test_slot_does_not_drain_while_lease_count_positive():
db, runtime, config, clock, metrics = _make_env(idle_scale_down_seconds=900)
_make_slot_ready(db, "slot001")
resv = db.create_reservation("x86_64-linux", "test", None, 1200)
scheduling_tick(db, runtime, config, clock, metrics)
# Confirm assigned
updated = db.get_reservation(resv["reservation_id"])
assert updated["phase"] == ReservationPhase.READY.value
clock.advance(901)
scheduling_tick(db, runtime, config, clock, metrics)
slot = db.get_slot("slot001")
assert slot["state"] == SlotState.READY.value
def test_interruption_pending_slot_moves_to_draining():
db, runtime, config, clock, metrics = _make_env()
_make_slot_ready(db, "slot001")
db.update_slot_fields("slot001", interruption_pending=1)
scheduling_tick(db, runtime, config, clock, metrics)
slot = db.get_slot("slot001")
assert slot["state"] == SlotState.DRAINING.value
assert slot["interruption_pending"] == 0
def test_launch_triggered_for_unmet_demand():
db, runtime, config, clock, metrics = _make_env()
db.create_reservation("x86_64-linux", "test", None, 1200)
scheduling_tick(db, runtime, config, clock, metrics)
launching = db.list_slots(SlotState.LAUNCHING)
assert len(launching) == 1
assert launching[0]["instance_id"] is not None
# FakeRuntime should have one pending instance
managed = runtime.list_managed_instances()
assert len(managed) == 1
def test_launch_respects_max_slots():
db, runtime, config, clock, metrics = _make_env(max_slots=1)
_make_slot_ready(db, "slot001")
# Slot001 is at capacity (lease_count will be 1 after assignment)
db.create_reservation("x86_64-linux", "test1", None, 1200)
db.create_reservation("x86_64-linux", "test2", None, 1200)
scheduling_tick(db, runtime, config, clock, metrics)
# One reservation assigned, one still pending — but no new launch
# because active_slots (1) == max_slots (1)
launching = db.list_slots(SlotState.LAUNCHING)
assert len(launching) == 0
def test_min_slots_maintained():
db, runtime, config, clock, metrics = _make_env(min_slots=1)
# No reservations, all slots empty
scheduling_tick(db, runtime, config, clock, metrics)
launching = db.list_slots(SlotState.LAUNCHING)
assert len(launching) == 1
def test_scale_down_respects_min_slots():
db, runtime, config, clock, metrics = _make_env(min_slots=1, idle_scale_down_seconds=900)
_make_slot_ready(db, "slot001")
clock.advance(901)
scheduling_tick(db, runtime, config, clock, metrics)
slot = db.get_slot("slot001")
assert slot["state"] == SlotState.READY.value

View file

@ -31,7 +31,7 @@ line-length = 100
[tool.ruff.lint]
select = ["E", "F", "I", "UP", "B", "SIM", "ANN"]
ignore = []
ignore = ["ANN401"]
[tool.ruff.lint.per-file-ignores]
"*/tests/*" = ["ANN"]