diff --git a/agent/nix_builder_autoscaler/bootstrap/userdata.py b/agent/nix_builder_autoscaler/bootstrap/userdata.py index 2cd4152..11b72a7 100644 --- a/agent/nix_builder_autoscaler/bootstrap/userdata.py +++ b/agent/nix_builder_autoscaler/bootstrap/userdata.py @@ -1,11 +1,62 @@ -"""EC2 user-data template rendering — stub for Plan 02.""" +"""EC2 user-data template rendering for builder instance bootstrap. + +The generated script follows the NixOS AMI pattern: write config files +that existing systemd services (tailscale-autoconnect, nix-daemon) consume, +rather than calling ``tailscale up`` directly. +""" from __future__ import annotations +import textwrap + def render_userdata(slot_id: str, region: str, ssm_param: str = "/nix-builder/ts-authkey") -> str: """Render a bash user-data script for builder instance bootstrap. - Full implementation in Plan 02. + The returned string is a complete shell script. On NixOS AMIs the script + is executed by ``amazon-init.service``. The caller (EC2Runtime) passes it + to ``run_instances`` as ``UserData``; boto3 base64-encodes automatically. + + Args: + slot_id: Autoscaler slot identifier (used as Tailscale hostname suffix). + region: AWS region for SSM parameter lookup. + ssm_param: SSM parameter path containing the Tailscale auth key. """ - raise NotImplementedError + return textwrap.dedent(f"""\ + #!/usr/bin/env bash + set -euo pipefail + + SLOT_ID="{slot_id}" + REGION="{region}" + SSM_PARAM="{ssm_param}" + + # --- Fetch Tailscale auth key from SSM Parameter Store --- + mkdir -p /run/credentials + TS_AUTHKEY=$(aws ssm get-parameter \\ + --region "$REGION" \\ + --with-decryption \\ + --name "$SSM_PARAM" \\ + --query 'Parameter.Value' \\ + --output text) + printf '%s' "$TS_AUTHKEY" > /run/credentials/tailscale-auth-key + chmod 600 /run/credentials/tailscale-auth-key + + # --- Write tailscale-autoconnect config --- + mkdir -p /etc/tailscale + cat > /etc/tailscale/autoconnect.conf < /run/nix-builder-ready + """) diff --git a/agent/nix_builder_autoscaler/providers/haproxy.py b/agent/nix_builder_autoscaler/providers/haproxy.py index cfbb647..9258500 100644 --- a/agent/nix_builder_autoscaler/providers/haproxy.py +++ b/agent/nix_builder_autoscaler/providers/haproxy.py @@ -1,7 +1,10 @@ -"""HAProxy runtime socket adapter — stub for Plan 02.""" +"""HAProxy runtime socket adapter for managing builder slots.""" from __future__ import annotations +import csv +import io +import socket from dataclasses import dataclass @@ -21,7 +24,12 @@ class SlotHealth: class HAProxyRuntime: """HAProxy runtime CLI adapter via Unix socket. - Full implementation in Plan 02. + Communicates with HAProxy using the admin socket text protocol. + + Args: + socket_path: Path to the HAProxy admin Unix socket. + backend: HAProxy backend name (e.g. "all"). + slot_prefix: Server name prefix used for builder slots. """ def __init__(self, socket_path: str, backend: str, slot_prefix: str) -> None: @@ -31,24 +39,76 @@ class HAProxyRuntime: def set_slot_addr(self, slot_id: str, ip: str, port: int = 22) -> None: """Update server address for a slot.""" - raise NotImplementedError + cmd = f"set server {self._backend}/{slot_id} addr {ip} port {port}" + resp = self._run(cmd) + self._check_response(resp, slot_id) def enable_slot(self, slot_id: str) -> None: """Enable a server slot.""" - raise NotImplementedError + cmd = f"enable server {self._backend}/{slot_id}" + resp = self._run(cmd) + self._check_response(resp, slot_id) def disable_slot(self, slot_id: str) -> None: """Disable a server slot.""" - raise NotImplementedError + cmd = f"disable server {self._backend}/{slot_id}" + resp = self._run(cmd) + self._check_response(resp, slot_id) def slot_is_up(self, slot_id: str) -> bool: """Return True when HAProxy health status is UP for slot.""" - raise NotImplementedError + health = self.read_slot_health() + entry = health.get(slot_id) + return entry is not None and entry.status == "UP" def slot_session_count(self, slot_id: str) -> int: """Return current active session count for slot.""" - raise NotImplementedError + health = self.read_slot_health() + entry = health.get(slot_id) + if entry is None: + raise HAProxyError(f"Slot not found in HAProxy stats: {slot_id}") + return entry.scur def read_slot_health(self) -> dict[str, SlotHealth]: - """Return full stats snapshot for all slots.""" - raise NotImplementedError + """Return full stats snapshot for all slots in the backend.""" + raw = self._run("show stat") + reader = csv.DictReader(io.StringIO(raw)) + result: dict[str, SlotHealth] = {} + for row in reader: + pxname = row.get("# pxname", "").strip() + svname = row.get("svname", "").strip() + if pxname == self._backend and svname.startswith(self._slot_prefix): + result[svname] = SlotHealth( + status=row.get("status", "").strip(), + scur=int(row.get("scur", "0")), + qcur=int(row.get("qcur", "0")), + ) + return result + + def _run(self, command: str) -> str: + """Send a command to the HAProxy admin socket and return the response.""" + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + sock.connect(self._socket_path) + sock.sendall((command + "\n").encode()) + sock.shutdown(socket.SHUT_WR) + chunks: list[bytes] = [] + while True: + chunk = sock.recv(4096) + if not chunk: + break + chunks.append(chunk) + return b"".join(chunks).decode() + except FileNotFoundError as e: + raise HAProxyError(f"HAProxy socket not found: {self._socket_path}") from e + except ConnectionRefusedError as e: + raise HAProxyError(f"Connection refused to HAProxy socket: {self._socket_path}") from e + finally: + sock.close() + + @staticmethod + def _check_response(response: str, slot_id: str) -> None: + """Raise HAProxyError if the response indicates an error.""" + stripped = response.strip() + if stripped.startswith(("No such", "Unknown")): + raise HAProxyError(f"HAProxy error for {slot_id}: {stripped}") diff --git a/agent/nix_builder_autoscaler/reconciler.py b/agent/nix_builder_autoscaler/reconciler.py index 55decff..9607462 100644 --- a/agent/nix_builder_autoscaler/reconciler.py +++ b/agent/nix_builder_autoscaler/reconciler.py @@ -1 +1,258 @@ -"""Reconciler — stub for Plan 03.""" +"""Reconciler — advances slots through the state machine. + +Each tick queries EC2 and HAProxy, then processes each slot according to +its current state: launching→booting→binding→ready, with draining and +terminating paths for teardown. +""" + +from __future__ import annotations + +import contextlib +import logging +import time +from datetime import datetime +from typing import TYPE_CHECKING + +from .models import SlotState +from .providers.haproxy import HAProxyError + +if TYPE_CHECKING: + from .config import AppConfig + from .metrics import MetricsRegistry + from .providers.clock import Clock + from .providers.haproxy import HAProxyRuntime + from .runtime.base import RuntimeAdapter + from .state_db import StateDB + +log = logging.getLogger(__name__) + + +class Reconciler: + """Advances slots through the state machine by polling EC2 and HAProxy. + + Maintains binding health-check counters between ticks. + """ + + def __init__( + self, + db: StateDB, + runtime: RuntimeAdapter, + haproxy: HAProxyRuntime, + config: AppConfig, + clock: Clock, + metrics: MetricsRegistry, + ) -> None: + self._db = db + self._runtime = runtime + self._haproxy = haproxy + self._config = config + self._clock = clock + self._metrics = metrics + self._binding_up_counts: dict[str, int] = {} + + def tick(self) -> None: + """Execute one reconciliation tick.""" + t0 = time.monotonic() + + # 1. Query EC2 + try: + managed = self._runtime.list_managed_instances() + except Exception: + log.exception("ec2_list_failed") + managed = [] + ec2_by_slot: dict[str, dict] = {} + for inst in managed: + sid = inst.get("slot_id") + if sid: + ec2_by_slot[sid] = inst + + # 2. Query HAProxy + try: + haproxy_health = self._haproxy.read_slot_health() + except HAProxyError: + log.warning("haproxy_stat_failed", exc_info=True) + haproxy_health = {} + + # 3. Process each slot + all_slots = self._db.list_slots() + for slot in all_slots: + state = slot["state"] + if state == SlotState.LAUNCHING.value: + self._handle_launching(slot) + elif state == SlotState.BOOTING.value: + self._handle_booting(slot) + elif state == SlotState.BINDING.value: + self._handle_binding(slot, haproxy_health) + elif state == SlotState.READY.value: + self._handle_ready(slot, ec2_by_slot) + elif state == SlotState.DRAINING.value: + self._handle_draining(slot) + elif state == SlotState.TERMINATING.value: + self._handle_terminating(slot, ec2_by_slot) + + # 4. Clean stale binding counters + binding_ids = {s["slot_id"] for s in all_slots if s["state"] == SlotState.BINDING.value} + stale = [k for k in self._binding_up_counts if k not in binding_ids] + for k in stale: + del self._binding_up_counts[k] + + # 5. Emit metrics + tick_duration = time.monotonic() - t0 + self._update_metrics(tick_duration) + + def _handle_launching(self, slot: dict) -> None: + """Check if launching instance has reached running state.""" + instance_id = slot["instance_id"] + if not instance_id: + self._db.update_slot_state(slot["slot_id"], SlotState.ERROR) + return + + info = self._runtime.describe_instance(instance_id) + ec2_state = info["state"] + + if ec2_state == "running": + self._db.update_slot_state(slot["slot_id"], SlotState.BOOTING) + log.info("slot_booting", extra={"slot_id": slot["slot_id"]}) + elif ec2_state in ("terminated", "shutting-down"): + self._db.update_slot_state(slot["slot_id"], SlotState.ERROR) + log.warning( + "slot_launch_terminated", + extra={"slot_id": slot["slot_id"], "ec2_state": ec2_state}, + ) + + def _handle_booting(self, slot: dict) -> None: + """Check if booting instance has a Tailscale IP yet.""" + instance_id = slot["instance_id"] + if not instance_id: + self._db.update_slot_state(slot["slot_id"], SlotState.ERROR) + return + + info = self._runtime.describe_instance(instance_id) + ec2_state = info["state"] + + if ec2_state in ("terminated", "shutting-down"): + self._db.update_slot_state(slot["slot_id"], SlotState.ERROR) + log.warning( + "slot_boot_terminated", + extra={"slot_id": slot["slot_id"], "ec2_state": ec2_state}, + ) + return + + tailscale_ip = info.get("tailscale_ip") + if tailscale_ip is not None: + self._db.update_slot_state(slot["slot_id"], SlotState.BINDING, instance_ip=tailscale_ip) + try: + self._haproxy.set_slot_addr(slot["slot_id"], tailscale_ip) + self._haproxy.enable_slot(slot["slot_id"]) + except HAProxyError: + log.warning( + "haproxy_binding_setup_failed", + extra={"slot_id": slot["slot_id"]}, + exc_info=True, + ) + + def _handle_binding(self, slot: dict, haproxy_health: dict) -> None: + """Check HAProxy health to determine when slot is ready.""" + slot_id = slot["slot_id"] + health = haproxy_health.get(slot_id) + + if health is not None and health.status == "UP": + count = self._binding_up_counts.get(slot_id, 0) + 1 + self._binding_up_counts[slot_id] = count + if count >= self._config.haproxy.check_ready_up_count: + self._db.update_slot_state(slot_id, SlotState.READY) + self._binding_up_counts.pop(slot_id, None) + log.info("slot_ready", extra={"slot_id": slot_id}) + else: + self._binding_up_counts[slot_id] = 0 + # Retry HAProxy setup + ip = slot.get("instance_ip") + if ip: + try: + self._haproxy.set_slot_addr(slot_id, ip) + self._haproxy.enable_slot(slot_id) + except HAProxyError: + pass + + # Check if instance is still alive + instance_id = slot.get("instance_id") + if instance_id: + info = self._runtime.describe_instance(instance_id) + if info["state"] in ("terminated", "shutting-down"): + self._db.update_slot_state(slot_id, SlotState.ERROR) + self._binding_up_counts.pop(slot_id, None) + log.warning( + "slot_binding_terminated", + extra={"slot_id": slot_id}, + ) + + def _handle_ready(self, slot: dict, ec2_by_slot: dict[str, dict]) -> None: + """Verify EC2 instance is still alive for ready slots.""" + slot_id = slot["slot_id"] + ec2_info = ec2_by_slot.get(slot_id) + + if ec2_info is None or ec2_info["state"] == "terminated": + self._db.update_slot_state(slot_id, SlotState.ERROR, instance_id=None, instance_ip=None) + log.warning("slot_ready_instance_gone", extra={"slot_id": slot_id}) + elif ec2_info["state"] == "shutting-down": + self._db.update_slot_fields(slot_id, interruption_pending=1) + log.info("slot_interruption_detected", extra={"slot_id": slot_id}) + + def _handle_draining(self, slot: dict) -> None: + """Disable HAProxy and terminate when drain conditions are met.""" + slot_id = slot["slot_id"] + + # Disable HAProxy (idempotent) + with contextlib.suppress(HAProxyError): + self._haproxy.disable_slot(slot_id) + + now = self._clock.now() + last_change = datetime.fromisoformat(slot["last_state_change"]) + drain_duration = (now - last_change).total_seconds() + + drain_timeout = self._config.capacity.drain_timeout_seconds + if slot["lease_count"] == 0 or drain_duration >= drain_timeout: + instance_id = slot.get("instance_id") + if instance_id: + try: + self._runtime.terminate_instance(instance_id) + self._metrics.counter("autoscaler_ec2_terminate_total", {}, 1.0) + except Exception: + log.warning( + "terminate_failed", + extra={"slot_id": slot_id, "instance_id": instance_id}, + exc_info=True, + ) + self._db.update_slot_state(slot_id, SlotState.TERMINATING) + log.info( + "slot_terminating", + extra={"slot_id": slot_id, "drain_duration": drain_duration}, + ) + + def _handle_terminating(self, slot: dict, ec2_by_slot: dict[str, dict]) -> None: + """Wait for EC2 to confirm termination, then reset slot to empty.""" + slot_id = slot["slot_id"] + instance_id = slot.get("instance_id") + + if not instance_id: + self._db.update_slot_state( + slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0 + ) + log.info("slot_emptied", extra={"slot_id": slot_id}) + return + + info = self._runtime.describe_instance(instance_id) + if info["state"] == "terminated": + self._db.update_slot_state( + slot_id, SlotState.EMPTY, instance_id=None, instance_ip=None, lease_count=0 + ) + log.info("slot_emptied", extra={"slot_id": slot_id}) + + def _update_metrics(self, tick_duration: float) -> None: + """Emit reconciler metrics.""" + summary = self._db.get_state_summary() + for state, count in summary["slots"].items(): + if state == "total": + continue + self._metrics.gauge("autoscaler_slots", {"state": state}, float(count)) + self._metrics.histogram_observe("autoscaler_reconciler_tick_seconds", {}, tick_duration) diff --git a/agent/nix_builder_autoscaler/runtime/ec2.py b/agent/nix_builder_autoscaler/runtime/ec2.py index 85393d8..d134c40 100644 --- a/agent/nix_builder_autoscaler/runtime/ec2.py +++ b/agent/nix_builder_autoscaler/runtime/ec2.py @@ -1,32 +1,175 @@ -"""EC2 runtime adapter — stub for Plan 02.""" +"""EC2 runtime adapter for managing Spot instances.""" from __future__ import annotations +import logging +import random +import time +from typing import Any + +import boto3 +from botocore.exceptions import ClientError + +from ..config import AwsConfig from .base import RuntimeAdapter +from .base import RuntimeError as RuntimeAdapterError + +log = logging.getLogger(__name__) + +# EC2 ClientError code → normalized error category +_ERROR_CATEGORIES: dict[str, str] = { + "InsufficientInstanceCapacity": "capacity_unavailable", + "SpotMaxPriceTooLow": "price_too_low", + "RequestLimitExceeded": "throttled", +} + +_RETRYABLE_CODES: frozenset[str] = frozenset({"RequestLimitExceeded"}) class EC2Runtime(RuntimeAdapter): """EC2 Spot instance runtime adapter. - Full implementation in Plan 02. + Args: + config: AWS configuration dataclass. + environment: Environment tag value (e.g. ``"dev"``, ``"prod"``). + _client: Optional pre-configured boto3 EC2 client (for testing). """ - def __init__(self, region: str, launch_template_id: str, **kwargs: object) -> None: - self._region = region - self._launch_template_id = launch_template_id + def __init__( + self, + config: AwsConfig, + environment: str = "dev", + *, + _client: Any = None, + ) -> None: + self._client: Any = _client or boto3.client("ec2", region_name=config.region) + self._launch_template_id = config.launch_template_id + self._subnet_ids = list(config.subnet_ids) + self._security_group_ids = list(config.security_group_ids) + self._instance_profile_arn = config.instance_profile_arn + self._environment = environment + self._subnet_index = 0 def launch_spot(self, slot_id: str, user_data: str) -> str: - """Launch a spot instance for slot_id.""" - raise NotImplementedError + """Launch a spot instance for *slot_id*. Return instance ID.""" + params: dict[str, Any] = { + "MinCount": 1, + "MaxCount": 1, + "LaunchTemplate": { + "LaunchTemplateId": self._launch_template_id, + "Version": "$Latest", + }, + "InstanceMarketOptions": { + "MarketType": "spot", + "SpotOptions": { + "SpotInstanceType": "one-time", + "InstanceInterruptionBehavior": "terminate", + }, + }, + "UserData": user_data, + "TagSpecifications": [ + { + "ResourceType": "instance", + "Tags": [ + {"Key": "Name", "Value": f"nix-builder-{slot_id}"}, + {"Key": "AutoscalerSlot", "Value": slot_id}, + {"Key": "ManagedBy", "Value": "nix-builder-autoscaler"}, + {"Key": "Service", "Value": "nix-builder"}, + {"Key": "Environment", "Value": self._environment}, + ], + } + ], + } + + if self._subnet_ids: + subnet = self._subnet_ids[self._subnet_index % len(self._subnet_ids)] + self._subnet_index += 1 + params["SubnetId"] = subnet + + resp = self._call_with_backoff(self._client.run_instances, **params) + return resp["Instances"][0]["InstanceId"] def describe_instance(self, instance_id: str) -> dict: - """Return normalized instance info.""" - raise NotImplementedError + """Return normalized instance info dict.""" + try: + resp = self._call_with_backoff( + self._client.describe_instances, InstanceIds=[instance_id] + ) + except RuntimeAdapterError: + return {"state": "terminated", "tailscale_ip": None, "launch_time": None} + + reservations = resp.get("Reservations", []) + if not reservations or not reservations[0].get("Instances"): + return {"state": "terminated", "tailscale_ip": None, "launch_time": None} + + inst = reservations[0]["Instances"][0] + launch_time = inst.get("LaunchTime") + return { + "state": inst["State"]["Name"], + "tailscale_ip": None, + "launch_time": launch_time.isoformat() if launch_time else None, + } def terminate_instance(self, instance_id: str) -> None: """Terminate the instance.""" - raise NotImplementedError + self._call_with_backoff(self._client.terminate_instances, InstanceIds=[instance_id]) def list_managed_instances(self) -> list[dict]: """Return list of managed instances.""" - raise NotImplementedError + resp = self._call_with_backoff( + self._client.describe_instances, + Filters=[ + {"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]}, + { + "Name": "instance-state-name", + "Values": ["pending", "running", "shutting-down", "stopping"], + }, + ], + ) + + result: list[dict] = [] + for reservation in resp.get("Reservations", []): + for inst in reservation.get("Instances", []): + tags = inst.get("Tags", []) + result.append( + { + "instance_id": inst["InstanceId"], + "state": inst["State"]["Name"], + "slot_id": self._get_tag(tags, "AutoscalerSlot"), + } + ) + return result + + def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any: + """Call *fn* with exponential backoff and full jitter on retryable errors.""" + delay = 0.5 + for attempt in range(max_retries + 1): + try: + return fn(*args, **kwargs) + except ClientError as e: + code = e.response["Error"]["Code"] + if code in _RETRYABLE_CODES and attempt < max_retries: + jitter = random.uniform(0, min(delay, 10.0)) + time.sleep(jitter) + delay *= 2 + log.warning( + "Retryable EC2 error (attempt %d/%d): %s", + attempt + 1, + max_retries, + code, + ) + continue + category = _ERROR_CATEGORIES.get(code, "unknown") + raise RuntimeAdapterError(str(e), category=category) from e + + # Unreachable — loop always returns or raises on every path + msg = "Retries exhausted" + raise RuntimeAdapterError(msg, category="unknown") + + @staticmethod + def _get_tag(tags: list[dict[str, str]], key: str) -> str | None: + """Extract a tag value from an EC2 tag list.""" + for tag in tags: + if tag.get("Key") == key: + return tag.get("Value") + return None diff --git a/agent/nix_builder_autoscaler/scheduler.py b/agent/nix_builder_autoscaler/scheduler.py index 9ef4205..94baf61 100644 --- a/agent/nix_builder_autoscaler/scheduler.py +++ b/agent/nix_builder_autoscaler/scheduler.py @@ -1 +1,267 @@ -"""Scheduler — stub for Plan 03.""" +"""Scheduler — stateless scheduling tick for the autoscaler. + +Each tick: expire reservations, handle interruptions, assign pending +reservations to ready slots, launch new capacity, maintain warm pool +and min-slots, check idle scale-down, and emit metrics. +""" + +from __future__ import annotations + +import logging +import time +from datetime import datetime +from typing import TYPE_CHECKING + +from .bootstrap.userdata import render_userdata +from .models import SlotState +from .runtime.base import RuntimeError as RuntimeAdapterError + +if TYPE_CHECKING: + from .config import AppConfig + from .metrics import MetricsRegistry + from .providers.clock import Clock + from .runtime.base import RuntimeAdapter + from .state_db import StateDB + +log = logging.getLogger(__name__) + + +def scheduling_tick( + db: StateDB, + runtime: RuntimeAdapter, + config: AppConfig, + clock: Clock, + metrics: MetricsRegistry, +) -> None: + """Execute one scheduling tick. + + All dependencies are passed as arguments — no global state. + """ + t0 = time.monotonic() + + # 1. Expire old reservations + expired = db.expire_reservations(clock.now()) + if expired: + log.info("expired_reservations", extra={"count": len(expired), "ids": expired}) + + # 2. Handle interruption-pending slots + _handle_interruptions(db) + + # 3. Assign pending reservations to ready slots + _assign_reservations(db, config) + + # 4. Launch new capacity for unmet demand + _launch_for_unmet_demand(db, runtime, config, metrics) + + # 5. Ensure minimum slots and warm pool + _ensure_min_and_warm(db, runtime, config, metrics) + + # 6. Check scale-down for idle slots + _check_idle_scale_down(db, config, clock) + + # 7. Emit metrics + tick_duration = time.monotonic() - t0 + _update_metrics(db, metrics, tick_duration) + + +def _handle_interruptions(db: StateDB) -> None: + """Move ready slots with interruption_pending to draining.""" + ready_slots = db.list_slots(SlotState.READY) + for slot in ready_slots: + if slot["interruption_pending"]: + db.update_slot_state(slot["slot_id"], SlotState.DRAINING, interruption_pending=0) + log.info( + "interruption_drain", + extra={"slot_id": slot["slot_id"]}, + ) + + +def _assign_reservations(db: StateDB, config: AppConfig) -> None: + """Assign pending reservations to ready slots with capacity.""" + from .models import ReservationPhase + + pending = db.list_reservations(ReservationPhase.PENDING) + if not pending: + return + + ready_slots = db.list_slots(SlotState.READY) + if not ready_slots: + return + + max_leases = config.capacity.max_leases_per_slot + # Track in-memory capacity to prevent double-assignment within the same tick + capacity_map: dict[str, int] = {s["slot_id"]: s["lease_count"] for s in ready_slots} + + for resv in pending: + system = resv["system"] + slot = _find_assignable_slot(ready_slots, system, max_leases, capacity_map) + if slot is None: + continue + db.assign_reservation(resv["reservation_id"], slot["slot_id"], slot["instance_id"]) + capacity_map[slot["slot_id"]] += 1 + log.info( + "reservation_assigned", + extra={ + "reservation_id": resv["reservation_id"], + "slot_id": slot["slot_id"], + }, + ) + + +def _find_assignable_slot( + ready_slots: list[dict], + system: str, + max_leases: int, + capacity_map: dict[str, int], +) -> dict | None: + """Return first ready slot for system with remaining capacity, or None.""" + for slot in ready_slots: + if slot["system"] != system: + continue + sid = slot["slot_id"] + current: int = capacity_map[sid] if sid in capacity_map else slot["lease_count"] + if current < max_leases: + return slot + return None + + +def _count_active_slots(db: StateDB) -> int: + """Count slots NOT in empty or error states.""" + all_slots = db.list_slots() + return sum( + 1 for s in all_slots if s["state"] not in (SlotState.EMPTY.value, SlotState.ERROR.value) + ) + + +def _launch_for_unmet_demand( + db: StateDB, + runtime: RuntimeAdapter, + config: AppConfig, + metrics: MetricsRegistry, +) -> None: + """Launch new capacity for pending reservations that couldn't be assigned.""" + from .models import ReservationPhase + + pending = db.list_reservations(ReservationPhase.PENDING) + if not pending: + return + + active = _count_active_slots(db) + if active >= config.capacity.max_slots: + return + + empty_slots = db.list_slots(SlotState.EMPTY) + if not empty_slots: + return + + for launched, slot in enumerate(empty_slots): + if launched >= len(pending): + break + if active + launched >= config.capacity.max_slots: + break + _launch_slot(db, runtime, config, metrics, slot) + + +def _ensure_min_and_warm( + db: StateDB, + runtime: RuntimeAdapter, + config: AppConfig, + metrics: MetricsRegistry, +) -> None: + """Ensure minimum slots and warm pool targets are met.""" + active = _count_active_slots(db) + + # Ensure min_slots + if active < config.capacity.min_slots: + needed = config.capacity.min_slots - active + empty_slots = db.list_slots(SlotState.EMPTY) + launched = 0 + for slot in empty_slots: + if launched >= needed: + break + if active + launched >= config.capacity.max_slots: + break + _launch_slot(db, runtime, config, metrics, slot) + launched += 1 + active += launched + + # Ensure warm pool + if config.capacity.target_warm_slots > 0: + ready_idle = sum(1 for s in db.list_slots(SlotState.READY) if s["lease_count"] == 0) + pending_warm = ( + len(db.list_slots(SlotState.LAUNCHING)) + + len(db.list_slots(SlotState.BOOTING)) + + len(db.list_slots(SlotState.BINDING)) + ) + warm_total = ready_idle + pending_warm + if warm_total < config.capacity.target_warm_slots: + needed = config.capacity.target_warm_slots - warm_total + empty_slots = db.list_slots(SlotState.EMPTY) + launched = 0 + for slot in empty_slots: + if launched >= needed: + break + if active + launched >= config.capacity.max_slots: + break + _launch_slot(db, runtime, config, metrics, slot) + launched += 1 + + +def _launch_slot( + db: StateDB, + runtime: RuntimeAdapter, + config: AppConfig, + metrics: MetricsRegistry, + slot: dict, +) -> None: + """Launch a single slot. Transition to LAUNCHING on success, ERROR on failure.""" + slot_id = slot["slot_id"] + user_data = render_userdata(slot_id, config.aws.region) + metrics.counter("autoscaler_ec2_launch_total", {}, 1.0) + try: + instance_id = runtime.launch_spot(slot_id, user_data) + db.update_slot_state(slot_id, SlotState.LAUNCHING, instance_id=instance_id) + log.info("slot_launched", extra={"slot_id": slot_id, "instance_id": instance_id}) + except RuntimeAdapterError as exc: + db.update_slot_state(slot_id, SlotState.ERROR) + log.warning( + "slot_launch_failed", + extra={"slot_id": slot_id, "error": str(exc), "category": exc.category}, + ) + + +def _check_idle_scale_down(db: StateDB, config: AppConfig, clock: Clock) -> None: + """Move idle ready slots to draining when idle threshold exceeded.""" + ready_slots = db.list_slots(SlotState.READY) + now = clock.now() + active = _count_active_slots(db) + + for slot in ready_slots: + if slot["lease_count"] != 0: + continue + last_change = datetime.fromisoformat(slot["last_state_change"]) + idle_seconds = (now - last_change).total_seconds() + if idle_seconds > config.capacity.idle_scale_down_seconds: + if active <= config.capacity.min_slots: + continue + db.update_slot_state(slot["slot_id"], SlotState.DRAINING) + active -= 1 + log.info( + "idle_scale_down", + extra={"slot_id": slot["slot_id"], "idle_seconds": idle_seconds}, + ) + + +def _update_metrics(db: StateDB, metrics: MetricsRegistry, tick_duration: float) -> None: + """Refresh all gauge/counter/histogram values.""" + summary = db.get_state_summary() + + for state, count in summary["slots"].items(): + if state == "total": + continue + metrics.gauge("autoscaler_slots", {"state": state}, float(count)) + + for phase, count in summary["reservations"].items(): + metrics.gauge("autoscaler_reservations", {"phase": phase}, float(count)) + + metrics.histogram_observe("autoscaler_scheduler_tick_seconds", {}, tick_duration) diff --git a/agent/nix_builder_autoscaler/state_db.py b/agent/nix_builder_autoscaler/state_db.py index d956f75..bb11156 100644 --- a/agent/nix_builder_autoscaler/state_db.py +++ b/agent/nix_builder_autoscaler/state_db.py @@ -153,6 +153,47 @@ class StateDB: self._conn.execute("ROLLBACK") raise + def update_slot_fields(self, slot_id: str, **fields: object) -> None: + """Update specific slot columns without changing state or last_state_change. + + Uses BEGIN IMMEDIATE. Allowed fields: instance_id, instance_ip, + instance_launch_time, lease_count, cooldown_until, interruption_pending. + """ + allowed = { + "instance_id", + "instance_ip", + "instance_launch_time", + "lease_count", + "cooldown_until", + "interruption_pending", + } + if not fields: + return + + set_parts: list[str] = [] + params: list[object] = [] + for k, v in fields.items(): + if k not in allowed: + msg = f"Unknown slot field: {k}" + raise ValueError(msg) + set_parts.append(f"{k} = ?") + params.append(v) + + params.append(slot_id) + sql = f"UPDATE slots SET {', '.join(set_parts)} WHERE slot_id = ?" + + self._conn.execute("BEGIN IMMEDIATE") + try: + self._conn.execute(sql, params) + self._record_event_inner( + "slot_fields_updated", + {"slot_id": slot_id, **fields}, + ) + self._conn.execute("COMMIT") + except Exception: + self._conn.execute("ROLLBACK") + raise + # -- Reservation operations --------------------------------------------- def create_reservation( diff --git a/agent/nix_builder_autoscaler/tests/test_haproxy_provider.py b/agent/nix_builder_autoscaler/tests/test_haproxy_provider.py index 3d9e484..d2ccf6d 100644 --- a/agent/nix_builder_autoscaler/tests/test_haproxy_provider.py +++ b/agent/nix_builder_autoscaler/tests/test_haproxy_provider.py @@ -1 +1,148 @@ -"""HAProxy provider unit tests — Plan 02.""" +"""Unit tests for the HAProxy provider, mocking at socket level.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from nix_builder_autoscaler.providers.haproxy import HAProxyError, HAProxyRuntime + +# HAProxy `show stat` CSV — trimmed to columns the parser uses. +# Full output has many more columns; we keep through `status` (col 17). +SHOW_STAT_CSV = ( + "# pxname,svname,qcur,qmax,scur,smax,slim,stot," + "bin,bout,dreq,dresp,ereq,econ,eresp,wretr,wredis,status\n" + "all,BACKEND,0,0,2,5,200,100,5000,6000,0,0,,0,0,0,0,UP\n" + "all,slot001,0,0,1,3,50,50,2500,3000,0,0,,0,0,0,0,UP\n" + "all,slot002,0,0,1,2,50,50,2500,3000,0,0,,0,0,0,0,DOWN\n" + "all,slot003,0,0,0,0,50,0,0,0,0,0,,0,0,0,0,MAINT\n" +) + + +class TestSetSlotAddr: + @patch("nix_builder_autoscaler.providers.haproxy.socket.socket") + def test_sends_correct_command(self, mock_socket_cls): + mock_sock = MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.recv.side_effect = [b"\n", b""] + + h = HAProxyRuntime("/tmp/test.sock", "all", "slot") + h.set_slot_addr("slot001", "100.64.0.1", 22) + + mock_sock.connect.assert_called_once_with("/tmp/test.sock") + mock_sock.sendall.assert_called_once_with( + b"set server all/slot001 addr 100.64.0.1 port 22\n" + ) + + +class TestEnableSlot: + @patch("nix_builder_autoscaler.providers.haproxy.socket.socket") + def test_sends_correct_command(self, mock_socket_cls): + mock_sock = MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.recv.side_effect = [b"\n", b""] + + h = HAProxyRuntime("/tmp/test.sock", "all", "slot") + h.enable_slot("slot001") + + mock_sock.sendall.assert_called_once_with(b"enable server all/slot001\n") + + +class TestDisableSlot: + @patch("nix_builder_autoscaler.providers.haproxy.socket.socket") + def test_sends_correct_command(self, mock_socket_cls): + mock_sock = MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.recv.side_effect = [b"\n", b""] + + h = HAProxyRuntime("/tmp/test.sock", "all", "slot") + h.disable_slot("slot001") + + mock_sock.sendall.assert_called_once_with(b"disable server all/slot001\n") + + +class TestReadSlotHealth: + @patch("nix_builder_autoscaler.providers.haproxy.socket.socket") + def test_parses_csv_correctly(self, mock_socket_cls): + mock_sock = MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.recv.side_effect = [SHOW_STAT_CSV.encode(), b""] + + h = HAProxyRuntime("/tmp/test.sock", "all", "slot") + health = h.read_slot_health() + + assert len(health) == 3 + # BACKEND row should be excluded (svname "BACKEND" doesn't start with "slot") + + assert health["slot001"].status == "UP" + assert health["slot001"].scur == 1 + assert health["slot001"].qcur == 0 + + assert health["slot002"].status == "DOWN" + assert health["slot002"].scur == 1 + assert health["slot002"].qcur == 0 + + assert health["slot003"].status == "MAINT" + assert health["slot003"].scur == 0 + assert health["slot003"].qcur == 0 + + +class TestSlotIsUp: + @patch("nix_builder_autoscaler.providers.haproxy.socket.socket") + def test_returns_true_for_up_slot(self, mock_socket_cls): + mock_sock = MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.recv.side_effect = [SHOW_STAT_CSV.encode(), b""] + + h = HAProxyRuntime("/tmp/test.sock", "all", "slot") + assert h.slot_is_up("slot001") is True + + @patch("nix_builder_autoscaler.providers.haproxy.socket.socket") + def test_returns_false_for_down_slot(self, mock_socket_cls): + mock_sock = MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.recv.side_effect = [SHOW_STAT_CSV.encode(), b""] + + h = HAProxyRuntime("/tmp/test.sock", "all", "slot") + assert h.slot_is_up("slot002") is False + + +class TestErrorHandling: + @patch("nix_builder_autoscaler.providers.haproxy.socket.socket") + def test_unrecognized_slot_raises_haproxy_error(self, mock_socket_cls): + mock_sock = MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.recv.side_effect = [b"No such server.\n", b""] + + h = HAProxyRuntime("/tmp/test.sock", "all", "slot") + with pytest.raises(HAProxyError, match="No such server"): + h.set_slot_addr("slot999", "100.64.0.1") + + @patch("nix_builder_autoscaler.providers.haproxy.socket.socket") + def test_socket_not_found_raises_haproxy_error(self, mock_socket_cls): + mock_sock = MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.connect.side_effect = FileNotFoundError("No such file") + + h = HAProxyRuntime("/tmp/nonexistent.sock", "all", "slot") + with pytest.raises(HAProxyError, match="socket not found"): + h.set_slot_addr("slot001", "100.64.0.1") + + @patch("nix_builder_autoscaler.providers.haproxy.socket.socket") + def test_connection_refused_raises_haproxy_error(self, mock_socket_cls): + mock_sock = MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.connect.side_effect = ConnectionRefusedError("Connection refused") + + h = HAProxyRuntime("/tmp/test.sock", "all", "slot") + with pytest.raises(HAProxyError, match="Connection refused"): + h.enable_slot("slot001") + + @patch("nix_builder_autoscaler.providers.haproxy.socket.socket") + def test_slot_session_count_missing_slot_raises(self, mock_socket_cls): + mock_sock = MagicMock() + mock_socket_cls.return_value = mock_sock + mock_sock.recv.side_effect = [SHOW_STAT_CSV.encode(), b""] + + h = HAProxyRuntime("/tmp/test.sock", "all", "slot") + with pytest.raises(HAProxyError, match="Slot not found"): + h.slot_session_count("slot_nonexistent") diff --git a/agent/nix_builder_autoscaler/tests/test_runtime_ec2.py b/agent/nix_builder_autoscaler/tests/test_runtime_ec2.py index be79379..40c35b3 100644 --- a/agent/nix_builder_autoscaler/tests/test_runtime_ec2.py +++ b/agent/nix_builder_autoscaler/tests/test_runtime_ec2.py @@ -1 +1,286 @@ -"""EC2 runtime unit tests — Plan 02.""" +"""Unit tests for the EC2 runtime adapter using botocore Stubber.""" + +from datetime import UTC, datetime +from unittest.mock import patch + +import boto3 +import pytest +from botocore.stub import Stubber + +from nix_builder_autoscaler.config import AwsConfig +from nix_builder_autoscaler.runtime.base import RuntimeError as RuntimeAdapterError +from nix_builder_autoscaler.runtime.ec2 import EC2Runtime + + +def _make_config(): + return AwsConfig( + region="us-east-1", + launch_template_id="lt-abc123", + subnet_ids=["subnet-aaa", "subnet-bbb"], + security_group_ids=["sg-111"], + instance_profile_arn="arn:aws:iam::123456789012:instance-profile/nix-builder", + ) + + +def _make_runtime(stubber, ec2_client, **kwargs): + config = kwargs.pop("config", _make_config()) + environment = kwargs.pop("environment", "dev") + stubber.activate() + return EC2Runtime(config, environment=environment, _client=ec2_client) + + +class TestLaunchSpot: + def test_correct_params_and_returns_instance_id(self): + config = _make_config() + ec2_client = boto3.client("ec2", region_name="us-east-1") + stubber = Stubber(ec2_client) + + expected_params = { + "MinCount": 1, + "MaxCount": 1, + "LaunchTemplate": { + "LaunchTemplateId": "lt-abc123", + "Version": "$Latest", + }, + "InstanceMarketOptions": { + "MarketType": "spot", + "SpotOptions": { + "SpotInstanceType": "one-time", + "InstanceInterruptionBehavior": "terminate", + }, + }, + "SubnetId": "subnet-aaa", + "UserData": "#!/bin/bash\necho hello", + "TagSpecifications": [ + { + "ResourceType": "instance", + "Tags": [ + {"Key": "Name", "Value": "nix-builder-slot001"}, + {"Key": "AutoscalerSlot", "Value": "slot001"}, + {"Key": "ManagedBy", "Value": "nix-builder-autoscaler"}, + {"Key": "Service", "Value": "nix-builder"}, + {"Key": "Environment", "Value": "dev"}, + ], + } + ], + } + + response = { + "Instances": [{"InstanceId": "i-12345678"}], + "OwnerId": "123456789012", + } + + stubber.add_response("run_instances", response, expected_params) + runtime = _make_runtime(stubber, ec2_client, config=config) + + iid = runtime.launch_spot("slot001", "#!/bin/bash\necho hello") + assert iid == "i-12345678" + stubber.assert_no_pending_responses() + + def test_round_robin_subnets(self): + config = _make_config() + ec2_client = boto3.client("ec2", region_name="us-east-1") + stubber = Stubber(ec2_client) + + # Two launches should use subnet-aaa then subnet-bbb + for _subnet in ["subnet-aaa", "subnet-bbb"]: + stubber.add_response( + "run_instances", + {"Instances": [{"InstanceId": "i-abc"}], "OwnerId": "123"}, + ) + + runtime = _make_runtime(stubber, ec2_client, config=config) + runtime.launch_spot("slot001", "") + runtime.launch_spot("slot002", "") + stubber.assert_no_pending_responses() + + +class TestDescribeInstance: + def test_normalizes_response(self): + ec2_client = boto3.client("ec2", region_name="us-east-1") + stubber = Stubber(ec2_client) + + launch_time = datetime(2026, 1, 15, 12, 30, 0, tzinfo=UTC) + response = { + "Reservations": [ + { + "Instances": [ + { + "InstanceId": "i-running1", + "State": {"Code": 16, "Name": "running"}, + "LaunchTime": launch_time, + "Tags": [ + {"Key": "AutoscalerSlot", "Value": "slot001"}, + ], + } + ], + } + ], + } + + stubber.add_response( + "describe_instances", + response, + {"InstanceIds": ["i-running1"]}, + ) + runtime = _make_runtime(stubber, ec2_client) + + info = runtime.describe_instance("i-running1") + assert info["state"] == "running" + assert info["tailscale_ip"] is None + assert info["launch_time"] == launch_time.isoformat() + + def test_missing_instance_returns_terminated(self): + ec2_client = boto3.client("ec2", region_name="us-east-1") + stubber = Stubber(ec2_client) + + stubber.add_response( + "describe_instances", + {"Reservations": []}, + {"InstanceIds": ["i-gone"]}, + ) + runtime = _make_runtime(stubber, ec2_client) + + info = runtime.describe_instance("i-gone") + assert info["state"] == "terminated" + assert info["tailscale_ip"] is None + assert info["launch_time"] is None + + +class TestListManagedInstances: + def test_filters_by_tag(self): + ec2_client = boto3.client("ec2", region_name="us-east-1") + stubber = Stubber(ec2_client) + + expected_params = { + "Filters": [ + {"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]}, + { + "Name": "instance-state-name", + "Values": ["pending", "running", "shutting-down", "stopping"], + }, + ], + } + + response = { + "Reservations": [ + { + "Instances": [ + { + "InstanceId": "i-aaa", + "State": {"Code": 16, "Name": "running"}, + "Tags": [ + {"Key": "AutoscalerSlot", "Value": "slot001"}, + {"Key": "ManagedBy", "Value": "nix-builder-autoscaler"}, + ], + }, + { + "InstanceId": "i-bbb", + "State": {"Code": 0, "Name": "pending"}, + "Tags": [ + {"Key": "AutoscalerSlot", "Value": "slot002"}, + {"Key": "ManagedBy", "Value": "nix-builder-autoscaler"}, + ], + }, + ], + } + ], + } + + stubber.add_response("describe_instances", response, expected_params) + runtime = _make_runtime(stubber, ec2_client) + + managed = runtime.list_managed_instances() + assert len(managed) == 2 + assert managed[0]["instance_id"] == "i-aaa" + assert managed[0]["state"] == "running" + assert managed[0]["slot_id"] == "slot001" + assert managed[1]["instance_id"] == "i-bbb" + assert managed[1]["state"] == "pending" + assert managed[1]["slot_id"] == "slot002" + + +class TestTerminateInstance: + def test_calls_terminate_api(self): + ec2_client = boto3.client("ec2", region_name="us-east-1") + stubber = Stubber(ec2_client) + + response = { + "TerminatingInstances": [ + { + "InstanceId": "i-kill", + "CurrentState": {"Code": 32, "Name": "shutting-down"}, + "PreviousState": {"Code": 16, "Name": "running"}, + } + ], + } + + stubber.add_response( + "terminate_instances", + response, + {"InstanceIds": ["i-kill"]}, + ) + runtime = _make_runtime(stubber, ec2_client) + + # Should not raise + runtime.terminate_instance("i-kill") + stubber.assert_no_pending_responses() + + +class TestErrorClassification: + def test_insufficient_capacity_classified(self): + ec2_client = boto3.client("ec2", region_name="us-east-1") + stubber = Stubber(ec2_client) + + stubber.add_client_error( + "run_instances", + service_error_code="InsufficientInstanceCapacity", + service_message="Insufficient capacity", + ) + runtime = _make_runtime(stubber, ec2_client) + + with pytest.raises(RuntimeAdapterError) as exc_info: + runtime.launch_spot("slot001", "#!/bin/bash") + assert exc_info.value.category == "capacity_unavailable" + + @patch("nix_builder_autoscaler.runtime.ec2.time.sleep") + def test_request_limit_exceeded_retried(self, mock_sleep): + ec2_client = boto3.client("ec2", region_name="us-east-1") + stubber = Stubber(ec2_client) + + # First call: throttled + stubber.add_client_error( + "run_instances", + service_error_code="RequestLimitExceeded", + service_message="Rate exceeded", + ) + # Second call: success + stubber.add_response( + "run_instances", + {"Instances": [{"InstanceId": "i-retry123"}], "OwnerId": "123"}, + ) + runtime = _make_runtime(stubber, ec2_client) + + iid = runtime.launch_spot("slot001", "#!/bin/bash") + assert iid == "i-retry123" + assert mock_sleep.called + stubber.assert_no_pending_responses() + + @patch("nix_builder_autoscaler.runtime.ec2.time.sleep") + def test_request_limit_exceeded_exhausted(self, mock_sleep): + """After max retries, RequestLimitExceeded raises with 'throttled' category.""" + ec2_client = boto3.client("ec2", region_name="us-east-1") + stubber = Stubber(ec2_client) + + # 4 errors (max_retries=3: attempt 0,1,2,3 all fail) + for _ in range(4): + stubber.add_client_error( + "run_instances", + service_error_code="RequestLimitExceeded", + service_message="Rate exceeded", + ) + runtime = _make_runtime(stubber, ec2_client) + + with pytest.raises(RuntimeAdapterError) as exc_info: + runtime.launch_spot("slot001", "#!/bin/bash") + assert exc_info.value.category == "throttled" diff --git a/agent/nix_builder_autoscaler/tests/test_scheduler.py b/agent/nix_builder_autoscaler/tests/test_scheduler.py index ceb8d09..c6c9da2 100644 --- a/agent/nix_builder_autoscaler/tests/test_scheduler.py +++ b/agent/nix_builder_autoscaler/tests/test_scheduler.py @@ -1 +1,194 @@ """Scheduler unit tests — Plan 03.""" + +from nix_builder_autoscaler.config import AppConfig, AwsConfig, CapacityConfig +from nix_builder_autoscaler.metrics import MetricsRegistry +from nix_builder_autoscaler.models import ReservationPhase, SlotState +from nix_builder_autoscaler.providers.clock import FakeClock +from nix_builder_autoscaler.runtime.fake import FakeRuntime +from nix_builder_autoscaler.scheduler import scheduling_tick +from nix_builder_autoscaler.state_db import StateDB + + +def _make_env( + slot_count=3, + max_slots=3, + max_leases=1, + idle_scale_down_seconds=900, + target_warm=0, + min_slots=0, +): + clock = FakeClock() + db = StateDB(":memory:", clock=clock) + db.init_schema() + db.init_slots("slot", slot_count, "x86_64-linux", "all") + runtime = FakeRuntime(launch_latency_ticks=2, ip_delay_ticks=1) + config = AppConfig( + capacity=CapacityConfig( + max_slots=max_slots, + max_leases_per_slot=max_leases, + idle_scale_down_seconds=idle_scale_down_seconds, + target_warm_slots=target_warm, + min_slots=min_slots, + reservation_ttl_seconds=1200, + ), + aws=AwsConfig(region="us-east-1"), + ) + metrics = MetricsRegistry() + return db, runtime, config, clock, metrics + + +def _make_slot_ready(db, slot_id, instance_id="i-test1", ip="100.64.0.1"): + """Transition a slot through the full state machine to ready.""" + db.update_slot_state(slot_id, SlotState.LAUNCHING, instance_id=instance_id) + db.update_slot_state(slot_id, SlotState.BOOTING) + db.update_slot_state(slot_id, SlotState.BINDING, instance_ip=ip) + db.update_slot_state(slot_id, SlotState.READY) + + +# --- Test cases --- + + +def test_pending_reservation_assigned_to_ready_slot(): + db, runtime, config, clock, metrics = _make_env() + _make_slot_ready(db, "slot001") + + resv = db.create_reservation("x86_64-linux", "test", None, 1200) + + scheduling_tick(db, runtime, config, clock, metrics) + + updated = db.get_reservation(resv["reservation_id"]) + assert updated["phase"] == ReservationPhase.READY.value + assert updated["slot_id"] == "slot001" + assert updated["instance_id"] == "i-test1" + + slot = db.get_slot("slot001") + assert slot["lease_count"] == 1 + + +def test_two_pending_one_slot_only_one_assigned_per_tick(): + db, runtime, config, clock, metrics = _make_env(max_leases=1) + _make_slot_ready(db, "slot001") + + r1 = db.create_reservation("x86_64-linux", "test1", None, 1200) + r2 = db.create_reservation("x86_64-linux", "test2", None, 1200) + + scheduling_tick(db, runtime, config, clock, metrics) + + u1 = db.get_reservation(r1["reservation_id"]) + u2 = db.get_reservation(r2["reservation_id"]) + + ready_count = sum(1 for r in [u1, u2] if r["phase"] == ReservationPhase.READY.value) + pending_count = sum(1 for r in [u1, u2] if r["phase"] == ReservationPhase.PENDING.value) + assert ready_count == 1 + assert pending_count == 1 + + slot = db.get_slot("slot001") + assert slot["lease_count"] == 1 + + +def test_reservation_expires_when_ttl_passes(): + db, runtime, config, clock, metrics = _make_env() + config.capacity.reservation_ttl_seconds = 60 + + db.create_reservation("x86_64-linux", "test", None, 60) + + clock.advance(61) + scheduling_tick(db, runtime, config, clock, metrics) + + reservations = db.list_reservations(ReservationPhase.EXPIRED) + assert len(reservations) == 1 + + +def test_scale_down_starts_when_idle_exceeds_threshold(): + db, runtime, config, clock, metrics = _make_env(idle_scale_down_seconds=900) + _make_slot_ready(db, "slot001") + + clock.advance(901) + scheduling_tick(db, runtime, config, clock, metrics) + + slot = db.get_slot("slot001") + assert slot["state"] == SlotState.DRAINING.value + + +def test_slot_does_not_drain_while_lease_count_positive(): + db, runtime, config, clock, metrics = _make_env(idle_scale_down_seconds=900) + _make_slot_ready(db, "slot001") + + resv = db.create_reservation("x86_64-linux", "test", None, 1200) + scheduling_tick(db, runtime, config, clock, metrics) + + # Confirm assigned + updated = db.get_reservation(resv["reservation_id"]) + assert updated["phase"] == ReservationPhase.READY.value + + clock.advance(901) + scheduling_tick(db, runtime, config, clock, metrics) + + slot = db.get_slot("slot001") + assert slot["state"] == SlotState.READY.value + + +def test_interruption_pending_slot_moves_to_draining(): + db, runtime, config, clock, metrics = _make_env() + _make_slot_ready(db, "slot001") + + db.update_slot_fields("slot001", interruption_pending=1) + + scheduling_tick(db, runtime, config, clock, metrics) + + slot = db.get_slot("slot001") + assert slot["state"] == SlotState.DRAINING.value + assert slot["interruption_pending"] == 0 + + +def test_launch_triggered_for_unmet_demand(): + db, runtime, config, clock, metrics = _make_env() + + db.create_reservation("x86_64-linux", "test", None, 1200) + + scheduling_tick(db, runtime, config, clock, metrics) + + launching = db.list_slots(SlotState.LAUNCHING) + assert len(launching) == 1 + assert launching[0]["instance_id"] is not None + + # FakeRuntime should have one pending instance + managed = runtime.list_managed_instances() + assert len(managed) == 1 + + +def test_launch_respects_max_slots(): + db, runtime, config, clock, metrics = _make_env(max_slots=1) + _make_slot_ready(db, "slot001") + + # Slot001 is at capacity (lease_count will be 1 after assignment) + db.create_reservation("x86_64-linux", "test1", None, 1200) + db.create_reservation("x86_64-linux", "test2", None, 1200) + + scheduling_tick(db, runtime, config, clock, metrics) + + # One reservation assigned, one still pending — but no new launch + # because active_slots (1) == max_slots (1) + launching = db.list_slots(SlotState.LAUNCHING) + assert len(launching) == 0 + + +def test_min_slots_maintained(): + db, runtime, config, clock, metrics = _make_env(min_slots=1) + + # No reservations, all slots empty + scheduling_tick(db, runtime, config, clock, metrics) + + launching = db.list_slots(SlotState.LAUNCHING) + assert len(launching) == 1 + + +def test_scale_down_respects_min_slots(): + db, runtime, config, clock, metrics = _make_env(min_slots=1, idle_scale_down_seconds=900) + _make_slot_ready(db, "slot001") + + clock.advance(901) + scheduling_tick(db, runtime, config, clock, metrics) + + slot = db.get_slot("slot001") + assert slot["state"] == SlotState.READY.value diff --git a/agent/pyproject.toml b/agent/pyproject.toml index 1087554..d2af89e 100644 --- a/agent/pyproject.toml +++ b/agent/pyproject.toml @@ -31,7 +31,7 @@ line-length = 100 [tool.ruff.lint] select = ["E", "F", "I", "UP", "B", "SIM", "ANN"] -ignore = [] +ignore = ["ANN401"] [tool.ruff.lint.per-file-ignores] "*/tests/*" = ["ANN"]