WIP autoscaler agent

This commit is contained in:
Abel Luck 2026-02-27 11:59:16 +01:00
parent c610a3e284
commit 28059dcedf
34 changed files with 2409 additions and 35 deletions

View file

@ -0,0 +1,3 @@
"""Nix builder autoscaler daemon."""
__version__ = "0.1.0"

View file

@ -0,0 +1,24 @@
"""Daemon entry point: python -m nix_builder_autoscaler — stub for Plan 04."""
from __future__ import annotations
import argparse
import sys
def main() -> None:
"""Parse arguments and start the daemon."""
parser = argparse.ArgumentParser(
prog="nix-builder-autoscaler",
description="Nix builder autoscaler daemon",
)
parser.add_argument("--config", required=True, help="Path to TOML config file")
args = parser.parse_args()
print(f"nix-builder-autoscaler: would start with config {args.config}")
print("Full daemon implementation in Plan 04.")
sys.exit(0)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,19 @@
"""FastAPI application — stub for Plan 04."""
from __future__ import annotations
from fastapi import FastAPI
def create_app() -> FastAPI:
"""Create the FastAPI application.
Full implementation in Plan 04.
"""
app = FastAPI(title="nix-builder-autoscaler", version="0.1.0")
@app.get("/health/live")
def health_live() -> dict[str, str]:
return {"status": "ok"}
return app

View file

@ -0,0 +1,11 @@
"""EC2 user-data template rendering — stub for Plan 02."""
from __future__ import annotations
def render_userdata(slot_id: str, region: str, ssm_param: str = "/nix-builder/ts-authkey") -> str:
"""Render a bash user-data script for builder instance bootstrap.
Full implementation in Plan 02.
"""
raise NotImplementedError

View file

@ -0,0 +1,26 @@
"""autoscalerctl CLI entry point — stub for Plan 04."""
from __future__ import annotations
import argparse
import sys
def main() -> None:
"""Entry point for the autoscalerctl CLI."""
parser = argparse.ArgumentParser(prog="autoscalerctl", description="Autoscaler CLI")
parser.add_argument(
"--socket", default="/run/nix-builder-autoscaler/daemon.sock", help="Daemon socket path"
)
subparsers = parser.add_subparsers(dest="command")
subparsers.add_parser("status", help="Show daemon status")
subparsers.add_parser("slots", help="List slots")
subparsers.add_parser("reservations", help="List reservations")
args = parser.parse_args()
if not args.command:
parser.print_help()
sys.exit(1)
print(f"autoscalerctl: command '{args.command}' not yet implemented")
sys.exit(1)

View file

@ -0,0 +1,155 @@
"""Configuration loading from TOML with environment variable overrides."""
from __future__ import annotations
import os
import tomllib
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class ServerConfig:
"""[server] section."""
socket_path: str = "/run/nix-builder-autoscaler/daemon.sock"
log_level: str = "info"
db_path: str = "/var/lib/nix-builder-autoscaler/state.db"
@dataclass
class AwsConfig:
"""[aws] section."""
region: str = "us-east-1"
launch_template_id: str = ""
subnet_ids: list[str] = field(default_factory=list)
security_group_ids: list[str] = field(default_factory=list)
instance_profile_arn: str = ""
@dataclass
class HaproxyConfig:
"""[haproxy] section."""
runtime_socket: str = "/run/haproxy/admin.sock"
backend: str = "all"
slot_prefix: str = "slot"
slot_count: int = 8
check_ready_up_count: int = 2
@dataclass
class SystemConfig:
"""[[systems]] entry for per-architecture capacity policy."""
name: str = "x86_64-linux"
min_slots: int = 0
max_slots: int = 8
target_warm_slots: int = 0
max_leases_per_slot: int = 1
launch_batch_size: int = 1
scale_down_idle_seconds: int = 900
@dataclass
class CapacityConfig:
"""[capacity] section — global defaults."""
default_system: str = "x86_64-linux"
min_slots: int = 0
max_slots: int = 8
target_warm_slots: int = 0
max_leases_per_slot: int = 1
reservation_ttl_seconds: int = 1200
idle_scale_down_seconds: int = 900
drain_timeout_seconds: int = 120
@dataclass
class SecurityConfig:
"""[security] section."""
socket_mode: str = "0660"
socket_owner: str = "buildbot"
socket_group: str = "buildbot"
@dataclass
class SchedulerConfig:
"""[scheduler] section."""
tick_seconds: float = 3.0
reconcile_seconds: float = 15.0
@dataclass
class AppConfig:
"""Top-level application configuration."""
server: ServerConfig = field(default_factory=ServerConfig)
aws: AwsConfig = field(default_factory=AwsConfig)
haproxy: HaproxyConfig = field(default_factory=HaproxyConfig)
capacity: CapacityConfig = field(default_factory=CapacityConfig)
security: SecurityConfig = field(default_factory=SecurityConfig)
scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
systems: list[SystemConfig] = field(default_factory=list)
# ---------------------------------------------------------------------------
# Environment variable overrides
# ---------------------------------------------------------------------------
# AUTOSCALER_TAILSCALE_API_TOKEN — Tailscale API token for IP discovery
# AWS_REGION — override aws.region
# AWS_ACCESS_KEY_ID — explicit AWS credential
# AWS_SECRET_ACCESS_KEY — explicit AWS credential
def _apply_env_overrides(cfg: AppConfig) -> None:
"""Apply environment variable overrides for secrets and region."""
region = os.environ.get("AWS_REGION")
if region:
cfg.aws.region = region
def _build_dataclass(cls: type, data: dict) -> object: # noqa: ANN001
"""Construct a dataclass from a dict, ignoring unknown keys."""
valid = {f.name for f in cls.__dataclass_fields__.values()} # type: ignore[attr-defined]
return cls(**{k: v for k, v in data.items() if k in valid})
def load_config(path: Path) -> AppConfig:
"""Load configuration from a TOML file.
Args:
path: Path to the TOML config file.
Returns:
Validated AppConfig instance.
"""
with open(path, "rb") as f:
raw = tomllib.load(f)
cfg = AppConfig()
if "server" in raw:
cfg.server = _build_dataclass(ServerConfig, raw["server"]) # type: ignore[assignment]
if "aws" in raw:
cfg.aws = _build_dataclass(AwsConfig, raw["aws"]) # type: ignore[assignment]
if "haproxy" in raw:
cfg.haproxy = _build_dataclass(HaproxyConfig, raw["haproxy"]) # type: ignore[assignment]
if "capacity" in raw:
cfg.capacity = _build_dataclass(CapacityConfig, raw["capacity"]) # type: ignore[assignment]
if "security" in raw:
cfg.security = _build_dataclass(SecurityConfig, raw["security"]) # type: ignore[assignment]
if "scheduler" in raw:
cfg.scheduler = _build_dataclass(SchedulerConfig, raw["scheduler"]) # type: ignore[assignment]
if "systems" in raw:
cfg.systems = list[SystemConfig](
_build_dataclass(SystemConfig, s) # type: ignore[list-item]
for s in raw["systems"]
)
_apply_env_overrides(cfg)
return cfg

