add runtime adapters, scheduler, reconciler, and their unit tests

This commit is contained in:
Abel Luck 2026-02-27 12:34:32 +01:00
parent d1976a5fd8
commit b63d69c81d
10 changed files with 1471 additions and 28 deletions

View file

@ -1,32 +1,175 @@
"""EC2 runtime adapter — stub for Plan 02."""
"""EC2 runtime adapter for managing Spot instances."""
from __future__ import annotations
import logging
import random
import time
from typing import Any
import boto3
from botocore.exceptions import ClientError
from ..config import AwsConfig
from .base import RuntimeAdapter
from .base import RuntimeError as RuntimeAdapterError
log = logging.getLogger(__name__)
# EC2 ClientError code → normalized error category
_ERROR_CATEGORIES: dict[str, str] = {
"InsufficientInstanceCapacity": "capacity_unavailable",
"SpotMaxPriceTooLow": "price_too_low",
"RequestLimitExceeded": "throttled",
}
_RETRYABLE_CODES: frozenset[str] = frozenset({"RequestLimitExceeded"})
class EC2Runtime(RuntimeAdapter):
"""EC2 Spot instance runtime adapter.
Full implementation in Plan 02.
Args:
config: AWS configuration dataclass.
environment: Environment tag value (e.g. ``"dev"``, ``"prod"``).
_client: Optional pre-configured boto3 EC2 client (for testing).
"""
def __init__(self, region: str, launch_template_id: str, **kwargs: object) -> None:
self._region = region
self._launch_template_id = launch_template_id
def __init__(
self,
config: AwsConfig,
environment: str = "dev",
*,
_client: Any = None,
) -> None:
self._client: Any = _client or boto3.client("ec2", region_name=config.region)
self._launch_template_id = config.launch_template_id
self._subnet_ids = list(config.subnet_ids)
self._security_group_ids = list(config.security_group_ids)
self._instance_profile_arn = config.instance_profile_arn
self._environment = environment
self._subnet_index = 0
def launch_spot(self, slot_id: str, user_data: str) -> str:
"""Launch a spot instance for slot_id."""
raise NotImplementedError
"""Launch a spot instance for *slot_id*. Return instance ID."""
params: dict[str, Any] = {
"MinCount": 1,
"MaxCount": 1,
"LaunchTemplate": {
"LaunchTemplateId": self._launch_template_id,
"Version": "$Latest",
},
"InstanceMarketOptions": {
"MarketType": "spot",
"SpotOptions": {
"SpotInstanceType": "one-time",
"InstanceInterruptionBehavior": "terminate",
},
},
"UserData": user_data,
"TagSpecifications": [
{
"ResourceType": "instance",
"Tags": [
{"Key": "Name", "Value": f"nix-builder-{slot_id}"},
{"Key": "AutoscalerSlot", "Value": slot_id},
{"Key": "ManagedBy", "Value": "nix-builder-autoscaler"},
{"Key": "Service", "Value": "nix-builder"},
{"Key": "Environment", "Value": self._environment},
],
}
],
}
if self._subnet_ids:
subnet = self._subnet_ids[self._subnet_index % len(self._subnet_ids)]
self._subnet_index += 1
params["SubnetId"] = subnet
resp = self._call_with_backoff(self._client.run_instances, **params)
return resp["Instances"][0]["InstanceId"]
def describe_instance(self, instance_id: str) -> dict:
"""Return normalized instance info."""
raise NotImplementedError
"""Return normalized instance info dict."""
try:
resp = self._call_with_backoff(
self._client.describe_instances, InstanceIds=[instance_id]
)
except RuntimeAdapterError:
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
reservations = resp.get("Reservations", [])
if not reservations or not reservations[0].get("Instances"):
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
inst = reservations[0]["Instances"][0]
launch_time = inst.get("LaunchTime")
return {
"state": inst["State"]["Name"],
"tailscale_ip": None,
"launch_time": launch_time.isoformat() if launch_time else None,
}
def terminate_instance(self, instance_id: str) -> None:
"""Terminate the instance."""
raise NotImplementedError
self._call_with_backoff(self._client.terminate_instances, InstanceIds=[instance_id])
def list_managed_instances(self) -> list[dict]:
"""Return list of managed instances."""
raise NotImplementedError
resp = self._call_with_backoff(
self._client.describe_instances,
Filters=[
{"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]},
{
"Name": "instance-state-name",
"Values": ["pending", "running", "shutting-down", "stopping"],
},
],
)
result: list[dict] = []
for reservation in resp.get("Reservations", []):
for inst in reservation.get("Instances", []):
tags = inst.get("Tags", [])
result.append(
{
"instance_id": inst["InstanceId"],
"state": inst["State"]["Name"],
"slot_id": self._get_tag(tags, "AutoscalerSlot"),
}
)
return result
def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any:
"""Call *fn* with exponential backoff and full jitter on retryable errors."""
delay = 0.5
for attempt in range(max_retries + 1):
try:
return fn(*args, **kwargs)
except ClientError as e:
code = e.response["Error"]["Code"]
if code in _RETRYABLE_CODES and attempt < max_retries:
jitter = random.uniform(0, min(delay, 10.0))
time.sleep(jitter)
delay *= 2
log.warning(
"Retryable EC2 error (attempt %d/%d): %s",
attempt + 1,
max_retries,
code,
)
continue
category = _ERROR_CATEGORIES.get(code, "unknown")
raise RuntimeAdapterError(str(e), category=category) from e
# Unreachable — loop always returns or raises on every path
msg = "Retries exhausted"
raise RuntimeAdapterError(msg, category="unknown")
@staticmethod
def _get_tag(tags: list[dict[str, str]], key: str) -> str | None:
"""Extract a tag value from an EC2 tag list."""
for tag in tags:
if tag.get("Key") == key:
return tag.get("Value")
return None