nix-builder-autoscaler/agent/nix_builder_autoscaler/state_db.py

469 lines
17 KiB
Python
Raw Permalink Normal View History

2026-02-27 11:59:16 +01:00
"""SQLite state persistence layer.
All write operations use BEGIN IMMEDIATE transactions for crash safety.
"""
from __future__ import annotations
import json
import sqlite3
2026-02-27 13:48:52 +01:00
import threading
2026-02-27 11:59:16 +01:00
import uuid
from datetime import UTC, datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING
from .models import ReservationPhase, SlotState
if TYPE_CHECKING:
from .providers.clock import Clock
_SCHEMA = """
CREATE TABLE IF NOT EXISTS slots (
slot_id TEXT PRIMARY KEY,
system TEXT NOT NULL,
state TEXT NOT NULL,
instance_id TEXT,
instance_ip TEXT,
instance_launch_time TEXT,
bound_backend TEXT NOT NULL,
lease_count INTEGER NOT NULL DEFAULT 0,
last_state_change TEXT NOT NULL,
cooldown_until TEXT,
interruption_pending INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS reservations (
reservation_id TEXT PRIMARY KEY,
system TEXT NOT NULL,
phase TEXT NOT NULL,
slot_id TEXT,
instance_id TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
released_at TEXT,
reason TEXT,
build_id INTEGER
);
CREATE TABLE IF NOT EXISTS events (
event_id INTEGER PRIMARY KEY AUTOINCREMENT,
ts TEXT NOT NULL,
kind TEXT NOT NULL,
payload_json TEXT NOT NULL
);
"""
def _now_iso(clock: Clock | None = None) -> str:
if clock is not None:
return clock.now().isoformat()
return datetime.now(UTC).isoformat()
def _row_to_dict(cursor: sqlite3.Cursor, row: tuple) -> dict: # type: ignore[type-arg]
"""Convert a sqlite3 row to a dict using column names."""
cols = [d[0] for d in cursor.description]
return dict(zip(cols, row, strict=False))
class StateDB:
"""SQLite-backed state store for slots, reservations, and events."""
def __init__(self, db_path: str | Path = ":memory:", clock: Clock | None = None) -> None:
self._conn = sqlite3.connect(str(db_path), check_same_thread=False)
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA busy_timeout=5000")
self._clock = clock
2026-02-27 13:48:52 +01:00
self._lock = threading.RLock()
2026-02-27 11:59:16 +01:00
def init_schema(self) -> None:
"""Create tables if they don't exist."""
2026-02-27 13:48:52 +01:00
with self._lock:
self._conn.executescript(_SCHEMA)
2026-02-27 11:59:16 +01:00
def init_slots(self, slot_prefix: str, slot_count: int, system: str, backend: str) -> None:
"""Ensure all expected slots exist, creating missing ones as empty."""
2026-02-27 13:48:52 +01:00
with self._lock:
now = _now_iso(self._clock)
for i in range(1, slot_count + 1):
slot_id = f"{slot_prefix}{i:03d}"
bound = f"{backend}/{slot_id}"
self._conn.execute(
"""INSERT OR IGNORE INTO slots
(slot_id, system, state, bound_backend, lease_count, last_state_change)
VALUES (?, ?, ?, ?, 0, ?)""",
(slot_id, system, SlotState.EMPTY.value, bound, now),
)
self._conn.commit()
2026-02-27 11:59:16 +01:00
# -- Slot operations ----------------------------------------------------
def get_slot(self, slot_id: str) -> dict | None:
"""Return a slot row as dict, or None."""
2026-02-27 13:48:52 +01:00
with self._lock:
cur = self._conn.execute("SELECT * FROM slots WHERE slot_id = ?", (slot_id,))
row = cur.fetchone()
if row is None:
return None
return _row_to_dict(cur, row)
2026-02-27 11:59:16 +01:00
def list_slots(self, state: SlotState | None = None) -> list[dict]:
"""List slots, optionally filtered by state."""
2026-02-27 13:48:52 +01:00
with self._lock:
if state is not None:
cur = self._conn.execute(
"SELECT * FROM slots WHERE state = ? ORDER BY slot_id", (state.value,)
)
else:
cur = self._conn.execute("SELECT * FROM slots ORDER BY slot_id")
return [_row_to_dict(cur, row) for row in cur.fetchall()]
2026-02-27 11:59:16 +01:00
def update_slot_state(self, slot_id: str, new_state: SlotState, **fields: object) -> None:
"""Atomically transition a slot to a new state and record an event.
Additional fields (instance_id, instance_ip, etc.) can be passed as kwargs.
"""
2026-02-27 13:48:52 +01:00
with self._lock:
now = _now_iso(self._clock)
set_parts = ["state = ?", "last_state_change = ?"]
params: list[object] = [new_state.value, now]
allowed = {
"instance_id",
"instance_ip",
"instance_launch_time",
"lease_count",
"cooldown_until",
"interruption_pending",
}
for k, v in fields.items():
if k not in allowed:
msg = f"Unknown slot field: {k}"
raise ValueError(msg)
set_parts.append(f"{k} = ?")
params.append(v)
params.append(slot_id)
sql = f"UPDATE slots SET {', '.join(set_parts)} WHERE slot_id = ?"
self._conn.execute("BEGIN IMMEDIATE")
try:
self._conn.execute(sql, params)
self._record_event_inner(
"slot_state_change",
{"slot_id": slot_id, "new_state": new_state.value, **fields},
)
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
2026-02-27 11:59:16 +01:00
def update_slot_fields(self, slot_id: str, **fields: object) -> None:
"""Update specific slot columns without changing state or last_state_change.
Uses BEGIN IMMEDIATE. Allowed fields: instance_id, instance_ip,
instance_launch_time, lease_count, cooldown_until, interruption_pending.
"""
2026-02-27 13:48:52 +01:00
with self._lock:
allowed = {
"instance_id",
"instance_ip",
"instance_launch_time",
"lease_count",
"cooldown_until",
"interruption_pending",
}
if not fields:
return
set_parts: list[str] = []
params: list[object] = []
for k, v in fields.items():
if k not in allowed:
msg = f"Unknown slot field: {k}"
raise ValueError(msg)
set_parts.append(f"{k} = ?")
params.append(v)
params.append(slot_id)
sql = f"UPDATE slots SET {', '.join(set_parts)} WHERE slot_id = ?"
self._conn.execute("BEGIN IMMEDIATE")
try:
self._conn.execute(sql, params)
self._record_event_inner(
"slot_fields_updated",
{"slot_id": slot_id, **fields},
)
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
2026-02-27 11:59:16 +01:00
# -- Reservation operations ---------------------------------------------
def create_reservation(
self,
system: str,
reason: str,
build_id: int | None,
ttl_seconds: int,
) -> dict:
"""Create a new pending reservation. Returns the reservation row as dict."""
2026-02-27 13:48:52 +01:00
with self._lock:
now = _now_iso(self._clock)
if self._clock is not None:
expires = (self._clock.now() + timedelta(seconds=ttl_seconds)).isoformat()
else:
expires = (datetime.now(UTC) + timedelta(seconds=ttl_seconds)).isoformat()
rid = f"resv_{uuid.uuid4().hex}"
self._conn.execute("BEGIN IMMEDIATE")
try:
self._conn.execute(
"""INSERT INTO reservations
(reservation_id, system, phase, created_at, updated_at,
expires_at, reason, build_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(
rid,
system,
ReservationPhase.PENDING.value,
now,
now,
expires,
reason,
build_id,
),
)
self._record_event_inner(
"reservation_created",
{"reservation_id": rid, "system": system, "reason": reason},
)
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
2026-02-27 11:59:16 +01:00
2026-02-27 13:48:52 +01:00
return self.get_reservation(rid) # type: ignore[return-value]
2026-02-27 11:59:16 +01:00
def get_reservation(self, reservation_id: str) -> dict | None:
"""Return a reservation row as dict, or None."""
2026-02-27 13:48:52 +01:00
with self._lock:
cur = self._conn.execute(
"SELECT * FROM reservations WHERE reservation_id = ?", (reservation_id,)
)
row = cur.fetchone()
if row is None:
return None
return _row_to_dict(cur, row)
2026-02-27 11:59:16 +01:00
def list_reservations(self, phase: ReservationPhase | None = None) -> list[dict]:
"""List reservations, optionally filtered by phase."""
2026-02-27 13:48:52 +01:00
with self._lock:
if phase is not None:
cur = self._conn.execute(
"SELECT * FROM reservations WHERE phase = ? ORDER BY created_at",
(phase.value,),
)
else:
cur = self._conn.execute("SELECT * FROM reservations ORDER BY created_at")
return [_row_to_dict(cur, row) for row in cur.fetchall()]
2026-02-27 11:59:16 +01:00
def assign_reservation(self, reservation_id: str, slot_id: str, instance_id: str) -> None:
"""Assign a pending reservation to a ready slot.
Atomically: update reservation phase to ready, set slot_id/instance_id,
and increment slot lease_count.
"""
2026-02-27 13:48:52 +01:00
with self._lock:
now = _now_iso(self._clock)
2026-02-27 11:59:16 +01:00
2026-02-27 13:48:52 +01:00
self._conn.execute("BEGIN IMMEDIATE")
try:
self._conn.execute(
"""UPDATE reservations
SET phase = ?, slot_id = ?, instance_id = ?, updated_at = ?
WHERE reservation_id = ? AND phase = ?""",
(
ReservationPhase.READY.value,
slot_id,
instance_id,
now,
reservation_id,
ReservationPhase.PENDING.value,
),
)
self._conn.execute(
"UPDATE slots SET lease_count = lease_count + 1 WHERE slot_id = ?",
(slot_id,),
)
self._record_event_inner(
"reservation_assigned",
{
"reservation_id": reservation_id,
"slot_id": slot_id,
"instance_id": instance_id,
},
)
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
2026-02-27 11:59:16 +01:00
def release_reservation(self, reservation_id: str) -> dict | None:
"""Release a reservation, decrementing the slot lease count."""
2026-02-27 13:48:52 +01:00
with self._lock:
now = _now_iso(self._clock)
self._conn.execute("BEGIN IMMEDIATE")
try:
cur = self._conn.execute(
"SELECT * FROM reservations WHERE reservation_id = ?",
(reservation_id,),
2026-02-27 11:59:16 +01:00
)
2026-02-27 13:48:52 +01:00
row = cur.fetchone()
if row is None:
self._conn.execute("ROLLBACK")
return None
2026-02-27 11:59:16 +01:00
2026-02-27 13:48:52 +01:00
resv = _row_to_dict(cur, row)
old_phase = resv["phase"]
2026-02-27 11:59:16 +01:00
2026-02-27 13:48:52 +01:00
if old_phase in (ReservationPhase.RELEASED.value, ReservationPhase.EXPIRED.value):
self._conn.execute("ROLLBACK")
return resv
2026-02-27 11:59:16 +01:00
self._conn.execute(
"""UPDATE reservations
2026-02-27 13:48:52 +01:00
SET phase = ?, released_at = ?, updated_at = ?
2026-02-27 11:59:16 +01:00
WHERE reservation_id = ?""",
2026-02-27 13:48:52 +01:00
(ReservationPhase.RELEASED.value, now, now, reservation_id),
2026-02-27 11:59:16 +01:00
)
2026-02-27 13:48:52 +01:00
if resv["slot_id"] and old_phase == ReservationPhase.READY.value:
2026-02-27 11:59:16 +01:00
self._conn.execute(
"""UPDATE slots SET lease_count = MAX(lease_count - 1, 0)
WHERE slot_id = ?""",
2026-02-27 13:48:52 +01:00
(resv["slot_id"],),
2026-02-27 11:59:16 +01:00
)
2026-02-27 13:48:52 +01:00
self._record_event_inner("reservation_released", {"reservation_id": reservation_id})
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
return self.get_reservation(reservation_id)
def expire_reservations(self, now: datetime) -> list[str]:
"""Expire all reservations past their expires_at. Returns expired IDs."""
with self._lock:
now_iso = now.isoformat()
expired_ids: list[str] = []
self._conn.execute("BEGIN IMMEDIATE")
try:
cur = self._conn.execute(
"""SELECT reservation_id, slot_id, phase FROM reservations
WHERE phase IN (?, ?) AND expires_at <= ?""",
(ReservationPhase.PENDING.value, ReservationPhase.READY.value, now_iso),
)
rows = cur.fetchall()
2026-02-27 11:59:16 +01:00
2026-02-27 13:48:52 +01:00
for row in rows:
rid, slot_id, phase = row
expired_ids.append(rid)
self._conn.execute(
"""UPDATE reservations
SET phase = ?, updated_at = ?
WHERE reservation_id = ?""",
(ReservationPhase.EXPIRED.value, now_iso, rid),
)
if slot_id and phase == ReservationPhase.READY.value:
self._conn.execute(
"""UPDATE slots SET lease_count = MAX(lease_count - 1, 0)
WHERE slot_id = ?""",
(slot_id,),
)
self._record_event_inner("reservation_expired", {"reservation_id": rid})
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
return expired_ids
2026-02-27 11:59:16 +01:00
# -- Events -------------------------------------------------------------
def record_event(self, kind: str, payload: dict) -> None: # type: ignore[type-arg]
"""Record an audit event."""
2026-02-27 13:48:52 +01:00
with self._lock:
self._conn.execute("BEGIN IMMEDIATE")
try:
self._record_event_inner(kind, payload)
self._conn.execute("COMMIT")
except Exception:
self._conn.execute("ROLLBACK")
raise
2026-02-27 11:59:16 +01:00
def _record_event_inner(self, kind: str, payload: dict) -> None: # type: ignore[type-arg]
"""Insert an event row (must be called inside an active transaction)."""
2026-02-27 13:48:52 +01:00
with self._lock:
now = _now_iso(self._clock)
self._conn.execute(
"INSERT INTO events (ts, kind, payload_json) VALUES (?, ?, ?)",
(now, kind, json.dumps(payload, default=str)),
)
2026-02-27 11:59:16 +01:00
# -- Summaries ----------------------------------------------------------
def get_state_summary(self) -> dict:
"""Return aggregate slot and reservation counts."""
2026-02-27 13:48:52 +01:00
with self._lock:
slot_counts: dict[str, int] = {}
cur = self._conn.execute("SELECT state, COUNT(*) FROM slots GROUP BY state")
for state_val, count in cur.fetchall():
slot_counts[state_val] = count
total_slots = sum(slot_counts.values())
resv_counts: dict[str, int] = {}
cur = self._conn.execute(
"SELECT phase, COUNT(*) FROM reservations WHERE phase IN (?, ?, ?) GROUP BY phase",
(
ReservationPhase.PENDING.value,
ReservationPhase.READY.value,
ReservationPhase.FAILED.value,
),
)
for phase_val, count in cur.fetchall():
resv_counts[phase_val] = count
return {
"slots": {
"total": total_slots,
"ready": slot_counts.get("ready", 0),
"launching": slot_counts.get("launching", 0),
"booting": slot_counts.get("booting", 0),
"binding": slot_counts.get("binding", 0),
"draining": slot_counts.get("draining", 0),
"terminating": slot_counts.get("terminating", 0),
"empty": slot_counts.get("empty", 0),
"error": slot_counts.get("error", 0),
},
"reservations": {
"pending": resv_counts.get("pending", 0),
"ready": resv_counts.get("ready", 0),
"failed": resv_counts.get("failed", 0),
},
}
2026-02-27 11:59:16 +01:00
def close(self) -> None:
"""Close the database connection."""
2026-02-27 13:48:52 +01:00
with self._lock:
self._conn.close()