View file

@ -0,0 +1,46 @@
"""Structured JSON logging setup."""
from __future__ import annotations
import json
import logging
import sys
from datetime import UTC, datetime
from typing import Any
class JSONFormatter(logging.Formatter):
"""Format log records as single-line JSON."""
EXTRA_FIELDS = ("slot_id", "reservation_id", "instance_id", "request_id")
def format(self, record: logging.LogRecord) -> str:
"""Format a log record as JSON."""
entry: dict[str, Any] = {
"ts": datetime.now(UTC).isoformat(),
"level": record.levelname,
"logger": record.name,
"message": record.getMessage(),
}
for field in self.EXTRA_FIELDS:
val = getattr(record, field, None)
if val is not None:
entry[field] = val
if record.exc_info and record.exc_info[1] is not None:
entry["exception"] = self.formatException(record.exc_info)
return json.dumps(entry, default=str)
def setup_logging(level: str = "INFO") -> None:
"""Configure the root logger with JSON output to stderr.
Args:
level: Log level name (DEBUG, INFO, WARNING, ERROR).
"""
handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(JSONFormatter())
root = logging.getLogger()
root.handlers.clear()
root.addHandler(handler)
root.setLevel(getattr(logging, level.upper(), logging.INFO))

View file

@ -0,0 +1,103 @@
"""In-memory Prometheus metrics registry.
No prometheus_client dependency formats text manually.
"""
from __future__ import annotations
import threading
from typing import Any
def _labels_key(labels: dict[str, str]) -> tuple[tuple[str, str], ...]:
return tuple(sorted(labels.items()))
def _format_labels(labels: dict[str, str]) -> str:
if not labels:
return ""
parts = ",".join(f'{k}="{v}"' for k, v in sorted(labels.items()))
return "{" + parts + "}"
class MetricsRegistry:
"""Thread-safe in-memory metrics store with Prometheus text output."""
def __init__(self) -> None:
self._lock = threading.Lock()
self._gauges: dict[str, dict[tuple[tuple[str, str], ...], float]] = {}
self._counters: dict[str, dict[tuple[tuple[str, str], ...], float]] = {}
self._histograms: dict[str, dict[tuple[tuple[str, str], ...], Any]] = {}
def gauge(self, name: str, labels: dict[str, str], value: float) -> None:
"""Set a gauge value."""
key = _labels_key(labels)
with self._lock:
if name not in self._gauges:
self._gauges[name] = {}
self._gauges[name][key] = value
def counter(self, name: str, labels: dict[str, str], increment: float = 1.0) -> None:
"""Increment a counter."""
key = _labels_key(labels)
with self._lock:
if name not in self._counters:
self._counters[name] = {}
self._counters[name][key] = self._counters[name].get(key, 0.0) + increment
def histogram_observe(self, name: str, labels: dict[str, str], value: float) -> None:
"""Record a histogram observation.
Uses fixed buckets: 0.01, 0.05, 0.1, 0.5, 1, 5, 10, 30, 60, 120, +Inf.
"""
key = _labels_key(labels)
buckets = (0.01, 0.05, 0.1, 0.5, 1.0, 5.0, 10.0, 30.0, 60.0, 120.0)
with self._lock:
if name not in self._histograms:
self._histograms[name] = {}
if key not in self._histograms[name]:
self._histograms[name][key] = {
"labels": labels,
"buckets": {b: 0 for b in buckets},
"sum": 0.0,
"count": 0,
}
entry = self._histograms[name][key]
entry["sum"] += value
entry["count"] += 1
for b in buckets:
if value <= b:
entry["buckets"][b] += 1
def render(self) -> str:
"""Render all metrics in Prometheus text exposition format."""
lines: list[str] = []
with self._lock:
for name, series in sorted(self._gauges.items()):
lines.append(f"# TYPE {name} gauge")
for key, val in sorted(series.items()):
labels = dict(key)
lines.append(f"{name}{_format_labels(labels)} {val}")
for name, series in sorted(self._counters.items()):
lines.append(f"# TYPE {name} counter")
for key, val in sorted(series.items()):
labels = dict(key)
lines.append(f"{name}{_format_labels(labels)} {val}")
for name, series in sorted(self._histograms.items()):
lines.append(f"# TYPE {name} histogram")
for _key, entry in sorted(series.items()):
labels = entry["labels"]
cumulative = 0
for b, count in sorted(entry["buckets"].items()):
cumulative += count
le_labels = {**labels, "le": str(b)}
lines.append(f"{name}_bucket{_format_labels(le_labels)} {cumulative}")
inf_labels = {**labels, "le": "+Inf"}
lines.append(f"{name}_bucket{_format_labels(inf_labels)} {entry['count']}")
lines.append(f"{name}_sum{_format_labels(labels)} {entry['sum']}")
lines.append(f"{name}_count{_format_labels(labels)} {entry['count']}")
lines.append("")
return "\n".join(lines)

