"""SQLite state persistence layer. All write operations use BEGIN IMMEDIATE transactions for crash safety. """ from __future__ import annotations import json import sqlite3 import threading import uuid from datetime import UTC, datetime, timedelta from pathlib import Path from typing import TYPE_CHECKING from .models import ReservationPhase, SlotState if TYPE_CHECKING: from .providers.clock import Clock _SCHEMA = """ CREATE TABLE IF NOT EXISTS slots ( slot_id TEXT PRIMARY KEY, system TEXT NOT NULL, state TEXT NOT NULL, instance_id TEXT, instance_ip TEXT, instance_launch_time TEXT, bound_backend TEXT NOT NULL, lease_count INTEGER NOT NULL DEFAULT 0, last_state_change TEXT NOT NULL, cooldown_until TEXT, interruption_pending INTEGER NOT NULL DEFAULT 0 ); CREATE TABLE IF NOT EXISTS reservations ( reservation_id TEXT PRIMARY KEY, system TEXT NOT NULL, phase TEXT NOT NULL, slot_id TEXT, instance_id TEXT, created_at TEXT NOT NULL, updated_at TEXT NOT NULL, expires_at TEXT NOT NULL, released_at TEXT, reason TEXT, build_id INTEGER ); CREATE TABLE IF NOT EXISTS events ( event_id INTEGER PRIMARY KEY AUTOINCREMENT, ts TEXT NOT NULL, kind TEXT NOT NULL, payload_json TEXT NOT NULL ); """ def _now_iso(clock: Clock | None = None) -> str: if clock is not None: return clock.now().isoformat() return datetime.now(UTC).isoformat() def _row_to_dict(cursor: sqlite3.Cursor, row: tuple) -> dict: # type: ignore[type-arg] """Convert a sqlite3 row to a dict using column names.""" cols = [d[0] for d in cursor.description] return dict(zip(cols, row, strict=False)) class StateDB: """SQLite-backed state store for slots, reservations, and events.""" def __init__(self, db_path: str | Path = ":memory:", clock: Clock | None = None) -> None: self._conn = sqlite3.connect(str(db_path), check_same_thread=False) self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute("PRAGMA busy_timeout=5000") self._clock = clock self._lock = threading.RLock() def init_schema(self) -> None: """Create tables if they don't exist.""" with self._lock: self._conn.executescript(_SCHEMA) def init_slots(self, slot_prefix: str, slot_count: int, system: str, backend: str) -> None: """Ensure all expected slots exist, creating missing ones as empty.""" with self._lock: now = _now_iso(self._clock) for i in range(1, slot_count + 1): slot_id = f"{slot_prefix}{i:03d}" bound = f"{backend}/{slot_id}" self._conn.execute( """INSERT OR IGNORE INTO slots (slot_id, system, state, bound_backend, lease_count, last_state_change) VALUES (?, ?, ?, ?, 0, ?)""", (slot_id, system, SlotState.EMPTY.value, bound, now), ) self._conn.commit() # -- Slot operations ---------------------------------------------------- def get_slot(self, slot_id: str) -> dict | None: """Return a slot row as dict, or None.""" with self._lock: cur = self._conn.execute("SELECT * FROM slots WHERE slot_id = ?", (slot_id,)) row = cur.fetchone() if row is None: return None return _row_to_dict(cur, row) def list_slots(self, state: SlotState | None = None) -> list[dict]: """List slots, optionally filtered by state.""" with self._lock: if state is not None: cur = self._conn.execute( "SELECT * FROM slots WHERE state = ? ORDER BY slot_id", (state.value,) ) else: cur = self._conn.execute("SELECT * FROM slots ORDER BY slot_id") return [_row_to_dict(cur, row) for row in cur.fetchall()] def update_slot_state(self, slot_id: str, new_state: SlotState, **fields: object) -> None: """Atomically transition a slot to a new state and record an event. Additional fields (instance_id, instance_ip, etc.) can be passed as kwargs. """ with self._lock: now = _now_iso(self._clock) set_parts = ["state = ?", "last_state_change = ?"] params: list[object] = [new_state.value, now] allowed = { "instance_id", "instance_ip", "instance_launch_time", "lease_count", "cooldown_until", "interruption_pending", } for k, v in fields.items(): if k not in allowed: msg = f"Unknown slot field: {k}" raise ValueError(msg) set_parts.append(f"{k} = ?") params.append(v) params.append(slot_id) sql = f"UPDATE slots SET {', '.join(set_parts)} WHERE slot_id = ?" self._conn.execute("BEGIN IMMEDIATE") try: self._conn.execute(sql, params) self._record_event_inner( "slot_state_change", {"slot_id": slot_id, "new_state": new_state.value, **fields}, ) self._conn.execute("COMMIT") except Exception: self._conn.execute("ROLLBACK") raise def update_slot_fields(self, slot_id: str, **fields: object) -> None: """Update specific slot columns without changing state or last_state_change. Uses BEGIN IMMEDIATE. Allowed fields: instance_id, instance_ip, instance_launch_time, lease_count, cooldown_until, interruption_pending. """ with self._lock: allowed = { "instance_id", "instance_ip", "instance_launch_time", "lease_count", "cooldown_until", "interruption_pending", } if not fields: return set_parts: list[str] = [] params: list[object] = [] for k, v in fields.items(): if k not in allowed: msg = f"Unknown slot field: {k}" raise ValueError(msg) set_parts.append(f"{k} = ?") params.append(v) params.append(slot_id) sql = f"UPDATE slots SET {', '.join(set_parts)} WHERE slot_id = ?" self._conn.execute("BEGIN IMMEDIATE") try: self._conn.execute(sql, params) self._record_event_inner( "slot_fields_updated", {"slot_id": slot_id, **fields}, ) self._conn.execute("COMMIT") except Exception: self._conn.execute("ROLLBACK") raise # -- Reservation operations --------------------------------------------- def create_reservation( self, system: str, reason: str, build_id: int | None, ttl_seconds: int, ) -> dict: """Create a new pending reservation. Returns the reservation row as dict.""" with self._lock: now = _now_iso(self._clock) if self._clock is not None: expires = (self._clock.now() + timedelta(seconds=ttl_seconds)).isoformat() else: expires = (datetime.now(UTC) + timedelta(seconds=ttl_seconds)).isoformat() rid = f"resv_{uuid.uuid4().hex}" self._conn.execute("BEGIN IMMEDIATE") try: self._conn.execute( """INSERT INTO reservations (reservation_id, system, phase, created_at, updated_at, expires_at, reason, build_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", ( rid, system, ReservationPhase.PENDING.value, now, now, expires, reason, build_id, ), ) self._record_event_inner( "reservation_created", {"reservation_id": rid, "system": system, "reason": reason}, ) self._conn.execute("COMMIT") except Exception: self._conn.execute("ROLLBACK") raise return self.get_reservation(rid) # type: ignore[return-value] def get_reservation(self, reservation_id: str) -> dict | None: """Return a reservation row as dict, or None.""" with self._lock: cur = self._conn.execute( "SELECT * FROM reservations WHERE reservation_id = ?", (reservation_id,) ) row = cur.fetchone() if row is None: return None return _row_to_dict(cur, row) def list_reservations(self, phase: ReservationPhase | None = None) -> list[dict]: """List reservations, optionally filtered by phase.""" with self._lock: if phase is not None: cur = self._conn.execute( "SELECT * FROM reservations WHERE phase = ? ORDER BY created_at", (phase.value,), ) else: cur = self._conn.execute("SELECT * FROM reservations ORDER BY created_at") return [_row_to_dict(cur, row) for row in cur.fetchall()] def assign_reservation(self, reservation_id: str, slot_id: str, instance_id: str) -> None: """Assign a pending reservation to a ready slot. Atomically: update reservation phase to ready, set slot_id/instance_id, and increment slot lease_count. """ with self._lock: now = _now_iso(self._clock) self._conn.execute("BEGIN IMMEDIATE") try: self._conn.execute( """UPDATE reservations SET phase = ?, slot_id = ?, instance_id = ?, updated_at = ? WHERE reservation_id = ? AND phase = ?""", ( ReservationPhase.READY.value, slot_id, instance_id, now, reservation_id, ReservationPhase.PENDING.value, ), ) self._conn.execute( "UPDATE slots SET lease_count = lease_count + 1 WHERE slot_id = ?", (slot_id,), ) self._record_event_inner( "reservation_assigned", { "reservation_id": reservation_id, "slot_id": slot_id, "instance_id": instance_id, }, ) self._conn.execute("COMMIT") except Exception: self._conn.execute("ROLLBACK") raise def release_reservation(self, reservation_id: str) -> dict | None: """Release a reservation, decrementing the slot lease count.""" with self._lock: now = _now_iso(self._clock) self._conn.execute("BEGIN IMMEDIATE") try: cur = self._conn.execute( "SELECT * FROM reservations WHERE reservation_id = ?", (reservation_id,), ) row = cur.fetchone() if row is None: self._conn.execute("ROLLBACK") return None resv = _row_to_dict(cur, row) old_phase = resv["phase"] if old_phase in (ReservationPhase.RELEASED.value, ReservationPhase.EXPIRED.value): self._conn.execute("ROLLBACK") return resv self._conn.execute( """UPDATE reservations SET phase = ?, released_at = ?, updated_at = ? WHERE reservation_id = ?""", (ReservationPhase.RELEASED.value, now, now, reservation_id), ) if resv["slot_id"] and old_phase == ReservationPhase.READY.value: self._conn.execute( """UPDATE slots SET lease_count = MAX(lease_count - 1, 0) WHERE slot_id = ?""", (resv["slot_id"],), ) self._record_event_inner("reservation_released", {"reservation_id": reservation_id}) self._conn.execute("COMMIT") except Exception: self._conn.execute("ROLLBACK") raise return self.get_reservation(reservation_id) def expire_reservations(self, now: datetime) -> list[str]: """Expire all reservations past their expires_at. Returns expired IDs.""" with self._lock: now_iso = now.isoformat() expired_ids: list[str] = [] self._conn.execute("BEGIN IMMEDIATE") try: cur = self._conn.execute( """SELECT reservation_id, slot_id, phase FROM reservations WHERE phase IN (?, ?) AND expires_at <= ?""", (ReservationPhase.PENDING.value, ReservationPhase.READY.value, now_iso), ) rows = cur.fetchall() for row in rows: rid, slot_id, phase = row expired_ids.append(rid) self._conn.execute( """UPDATE reservations SET phase = ?, updated_at = ? WHERE reservation_id = ?""", (ReservationPhase.EXPIRED.value, now_iso, rid), ) if slot_id and phase == ReservationPhase.READY.value: self._conn.execute( """UPDATE slots SET lease_count = MAX(lease_count - 1, 0) WHERE slot_id = ?""", (slot_id,), ) self._record_event_inner("reservation_expired", {"reservation_id": rid}) self._conn.execute("COMMIT") except Exception: self._conn.execute("ROLLBACK") raise return expired_ids # -- Events ------------------------------------------------------------- def record_event(self, kind: str, payload: dict) -> None: # type: ignore[type-arg] """Record an audit event.""" with self._lock: self._conn.execute("BEGIN IMMEDIATE") try: self._record_event_inner(kind, payload) self._conn.execute("COMMIT") except Exception: self._conn.execute("ROLLBACK") raise def _record_event_inner(self, kind: str, payload: dict) -> None: # type: ignore[type-arg] """Insert an event row (must be called inside an active transaction).""" with self._lock: now = _now_iso(self._clock) self._conn.execute( "INSERT INTO events (ts, kind, payload_json) VALUES (?, ?, ?)", (now, kind, json.dumps(payload, default=str)), ) # -- Summaries ---------------------------------------------------------- def get_state_summary(self) -> dict: """Return aggregate slot and reservation counts.""" with self._lock: slot_counts: dict[str, int] = {} cur = self._conn.execute("SELECT state, COUNT(*) FROM slots GROUP BY state") for state_val, count in cur.fetchall(): slot_counts[state_val] = count total_slots = sum(slot_counts.values()) resv_counts: dict[str, int] = {} cur = self._conn.execute( "SELECT phase, COUNT(*) FROM reservations WHERE phase IN (?, ?, ?) GROUP BY phase", ( ReservationPhase.PENDING.value, ReservationPhase.READY.value, ReservationPhase.FAILED.value, ), ) for phase_val, count in cur.fetchall(): resv_counts[phase_val] = count return { "slots": { "total": total_slots, "ready": slot_counts.get("ready", 0), "launching": slot_counts.get("launching", 0), "booting": slot_counts.get("booting", 0), "binding": slot_counts.get("binding", 0), "draining": slot_counts.get("draining", 0), "terminating": slot_counts.get("terminating", 0), "empty": slot_counts.get("empty", 0), "error": slot_counts.get("error", 0), }, "reservations": { "pending": resv_counts.get("pending", 0), "ready": resv_counts.get("ready", 0), "failed": resv_counts.get("failed", 0), }, } def close(self) -> None: """Close the database connection.""" with self._lock: self._conn.close()