"""Fake runtime adapter for testing.""" from __future__ import annotations import uuid from dataclasses import dataclass from .base import RuntimeAdapter from .base import RuntimeError as RuntimeAdapterError @dataclass class _FakeInstance: instance_id: str slot_id: str state: str = "pending" tailscale_ip: str | None = None launch_time: str = "" ticks_to_running: int = 0 ticks_to_ip: int = 0 interrupted: bool = False class FakeRuntime(RuntimeAdapter): """In-memory runtime adapter for deterministic testing. Args: launch_latency_ticks: Number of tick() calls before instance becomes running. ip_delay_ticks: Additional ticks after running before tailscale_ip appears. """ def __init__(self, launch_latency_ticks: int = 2, ip_delay_ticks: int = 1) -> None: self._launch_latency = launch_latency_ticks self._ip_delay = ip_delay_ticks self._instances: dict[str, _FakeInstance] = {} self._launch_failures: set[str] = set() self._interruptions: set[str] = set() self._tick_count: int = 0 self._next_ip_counter: int = 1 def launch_spot(self, slot_id: str, user_data: str) -> str: """Launch a fake spot instance.""" if slot_id in self._launch_failures: self._launch_failures.discard(slot_id) raise RuntimeAdapterError( f"Simulated launch failure for {slot_id}", category="capacity_unavailable", ) iid = f"i-fake-{uuid.uuid4().hex[:12]}" self._instances[iid] = _FakeInstance( instance_id=iid, slot_id=slot_id, state="pending", launch_time=f"2026-01-01T00:00:{self._tick_count:02d}Z", ticks_to_running=self._launch_latency, ticks_to_ip=self._launch_latency + self._ip_delay, ) return iid def describe_instance(self, instance_id: str) -> dict: """Return normalized instance info.""" inst = self._instances.get(instance_id) if inst is None: return {"state": "terminated", "tailscale_ip": None, "launch_time": None} if instance_id in self._interruptions: self._interruptions.discard(instance_id) inst.state = "terminated" inst.interrupted = True return { "state": inst.state, "tailscale_ip": inst.tailscale_ip, "launch_time": inst.launch_time, } def terminate_instance(self, instance_id: str) -> None: """Terminate a fake instance.""" inst = self._instances.get(instance_id) if inst is not None: inst.state = "terminated" def list_managed_instances(self) -> list[dict]: """List all non-terminated fake instances.""" result: list[dict] = [] for inst in self._instances.values(): if inst.state != "terminated": result.append( { "instance_id": inst.instance_id, "state": inst.state, "slot_id": inst.slot_id, } ) return result # -- Test helpers ------------------------------------------------------- def tick(self) -> None: """Advance internal tick counter and progress instance states.""" self._tick_count += 1 for inst in self._instances.values(): if inst.state == "terminated": continue if inst.state == "pending" and self._tick_count >= inst.ticks_to_running: inst.state = "running" if ( inst.state == "running" and inst.tailscale_ip is None and self._tick_count >= inst.ticks_to_ip ): inst.tailscale_ip = f"100.64.0.{self._next_ip_counter}" self._next_ip_counter += 1 def inject_launch_failure(self, slot_id: str) -> None: """Make the next launch_spot call for this slot_id raise an error.""" self._launch_failures.add(slot_id) def inject_interruption(self, instance_id: str) -> None: """Make the next describe_instance call for this instance return terminated.""" self._interruptions.add(instance_id)