View file

@ -0,0 +1,153 @@
"""Data models for the autoscaler daemon."""
from __future__ import annotations
from datetime import datetime
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, Field
class SlotState(StrEnum):
"""Exhaustive slot states."""
EMPTY = "empty"
LAUNCHING = "launching"
BOOTING = "booting"
BINDING = "binding"
READY = "ready"
DRAINING = "draining"
TERMINATING = "terminating"
ERROR = "error"
class ReservationPhase(StrEnum):
"""Exhaustive reservation phases."""
PENDING = "pending"
READY = "ready"
FAILED = "failed"
RELEASED = "released"
EXPIRED = "expired"
# ---------------------------------------------------------------------------
# API request models
# ---------------------------------------------------------------------------
class ReservationRequest(BaseModel):
"""POST /v1/reservations request body."""
system: str
reason: str
build_id: int | None = None
class CapacityHint(BaseModel):
"""POST /v1/hints/capacity request body."""
builder: str
queued: int
running: int
system: str
timestamp: datetime
# ---------------------------------------------------------------------------
# API response models
# ---------------------------------------------------------------------------
class ReservationResponse(BaseModel):
"""Reservation representation returned by the API."""
reservation_id: str
phase: ReservationPhase
slot: str | None = None
instance_id: str | None = None
system: str
created_at: datetime
updated_at: datetime
expires_at: datetime
released_at: datetime | None = None
class SlotInfo(BaseModel):
"""Slot representation returned by the API."""
slot_id: str
system: str
state: SlotState
instance_id: str | None = None
instance_ip: str | None = None
lease_count: int
last_state_change: datetime
class SlotsSummary(BaseModel):
"""Aggregate slot counts by state."""
total: int = 0
ready: int = 0
launching: int = 0
booting: int = 0
binding: int = 0
draining: int = 0
terminating: int = 0
empty: int = 0
error: int = 0
class ReservationsSummary(BaseModel):
"""Aggregate reservation counts by phase."""
pending: int = 0
ready: int = 0
failed: int = 0
class Ec2Summary(BaseModel):
"""EC2 subsystem health."""
api_ok: bool = True
last_reconcile_at: datetime | None = None
class HaproxySummary(BaseModel):
"""HAProxy subsystem health."""
socket_ok: bool = True
last_stat_poll_at: datetime | None = None
class StateSummary(BaseModel):
"""GET /v1/state/summary response."""
slots: SlotsSummary = Field(default_factory=SlotsSummary)
reservations: ReservationsSummary = Field(default_factory=ReservationsSummary)
ec2: Ec2Summary = Field(default_factory=Ec2Summary)
haproxy: HaproxySummary = Field(default_factory=HaproxySummary)
class ErrorDetail(BaseModel):
"""Structured error detail."""
code: str
message: str
retryable: bool = False
details: dict[str, Any] | None = None
class ErrorResponse(BaseModel):
"""Standard error response envelope."""
error: ErrorDetail
request_id: str
class HealthResponse(BaseModel):
"""Health check response."""
status: str

