400 lines
14 KiB
Python
400 lines
14 KiB
Python
"""SQLite state persistence layer.
|
|
|
|
All write operations use BEGIN IMMEDIATE transactions for crash safety.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import sqlite3
|
|
import uuid
|
|
from datetime import UTC, datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
from .models import ReservationPhase, SlotState
|
|
|
|
if TYPE_CHECKING:
|
|
from .providers.clock import Clock
|
|
|
|
_SCHEMA = """
|
|
CREATE TABLE IF NOT EXISTS slots (
|
|
slot_id TEXT PRIMARY KEY,
|
|
system TEXT NOT NULL,
|
|
state TEXT NOT NULL,
|
|
instance_id TEXT,
|
|
instance_ip TEXT,
|
|
instance_launch_time TEXT,
|
|
bound_backend TEXT NOT NULL,
|
|
lease_count INTEGER NOT NULL DEFAULT 0,
|
|
last_state_change TEXT NOT NULL,
|
|
cooldown_until TEXT,
|
|
interruption_pending INTEGER NOT NULL DEFAULT 0
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS reservations (
|
|
reservation_id TEXT PRIMARY KEY,
|
|
system TEXT NOT NULL,
|
|
phase TEXT NOT NULL,
|
|
slot_id TEXT,
|
|
instance_id TEXT,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL,
|
|
expires_at TEXT NOT NULL,
|
|
released_at TEXT,
|
|
reason TEXT,
|
|
build_id INTEGER
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS events (
|
|
event_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
ts TEXT NOT NULL,
|
|
kind TEXT NOT NULL,
|
|
payload_json TEXT NOT NULL
|
|
);
|
|
"""
|
|
|
|
|
|
def _now_iso(clock: Clock | None = None) -> str:
|
|
if clock is not None:
|
|
return clock.now().isoformat()
|
|
return datetime.now(UTC).isoformat()
|
|
|
|
|
|
def _row_to_dict(cursor: sqlite3.Cursor, row: tuple) -> dict: # type: ignore[type-arg]
|
|
"""Convert a sqlite3 row to a dict using column names."""
|
|
cols = [d[0] for d in cursor.description]
|
|
return dict(zip(cols, row, strict=False))
|
|
|
|
|
|
class StateDB:
|
|
"""SQLite-backed state store for slots, reservations, and events."""
|
|
|
|
def __init__(self, db_path: str | Path = ":memory:", clock: Clock | None = None) -> None:
|
|
self._conn = sqlite3.connect(str(db_path), check_same_thread=False)
|
|
self._conn.execute("PRAGMA journal_mode=WAL")
|
|
self._conn.execute("PRAGMA busy_timeout=5000")
|
|
self._clock = clock
|
|
|
|
def init_schema(self) -> None:
|
|
"""Create tables if they don't exist."""
|
|
self._conn.executescript(_SCHEMA)
|
|
|
|
def init_slots(self, slot_prefix: str, slot_count: int, system: str, backend: str) -> None:
|
|
"""Ensure all expected slots exist, creating missing ones as empty."""
|
|
now = _now_iso(self._clock)
|
|
for i in range(1, slot_count + 1):
|
|
slot_id = f"{slot_prefix}{i:03d}"
|
|
bound = f"{backend}/{slot_id}"
|
|
self._conn.execute(
|
|
"""INSERT OR IGNORE INTO slots
|
|
(slot_id, system, state, bound_backend, lease_count, last_state_change)
|
|
VALUES (?, ?, ?, ?, 0, ?)""",
|
|
(slot_id, system, SlotState.EMPTY.value, bound, now),
|
|
)
|
|
self._conn.commit()
|
|
|
|
# -- Slot operations ----------------------------------------------------
|
|
|
|
def get_slot(self, slot_id: str) -> dict | None:
|
|
"""Return a slot row as dict, or None."""
|
|
cur = self._conn.execute("SELECT * FROM slots WHERE slot_id = ?", (slot_id,))
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
return None
|
|
return _row_to_dict(cur, row)
|
|
|
|
def list_slots(self, state: SlotState | None = None) -> list[dict]:
|
|
"""List slots, optionally filtered by state."""
|
|
if state is not None:
|
|
cur = self._conn.execute(
|
|
"SELECT * FROM slots WHERE state = ? ORDER BY slot_id", (state.value,)
|
|
)
|
|
else:
|
|
cur = self._conn.execute("SELECT * FROM slots ORDER BY slot_id")
|
|
return [_row_to_dict(cur, row) for row in cur.fetchall()]
|
|
|
|
def update_slot_state(self, slot_id: str, new_state: SlotState, **fields: object) -> None:
|
|
"""Atomically transition a slot to a new state and record an event.
|
|
|
|
Additional fields (instance_id, instance_ip, etc.) can be passed as kwargs.
|
|
"""
|
|
now = _now_iso(self._clock)
|
|
set_parts = ["state = ?", "last_state_change = ?"]
|
|
params: list[object] = [new_state.value, now]
|
|
|
|
allowed = {
|
|
"instance_id",
|
|
"instance_ip",
|
|
"instance_launch_time",
|
|
"lease_count",
|
|
"cooldown_until",
|
|
"interruption_pending",
|
|
}
|
|
for k, v in fields.items():
|
|
if k not in allowed:
|
|
msg = f"Unknown slot field: {k}"
|
|
raise ValueError(msg)
|
|
set_parts.append(f"{k} = ?")
|
|
params.append(v)
|
|
|
|
params.append(slot_id)
|
|
sql = f"UPDATE slots SET {', '.join(set_parts)} WHERE slot_id = ?"
|
|
|
|
self._conn.execute("BEGIN IMMEDIATE")
|
|
try:
|
|
self._conn.execute(sql, params)
|
|
self._record_event_inner(
|
|
"slot_state_change",
|
|
{"slot_id": slot_id, "new_state": new_state.value, **fields},
|
|
)
|
|
self._conn.execute("COMMIT")
|
|
except Exception:
|
|
self._conn.execute("ROLLBACK")
|
|
raise
|
|
|
|
# -- Reservation operations ---------------------------------------------
|
|
|
|
def create_reservation(
|
|
self,
|
|
system: str,
|
|
reason: str,
|
|
build_id: int | None,
|
|
ttl_seconds: int,
|
|
) -> dict:
|
|
"""Create a new pending reservation. Returns the reservation row as dict."""
|
|
now = _now_iso(self._clock)
|
|
if self._clock is not None:
|
|
expires = (self._clock.now() + timedelta(seconds=ttl_seconds)).isoformat()
|
|
else:
|
|
expires = (datetime.now(UTC) + timedelta(seconds=ttl_seconds)).isoformat()
|
|
rid = f"resv_{uuid.uuid4().hex}"
|
|
|
|
self._conn.execute("BEGIN IMMEDIATE")
|
|
try:
|
|
self._conn.execute(
|
|
"""INSERT INTO reservations
|
|
(reservation_id, system, phase, created_at, updated_at,
|
|
expires_at, reason, build_id)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
|
(rid, system, ReservationPhase.PENDING.value, now, now, expires, reason, build_id),
|
|
)
|
|
self._record_event_inner(
|
|
"reservation_created",
|
|
{"reservation_id": rid, "system": system, "reason": reason},
|
|
)
|
|
self._conn.execute("COMMIT")
|
|
except Exception:
|
|
self._conn.execute("ROLLBACK")
|
|
raise
|
|
|
|
return self.get_reservation(rid) # type: ignore[return-value]
|
|
|
|
def get_reservation(self, reservation_id: str) -> dict | None:
|
|
"""Return a reservation row as dict, or None."""
|
|
cur = self._conn.execute(
|
|
"SELECT * FROM reservations WHERE reservation_id = ?", (reservation_id,)
|
|
)
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
return None
|
|
return _row_to_dict(cur, row)
|
|
|
|
def list_reservations(self, phase: ReservationPhase | None = None) -> list[dict]:
|
|
"""List reservations, optionally filtered by phase."""
|
|
if phase is not None:
|
|
cur = self._conn.execute(
|
|
"SELECT * FROM reservations WHERE phase = ? ORDER BY created_at",
|
|
(phase.value,),
|
|
)
|
|
else:
|
|
cur = self._conn.execute("SELECT * FROM reservations ORDER BY created_at")
|
|
return [_row_to_dict(cur, row) for row in cur.fetchall()]
|
|
|
|
def assign_reservation(self, reservation_id: str, slot_id: str, instance_id: str) -> None:
|
|
"""Assign a pending reservation to a ready slot.
|
|
|
|
Atomically: update reservation phase to ready, set slot_id/instance_id,
|
|
and increment slot lease_count.
|
|
"""
|
|
now = _now_iso(self._clock)
|
|
|
|
self._conn.execute("BEGIN IMMEDIATE")
|
|
try:
|
|
self._conn.execute(
|
|
"""UPDATE reservations
|
|
SET phase = ?, slot_id = ?, instance_id = ?, updated_at = ?
|
|
WHERE reservation_id = ? AND phase = ?""",
|
|
(
|
|
ReservationPhase.READY.value,
|
|
slot_id,
|
|
instance_id,
|
|
now,
|
|
reservation_id,
|
|
ReservationPhase.PENDING.value,
|
|
),
|
|
)
|
|
self._conn.execute(
|
|
"UPDATE slots SET lease_count = lease_count + 1 WHERE slot_id = ?",
|
|
(slot_id,),
|
|
)
|
|
self._record_event_inner(
|
|
"reservation_assigned",
|
|
{
|
|
"reservation_id": reservation_id,
|
|
"slot_id": slot_id,
|
|
"instance_id": instance_id,
|
|
},
|
|
)
|
|
self._conn.execute("COMMIT")
|
|
except Exception:
|
|
self._conn.execute("ROLLBACK")
|
|
raise
|
|
|
|
def release_reservation(self, reservation_id: str) -> dict | None:
|
|
"""Release a reservation, decrementing the slot lease count."""
|
|
now = _now_iso(self._clock)
|
|
|
|
self._conn.execute("BEGIN IMMEDIATE")
|
|
try:
|
|
cur = self._conn.execute(
|
|
"SELECT * FROM reservations WHERE reservation_id = ?",
|
|
(reservation_id,),
|
|
)
|
|
row = cur.fetchone()
|
|
if row is None:
|
|
self._conn.execute("ROLLBACK")
|
|
return None
|
|
|
|
resv = _row_to_dict(cur, row)
|
|
old_phase = resv["phase"]
|
|
|
|
if old_phase in (ReservationPhase.RELEASED.value, ReservationPhase.EXPIRED.value):
|
|
self._conn.execute("ROLLBACK")
|
|
return resv
|
|
|
|
self._conn.execute(
|
|
"""UPDATE reservations
|
|
SET phase = ?, released_at = ?, updated_at = ?
|
|
WHERE reservation_id = ?""",
|
|
(ReservationPhase.RELEASED.value, now, now, reservation_id),
|
|
)
|
|
|
|
if resv["slot_id"] and old_phase == ReservationPhase.READY.value:
|
|
self._conn.execute(
|
|
"""UPDATE slots SET lease_count = MAX(lease_count - 1, 0)
|
|
WHERE slot_id = ?""",
|
|
(resv["slot_id"],),
|
|
)
|
|
|
|
self._record_event_inner("reservation_released", {"reservation_id": reservation_id})
|
|
self._conn.execute("COMMIT")
|
|
except Exception:
|
|
self._conn.execute("ROLLBACK")
|
|
raise
|
|
|
|
return self.get_reservation(reservation_id)
|
|
|
|
def expire_reservations(self, now: datetime) -> list[str]:
|
|
"""Expire all reservations past their expires_at. Returns expired IDs."""
|
|
now_iso = now.isoformat()
|
|
expired_ids: list[str] = []
|
|
|
|
self._conn.execute("BEGIN IMMEDIATE")
|
|
try:
|
|
cur = self._conn.execute(
|
|
"""SELECT reservation_id, slot_id, phase FROM reservations
|
|
WHERE phase IN (?, ?) AND expires_at <= ?""",
|
|
(ReservationPhase.PENDING.value, ReservationPhase.READY.value, now_iso),
|
|
)
|
|
rows = cur.fetchall()
|
|
|
|
for row in rows:
|
|
rid, slot_id, phase = row
|
|
expired_ids.append(rid)
|
|
self._conn.execute(
|
|
"""UPDATE reservations
|
|
SET phase = ?, updated_at = ?
|
|
WHERE reservation_id = ?""",
|
|
(ReservationPhase.EXPIRED.value, now_iso, rid),
|
|
)
|
|
if slot_id and phase == ReservationPhase.READY.value:
|
|
self._conn.execute(
|
|
"""UPDATE slots SET lease_count = MAX(lease_count - 1, 0)
|
|
WHERE slot_id = ?""",
|
|
(slot_id,),
|
|
)
|
|
self._record_event_inner("reservation_expired", {"reservation_id": rid})
|
|
|
|
self._conn.execute("COMMIT")
|
|
except Exception:
|
|
self._conn.execute("ROLLBACK")
|
|
raise
|
|
|
|
return expired_ids
|
|
|
|
# -- Events -------------------------------------------------------------
|
|
|
|
def record_event(self, kind: str, payload: dict) -> None: # type: ignore[type-arg]
|
|
"""Record an audit event."""
|
|
self._conn.execute("BEGIN IMMEDIATE")
|
|
try:
|
|
self._record_event_inner(kind, payload)
|
|
self._conn.execute("COMMIT")
|
|
except Exception:
|
|
self._conn.execute("ROLLBACK")
|
|
raise
|
|
|
|
def _record_event_inner(self, kind: str, payload: dict) -> None: # type: ignore[type-arg]
|
|
"""Insert an event row (must be called inside an active transaction)."""
|
|
now = _now_iso(self._clock)
|
|
self._conn.execute(
|
|
"INSERT INTO events (ts, kind, payload_json) VALUES (?, ?, ?)",
|
|
(now, kind, json.dumps(payload, default=str)),
|
|
)
|
|
|
|
# -- Summaries ----------------------------------------------------------
|
|
|
|
def get_state_summary(self) -> dict:
|
|
"""Return aggregate slot and reservation counts."""
|
|
slot_counts: dict[str, int] = {}
|
|
cur = self._conn.execute("SELECT state, COUNT(*) FROM slots GROUP BY state")
|
|
for state_val, count in cur.fetchall():
|
|
slot_counts[state_val] = count
|
|
|
|
total_slots = sum(slot_counts.values())
|
|
|
|
resv_counts: dict[str, int] = {}
|
|
cur = self._conn.execute(
|
|
"SELECT phase, COUNT(*) FROM reservations WHERE phase IN (?, ?, ?) GROUP BY phase",
|
|
(
|
|
ReservationPhase.PENDING.value,
|
|
ReservationPhase.READY.value,
|
|
ReservationPhase.FAILED.value,
|
|
),
|
|
)
|
|
for phase_val, count in cur.fetchall():
|
|
resv_counts[phase_val] = count
|
|
|
|
return {
|
|
"slots": {
|
|
"total": total_slots,
|
|
"ready": slot_counts.get("ready", 0),
|
|
"launching": slot_counts.get("launching", 0),
|
|
"booting": slot_counts.get("booting", 0),
|
|
"binding": slot_counts.get("binding", 0),
|
|
"draining": slot_counts.get("draining", 0),
|
|
"terminating": slot_counts.get("terminating", 0),
|
|
"empty": slot_counts.get("empty", 0),
|
|
"error": slot_counts.get("error", 0),
|
|
},
|
|
"reservations": {
|
|
"pending": resv_counts.get("pending", 0),
|
|
"ready": resv_counts.get("ready", 0),
|
|
"failed": resv_counts.get("failed", 0),
|
|
},
|
|
}
|
|
|
|
def close(self) -> None:
|
|
"""Close the database connection."""
|
|
self._conn.close()
|