View file

@ -0,0 +1,39 @@
"""Injectable clock abstraction for testability."""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from typing import Protocol
class Clock(Protocol):
"""Clock protocol — provides current UTC time."""
def now(self) -> datetime: ...
class SystemClock:
"""Real wall-clock implementation."""
def now(self) -> datetime:
"""Return the current UTC time."""
return datetime.now(UTC)
class FakeClock:
"""Deterministic clock for tests."""
def __init__(self, start: datetime | None = None) -> None:
self._now = start or datetime(2026, 1, 1, tzinfo=UTC)
def now(self) -> datetime:
"""Return the fixed current time."""
return self._now
def advance(self, seconds: float) -> None:
"""Advance the clock by the given number of seconds."""
self._now += timedelta(seconds=seconds)
def set(self, dt: datetime) -> None:
"""Set the clock to an exact time."""
self._now = dt

View file

@ -0,0 +1,54 @@
"""HAProxy runtime socket adapter — stub for Plan 02."""
from __future__ import annotations
from dataclasses import dataclass
class HAProxyError(Exception):
"""Error communicating with HAProxy runtime socket."""
@dataclass
class SlotHealth:
"""Health status for a single HAProxy server slot."""
status: str
scur: int
qcur: int
class HAProxyRuntime:
"""HAProxy runtime CLI adapter via Unix socket.
Full implementation in Plan 02.
"""
def __init__(self, socket_path: str, backend: str, slot_prefix: str) -> None:
self._socket_path = socket_path
self._backend = backend
self._slot_prefix = slot_prefix
def set_slot_addr(self, slot_id: str, ip: str, port: int = 22) -> None:
"""Update server address for a slot."""
raise NotImplementedError
def enable_slot(self, slot_id: str) -> None:
"""Enable a server slot."""
raise NotImplementedError
def disable_slot(self, slot_id: str) -> None:
"""Disable a server slot."""
raise NotImplementedError
def slot_is_up(self, slot_id: str) -> bool:
"""Return True when HAProxy health status is UP for slot."""
raise NotImplementedError
def slot_session_count(self, slot_id: str) -> int:
"""Return current active session count for slot."""
raise NotImplementedError
def read_slot_health(self) -> dict[str, SlotHealth]:
"""Return full stats snapshot for all slots."""
raise NotImplementedError

View file

@ -0,0 +1 @@
"""Reconciler — stub for Plan 03."""

View file

@ -0,0 +1,43 @@
"""Abstract base class for runtime adapters."""
from __future__ import annotations
from abc import ABC, abstractmethod
class RuntimeError(Exception):
"""Base error for runtime adapter failures.
Attributes:
category: Normalized error category for retry/classification logic.
"""
def __init__(self, message: str, category: str = "unknown") -> None:
super().__init__(message)
self.category = category
class RuntimeAdapter(ABC):
"""Interface for compute runtime backends (EC2, fake, etc.)."""
@abstractmethod
def launch_spot(self, slot_id: str, user_data: str) -> str:
"""Launch a spot instance for slot_id. Return instance_id."""
@abstractmethod
def describe_instance(self, instance_id: str) -> dict:
"""Return normalized instance info dict.
Keys: state, tailscale_ip (or None), launch_time.
"""
@abstractmethod
def terminate_instance(self, instance_id: str) -> None:
"""Terminate the instance."""
@abstractmethod
def list_managed_instances(self) -> list[dict]:
"""Return list of instances tagged ManagedBy=nix-builder-autoscaler.
Each entry has instance_id, state, slot_id (from AutoscalerSlot tag).
"""

View file

@ -0,0 +1,32 @@
"""EC2 runtime adapter — stub for Plan 02."""
from __future__ import annotations
from .base import RuntimeAdapter
class EC2Runtime(RuntimeAdapter):
"""EC2 Spot instance runtime adapter.
Full implementation in Plan 02.
"""
def __init__(self, region: str, launch_template_id: str, **kwargs: object) -> None:
self._region = region
self._launch_template_id = launch_template_id
def launch_spot(self, slot_id: str, user_data: str) -> str:
"""Launch a spot instance for slot_id."""
raise NotImplementedError
def describe_instance(self, instance_id: str) -> dict:
"""Return normalized instance info."""
raise NotImplementedError
def terminate_instance(self, instance_id: str) -> None:
"""Terminate the instance."""
raise NotImplementedError
def list_managed_instances(self) -> list[dict]:
"""Return list of managed instances."""
raise NotImplementedError

View file

@ -0,0 +1,122 @@
"""Fake runtime adapter for testing."""
from __future__ import annotations
import uuid
from dataclasses import dataclass
from .base import RuntimeAdapter
from .base import RuntimeError as RuntimeAdapterError
@dataclass
class _FakeInstance:
instance_id: str
slot_id: str
state: str = "pending"
tailscale_ip: str | None = None
launch_time: str = ""
ticks_to_running: int = 0
ticks_to_ip: int = 0
interrupted: bool = False
class FakeRuntime(RuntimeAdapter):
"""In-memory runtime adapter for deterministic testing.
Args:
launch_latency_ticks: Number of tick() calls before instance becomes running.
ip_delay_ticks: Additional ticks after running before tailscale_ip appears.
"""
def __init__(self, launch_latency_ticks: int = 2, ip_delay_ticks: int = 1) -> None:
self._launch_latency = launch_latency_ticks
self._ip_delay = ip_delay_ticks
self._instances: dict[str, _FakeInstance] = {}
self._launch_failures: set[str] = set()
self._interruptions: set[str] = set()
self._tick_count: int = 0
self._next_ip_counter: int = 1
def launch_spot(self, slot_id: str, user_data: str) -> str:
"""Launch a fake spot instance."""
if slot_id in self._launch_failures:
self._launch_failures.discard(slot_id)
raise RuntimeAdapterError(
f"Simulated launch failure for {slot_id}",
category="capacity_unavailable",
)
iid = f"i-fake-{uuid.uuid4().hex[:12]}"
self._instances[iid] = _FakeInstance(
instance_id=iid,
slot_id=slot_id,
state="pending",
launch_time=f"2026-01-01T00:00:{self._tick_count:02d}Z",
ticks_to_running=self._launch_latency,
ticks_to_ip=self._launch_latency + self._ip_delay,
)
return iid
def describe_instance(self, instance_id: str) -> dict:
"""Return normalized instance info."""
inst = self._instances.get(instance_id)
if inst is None:
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
if instance_id in self._interruptions:
self._interruptions.discard(instance_id)
inst.state = "terminated"
inst.interrupted = True
return {
"state": inst.state,
"tailscale_ip": inst.tailscale_ip,
"launch_time": inst.launch_time,
}
def terminate_instance(self, instance_id: str) -> None:
"""Terminate a fake instance."""
inst = self._instances.get(instance_id)
if inst is not None:
inst.state = "terminated"
def list_managed_instances(self) -> list[dict]:
"""List all non-terminated fake instances."""
result: list[dict] = []
for inst in self._instances.values():
if inst.state != "terminated":
result.append(
{
"instance_id": inst.instance_id,
"state": inst.state,
"slot_id": inst.slot_id,
}
)
return result
# -- Test helpers -------------------------------------------------------
def tick(self) -> None:
"""Advance internal tick counter and progress instance states."""
self._tick_count += 1
for inst in self._instances.values():
if inst.state == "terminated":
continue
if inst.state == "pending" and self._tick_count >= inst.ticks_to_running:
inst.state = "running"
if (
inst.state == "running"
and inst.tailscale_ip is None
and self._tick_count >= inst.ticks_to_ip
):
inst.tailscale_ip = f"100.64.0.{self._next_ip_counter}"
self._next_ip_counter += 1
def inject_launch_failure(self, slot_id: str) -> None:
"""Make the next launch_spot call for this slot_id raise an error."""
self._launch_failures.add(slot_id)
def inject_interruption(self, instance_id: str) -> None:
"""Make the next describe_instance call for this instance return terminated."""
self._interruptions.add(instance_id)

View file

@ -0,0 +1 @@
"""Scheduler — stub for Plan 03."""

View file

@ -0,0 +1,400 @@
"""SQLite state persistence layer.
All write operations use BEGIN IMMEDIATE transactions for crash safety.
"""
from __future__ import annotations
import json
import sqlite3
import uuid
from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING
from .models import ReservationPhase, SlotState
if TYPE_CHECKING:
from .providers.clock import Clock
_SCHEMA = """
CREATE TABLE IF NOT EXISTS slots (
slot_id TEXT PRIMARY KEY,
system TEXT NOT NULL,
state TEXT NOT NULL,
instance_id TEXT,
instance_ip TEXT,
instance_launch_time TEXT,
bound_backend TEXT NOT NULL,
lease_count INTEGER NOT NULL DEFAULT 0,
last_state_change TEXT NOT NULL,
cooldown_until TEXT,
interruption_pending INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS reservations (
reservation_id TEXT PRIMARY KEY,
system TEXT NOT NULL,
phase TEXT NOT NULL,
slot_id TEXT,
instance_id TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
released_at TEXT,
reason TEXT,
build_id INTEGER
);
CREATE TABLE IF NOT EXISTS events (
event_id INTEGER PRIMARY KEY AUTOINCREMENT,
ts TEXT NOT NULL,
kind TEXT NOT NULL,
payload_json TEXT NOT NULL
);
"""
def _now_iso(clock: Clock | None = None) -> str:
if clock is not None:
return clock.now().isoformat()
return datetime.now(UTC).isoformat()
def _row_to_dict(cursor: sqlite3.Cursor, row: tuple) -> dict: # type: ignore[type-arg]
"""Convert a sqlite3 row to a dict using column names."""
cols = [d[0] for d in cursor.description]
return dict(zip(cols, row, strict=False))
class StateDB:
"""SQLite-backed state store for slots, reservations, and events."""
def __init__(self, db_path: str | Path = ":memory:", clock: Clock | None = None) -> None:
self._conn = sqlite3.connect(str(db_path), check_same_thread=False)
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA busy_timeout=5000")
self._clock = clock
def init_schema(self) -> None:
"""Create tables if they don't exist."""
self._conn.executescript(_SCHEMA)
def init_slots(self, slot_prefix: str, slot_count: int, system: str, backend: str) -> None:
"""Ensure all expected slots exist, creating missing ones as empty."""
now = _now_iso(self._clock)
for i in range(1, slot_count + 1):
slot_id = f"{slot_prefix}{i:03d}"
bound = f"{backend}/{slot_id}"
self._conn.execute(
"""INSERT OR IGNORE INTO slots
(slot_id, system, state, bound_backend, lease_count, last_state_change)
VALUES (?, ?, ?, ?, 0, ?)""",
(slot_id, system, SlotState.EMPTY.value, bound, now),
)
self._conn.commit()
# -- Slot operations ----------------------------------------------------
def get_slot(self, slot_id: str) -> dict | None:
"""Return a slot row as dict, or None."""
cur = self._conn.execute("SELECT * FROM slots WHERE slot_id = ?", (slot_id,))
row = cur.fetchone()
if row is None:
return None
return _row_to_dict(cur, row)
def list_slots(self, state: SlotState | None = None) -> list[dict]:
"""List slots, optionally filtered by state."""
if state is not None:
cur = self._conn.execute(
"SELECT * FROM slots WHERE state = ? ORDER BY slot_id", (state.value,)
)
else:
cur = self._conn.execute("SELECT * FROM slots ORDER BY slot_id")
return [_row_to_dict(cur, row) for row in cur.fetchall()]
def update_slot_state(self, slot_id: str, new_state: SlotState, **fields: object) -> None:
"""Atomically transition a slot to a new state and record an event.
Additional fields (instance_id, instance_ip, etc.) can be passed as kwargs.
"""
now = _now_iso(self._clock)
set_parts = ["state = ?", "last_state_change = ?"]
params: list[object] = [new_state.value, now]
allowed = {
"instance_id",
"instance_ip",
"instance_launch_time",
"lease_count",
"cooldown_until",
"interruption_pending",
}
for k, v in fields.items():
if k not in allowed:
msg = f"Unknown slot field: {k}"
raise ValueError(msg)
set_parts.append(f"{k} = ?")
params.append(v)
params.append(slot_id)
sql = f"UPDATE slots SET {', '.join(set_parts)} WHERE slot_id = ?"
self._conn.execute("BEGIN IMMEDIATE")
try:
self._conn.execute(sql, params)
self._record_event_inner(
"slot_state_change",
{"slot_id": slot_id, "new_state": new_state.value, **fields},
)
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
# -- Reservation operations ---------------------------------------------
def create_reservation(
self,
system: str,
reason: str,
build_id: int | None,
ttl_seconds: int,
) -> dict:
"""Create a new pending reservation. Returns the reservation row as dict."""
now = _now_iso(self._clock)
if self._clock is not None:
expires = (self._clock.now() + timedelta(seconds=ttl_seconds)).isoformat()
else:
expires = (datetime.now(UTC) + timedelta(seconds=ttl_seconds)).isoformat()
rid = f"resv_{uuid.uuid4().hex}"
self._conn.execute("BEGIN IMMEDIATE")
try:
self._conn.execute(
"""INSERT INTO reservations
(reservation_id, system, phase, created_at, updated_at,
expires_at, reason, build_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(rid, system, ReservationPhase.PENDING.value, now, now, expires, reason, build_id),
)
self._record_event_inner(
"reservation_created",
{"reservation_id": rid, "system": system, "reason": reason},
)
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
return self.get_reservation(rid) # type: ignore[return-value]
def get_reservation(self, reservation_id: str) -> dict | None:
"""Return a reservation row as dict, or None."""
cur = self._conn.execute(
"SELECT * FROM reservations WHERE reservation_id = ?", (reservation_id,)
)
row = cur.fetchone()
if row is None:
return None
return _row_to_dict(cur, row)
def list_reservations(self, phase: ReservationPhase | None = None) -> list[dict]:
"""List reservations, optionally filtered by phase."""
if phase is not None:
cur = self._conn.execute(
"SELECT * FROM reservations WHERE phase = ? ORDER BY created_at",
(phase.value,),
)
else:
cur = self._conn.execute("SELECT * FROM reservations ORDER BY created_at")
return [_row_to_dict(cur, row) for row in cur.fetchall()]
def assign_reservation(self, reservation_id: str, slot_id: str, instance_id: str) -> None:
"""Assign a pending reservation to a ready slot.
Atomically: update reservation phase to ready, set slot_id/instance_id,
and increment slot lease_count.
"""
now = _now_iso(self._clock)
self._conn.execute("BEGIN IMMEDIATE")
try:
self._conn.execute(
"""UPDATE reservations
SET phase = ?, slot_id = ?, instance_id = ?, updated_at = ?
WHERE reservation_id = ? AND phase = ?""",
(
ReservationPhase.READY.value,
slot_id,
instance_id,
now,
reservation_id,
ReservationPhase.PENDING.value,
),
)
self._conn.execute(
"UPDATE slots SET lease_count = lease_count + 1 WHERE slot_id = ?",
(slot_id,),
)
self._record_event_inner(
"reservation_assigned",
{
"reservation_id": reservation_id,
"slot_id": slot_id,
"instance_id": instance_id,
},
)
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
def release_reservation(self, reservation_id: str) -> dict | None:
"""Release a reservation, decrementing the slot lease count."""
now = _now_iso(self._clock)
self._conn.execute("BEGIN IMMEDIATE")
try:
cur = self._conn.execute(
"SELECT * FROM reservations WHERE reservation_id = ?",
(reservation_id,),
)
row = cur.fetchone()
if row is None:
self._conn.execute("ROLLBACK")
return None
resv = _row_to_dict(cur, row)
old_phase = resv["phase"]
if old_phase in (ReservationPhase.RELEASED.value, ReservationPhase.EXPIRED.value):
self._conn.execute("ROLLBACK")
return resv
self._conn.execute(
"""UPDATE reservations
SET phase = ?, released_at = ?, updated_at = ?
WHERE reservation_id = ?""",
(ReservationPhase.RELEASED.value, now, now, reservation_id),
)
if resv["slot_id"] and old_phase == ReservationPhase.READY.value:
self._conn.execute(
"""UPDATE slots SET lease_count = MAX(lease_count - 1, 0)
WHERE slot_id = ?""",
(resv["slot_id"],),
)
self._record_event_inner("reservation_released", {"reservation_id": reservation_id})
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
return self.get_reservation(reservation_id)
def expire_reservations(self, now: datetime) -> list[str]:
"""Expire all reservations past their expires_at. Returns expired IDs."""
now_iso = now.isoformat()
expired_ids: list[str] = []
self._conn.execute("BEGIN IMMEDIATE")
try:
cur = self._conn.execute(
"""SELECT reservation_id, slot_id, phase FROM reservations
WHERE phase IN (?, ?) AND expires_at <= ?""",
(ReservationPhase.PENDING.value, ReservationPhase.READY.value, now_iso),
)
rows = cur.fetchall()
for row in rows:
rid, slot_id, phase = row
expired_ids.append(rid)
self._conn.execute(
"""UPDATE reservations
SET phase = ?, updated_at = ?
WHERE reservation_id = ?""",
(ReservationPhase.EXPIRED.value, now_iso, rid),
)
if slot_id and phase == ReservationPhase.READY.value:
self._conn.execute(
"""UPDATE slots SET lease_count = MAX(lease_count - 1, 0)
WHERE slot_id = ?""",
(slot_id,),
)
self._record_event_inner("reservation_expired", {"reservation_id": rid})
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
return expired_ids
# -- Events -------------------------------------------------------------
def record_event(self, kind: str, payload: dict) -> None: # type: ignore[type-arg]
"""Record an audit event."""
self._conn.execute("BEGIN IMMEDIATE")
try:
self._record_event_inner(kind, payload)
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
def _record_event_inner(self, kind: str, payload: dict) -> None: # type: ignore[type-arg]
"""Insert an event row (must be called inside an active transaction)."""
now = _now_iso(self._clock)
self._conn.execute(
"INSERT INTO events (ts, kind, payload_json) VALUES (?, ?, ?)",
(now, kind, json.dumps(payload, default=str)),
)
# -- Summaries ----------------------------------------------------------
def get_state_summary(self) -> dict:
"""Return aggregate slot and reservation counts."""
slot_counts: dict[str, int] = {}
cur = self._conn.execute("SELECT state, COUNT(*) FROM slots GROUP BY state")
for state_val, count in cur.fetchall():
slot_counts[state_val] = count
total_slots = sum(slot_counts.values())
resv_counts: dict[str, int] = {}
cur = self._conn.execute(
"SELECT phase, COUNT(*) FROM reservations WHERE phase IN (?, ?, ?) GROUP BY phase",
(
ReservationPhase.PENDING.value,
ReservationPhase.READY.value,
ReservationPhase.FAILED.value,
),
)
for phase_val, count in cur.fetchall():
resv_counts[phase_val] = count
return {
"slots": {
"total": total_slots,
"ready": slot_counts.get("ready", 0),
"launching": slot_counts.get("launching", 0),
"booting": slot_counts.get("booting", 0),
"binding": slot_counts.get("binding", 0),
"draining": slot_counts.get("draining", 0),
"terminating": slot_counts.get("terminating", 0),
"empty": slot_counts.get("empty", 0),
"error": slot_counts.get("error", 0),
},
"reservations": {
"pending": resv_counts.get("pending", 0),
"ready": resv_counts.get("ready", 0),
"failed": resv_counts.get("failed", 0),
},
}
def close(self) -> None:
"""Close the database connection."""
self._conn.close()

View file

@ -0,0 +1 @@
"""End-to-end integration tests with FakeRuntime — Plan 05."""

View file

@ -0,0 +1 @@
"""HAProxy provider unit tests — Plan 02."""

View file

@ -0,0 +1 @@
"""Reservations API unit tests — Plan 04."""

View file

@ -0,0 +1 @@
"""EC2 runtime unit tests — Plan 02."""

View file

@ -0,0 +1,116 @@
"""Unit tests for the FakeRuntime adapter."""
import contextlib
from nix_builder_autoscaler.runtime.base import RuntimeError as RuntimeAdapterError
from nix_builder_autoscaler.runtime.fake import FakeRuntime
class TestLaunchSpot:
def test_returns_synthetic_instance_id(self):
rt = FakeRuntime()
iid = rt.launch_spot("slot001", "#!/bin/bash\necho hello")
assert iid.startswith("i-fake-")
assert len(iid) > 10
def test_instance_starts_pending(self):
rt = FakeRuntime()
iid = rt.launch_spot("slot001", "")
info = rt.describe_instance(iid)
assert info["state"] == "pending"
assert info["tailscale_ip"] is None
class TestTickProgression:
def test_transitions_to_running_after_configured_ticks(self):
rt = FakeRuntime(launch_latency_ticks=3, ip_delay_ticks=1)
iid = rt.launch_spot("slot001", "")
for _ in range(2):
rt.tick()
assert rt.describe_instance(iid)["state"] == "pending"
rt.tick() # tick 3
assert rt.describe_instance(iid)["state"] == "running"
def test_tailscale_ip_appears_after_configured_delay(self):
rt = FakeRuntime(launch_latency_ticks=2, ip_delay_ticks=2)
iid = rt.launch_spot("slot001", "")
for _ in range(2):
rt.tick()
assert rt.describe_instance(iid)["state"] == "running"
assert rt.describe_instance(iid)["tailscale_ip"] is None
rt.tick() # tick 3 — still no IP (need tick 4)
assert rt.describe_instance(iid)["tailscale_ip"] is None
rt.tick() # tick 4
info = rt.describe_instance(iid)
assert info["tailscale_ip"] is not None
assert info["tailscale_ip"].startswith("100.64.0.")
class TestInjectedFailure:
def test_launch_failure_raises(self):
rt = FakeRuntime()
rt.inject_launch_failure("slot001")
try:
rt.launch_spot("slot001", "")
raise AssertionError("Should have raised")
except RuntimeAdapterError as e:
assert e.category == "capacity_unavailable"
def test_failure_is_one_shot(self):
rt = FakeRuntime()
rt.inject_launch_failure("slot001")
with contextlib.suppress(RuntimeAdapterError):
rt.launch_spot("slot001", "")
# Second call should succeed
iid = rt.launch_spot("slot001", "")
assert iid.startswith("i-fake-")
class TestInjectedInterruption:
def test_interruption_returns_terminated(self):
rt = FakeRuntime(launch_latency_ticks=1)
iid = rt.launch_spot("slot001", "")
rt.tick()
assert rt.describe_instance(iid)["state"] == "running"
rt.inject_interruption(iid)
info = rt.describe_instance(iid)
assert info["state"] == "terminated"
def test_interruption_is_one_shot(self):
"""After the interruption fires, subsequent describes stay terminated."""
rt = FakeRuntime(launch_latency_ticks=1)
iid = rt.launch_spot("slot001", "")
rt.tick()
rt.inject_interruption(iid)
rt.describe_instance(iid) # consumes the injection
info = rt.describe_instance(iid)
assert info["state"] == "terminated"
class TestTerminate:
def test_terminate_marks_instance(self):
rt = FakeRuntime(launch_latency_ticks=1)
iid = rt.launch_spot("slot001", "")
rt.tick()
rt.terminate_instance(iid)
assert rt.describe_instance(iid)["state"] == "terminated"
class TestListManaged:
def test_lists_non_terminated(self):
rt = FakeRuntime(launch_latency_ticks=1)
iid1 = rt.launch_spot("slot001", "")
iid2 = rt.launch_spot("slot002", "")
rt.tick()
rt.terminate_instance(iid1)
managed = rt.list_managed_instances()
ids = [m["instance_id"] for m in managed]
assert iid2 in ids
assert iid1 not in ids

View file

@ -0,0 +1 @@
"""Scheduler unit tests — Plan 03."""