nix-builder-autoscaler/agent/nix_builder_autoscaler/runtime/ec2.py

292 lines
11 KiB
Python
Raw Normal View History

"""EC2 runtime adapter for managing Spot instances."""
2026-02-27 11:59:16 +01:00
from __future__ import annotations
2026-02-27 13:48:52 +01:00
import http.client
import json
import logging
import random
2026-02-27 13:48:52 +01:00
import socket
import time
from typing import Any
import boto3
from botocore.exceptions import ClientError
from ..config import AwsConfig
2026-02-27 11:59:16 +01:00
from .base import RuntimeAdapter
from .base import RuntimeError as RuntimeAdapterError
log = logging.getLogger(__name__)
# EC2 ClientError code → normalized error category
_ERROR_CATEGORIES: dict[str, str] = {
"InsufficientInstanceCapacity": "capacity_unavailable",
"SpotMaxPriceTooLow": "price_too_low",
"RequestLimitExceeded": "throttled",
}
_RETRYABLE_CODES: frozenset[str] = frozenset({"RequestLimitExceeded"})
2026-02-27 11:59:16 +01:00
2026-02-27 13:48:52 +01:00
class _UnixSocketHTTPConnection(http.client.HTTPConnection):
"""HTTP connection over a Unix domain socket."""
def __init__(self, socket_path: str, timeout: float = 1.0) -> None:
super().__init__("local-tailscaled.sock", timeout=timeout)
self._socket_path = socket_path
def connect(self) -> None:
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.sock.connect(self._socket_path)
2026-02-27 11:59:16 +01:00
class EC2Runtime(RuntimeAdapter):
"""EC2 Spot instance runtime adapter.
Args:
config: AWS configuration dataclass.
environment: Environment tag value (e.g. ``"dev"``, ``"prod"``).
_client: Optional pre-configured boto3 EC2 client (for testing).
2026-02-27 11:59:16 +01:00
"""
def __init__(
self,
config: AwsConfig,
environment: str = "dev",
*,
_client: Any = None,
2026-02-27 13:48:52 +01:00
_tailscale_socket_path: str = "/run/tailscale/tailscaled.sock",
) -> None:
self._client: Any = _client or boto3.client("ec2", region_name=config.region)
self._launch_template_id = config.launch_template_id
self._subnet_ids = list(config.subnet_ids)
self._security_group_ids = list(config.security_group_ids)
self._instance_profile_arn = config.instance_profile_arn
self._environment = environment
self._subnet_index = 0
2026-02-27 13:48:52 +01:00
self._tailscale_socket_path = _tailscale_socket_path
2026-02-27 11:59:16 +01:00
def launch_spot(self, slot_id: str, user_data: str) -> str:
"""Launch a spot instance for *slot_id*. Return instance ID."""
params: dict[str, Any] = {
"MinCount": 1,
"MaxCount": 1,
"LaunchTemplate": {
"LaunchTemplateId": self._launch_template_id,
"Version": "$Latest",
},
"InstanceMarketOptions": {
"MarketType": "spot",
"SpotOptions": {
"SpotInstanceType": "one-time",
"InstanceInterruptionBehavior": "terminate",
},
},
"UserData": user_data,
"TagSpecifications": [
{
"ResourceType": "instance",
"Tags": [
{"Key": "Name", "Value": f"nix-builder-{slot_id}"},
{"Key": "AutoscalerSlot", "Value": slot_id},
{"Key": "ManagedBy", "Value": "nix-builder-autoscaler"},
{"Key": "Service", "Value": "nix-builder"},
{"Key": "Environment", "Value": self._environment},
],
}
],
}
if self._subnet_ids:
subnet = self._subnet_ids[self._subnet_index % len(self._subnet_ids)]
self._subnet_index += 1
params["SubnetId"] = subnet
resp = self._call_with_backoff(self._client.run_instances, **params)
return resp["Instances"][0]["InstanceId"]
2026-02-27 11:59:16 +01:00
def describe_instance(self, instance_id: str) -> dict:
"""Return normalized instance info dict."""
try:
resp = self._call_with_backoff(
self._client.describe_instances, InstanceIds=[instance_id]
)
except RuntimeAdapterError:
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
reservations = resp.get("Reservations", [])
if not reservations or not reservations[0].get("Instances"):
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
inst = reservations[0]["Instances"][0]
2026-02-27 13:48:52 +01:00
tags = inst.get("Tags", [])
slot_id = self._get_tag(tags, "AutoscalerSlot")
state = inst["State"]["Name"]
tailscale_ip: str | None = None
if state == "running" and slot_id:
tailscale_ip = self._discover_tailscale_ip(slot_id, instance_id)
launch_time = inst.get("LaunchTime")
return {
2026-02-27 13:48:52 +01:00
"state": state,
"tailscale_ip": tailscale_ip,
"launch_time": launch_time.isoformat() if launch_time else None,
}
2026-02-27 11:59:16 +01:00
def terminate_instance(self, instance_id: str) -> None:
"""Terminate the instance."""
self._call_with_backoff(self._client.terminate_instances, InstanceIds=[instance_id])
2026-02-27 11:59:16 +01:00
def list_managed_instances(self) -> list[dict]:
"""Return list of managed instances."""
resp = self._call_with_backoff(
self._client.describe_instances,
Filters=[
{"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]},
{
"Name": "instance-state-name",
"Values": ["pending", "running", "shutting-down", "stopping"],
},
],
)
result: list[dict] = []
for reservation in resp.get("Reservations", []):
for inst in reservation.get("Instances", []):
tags = inst.get("Tags", [])
result.append(
{
"instance_id": inst["InstanceId"],
"state": inst["State"]["Name"],
"slot_id": self._get_tag(tags, "AutoscalerSlot"),
}
)
return result
def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any:
"""Call *fn* with exponential backoff and full jitter on retryable errors."""
delay = 0.5
for attempt in range(max_retries + 1):
try:
return fn(*args, **kwargs)
except ClientError as e:
code = e.response["Error"]["Code"]
if code in _RETRYABLE_CODES and attempt < max_retries:
jitter = random.uniform(0, min(delay, 10.0))
time.sleep(jitter)
delay *= 2
log.warning(
"Retryable EC2 error (attempt %d/%d): %s",
attempt + 1,
max_retries,
code,
)
continue
category = _ERROR_CATEGORIES.get(code, "unknown")
raise RuntimeAdapterError(str(e), category=category) from e
# Unreachable — loop always returns or raises on every path
msg = "Retries exhausted"
raise RuntimeAdapterError(msg, category="unknown")
2026-02-27 13:48:52 +01:00
def _discover_tailscale_ip(self, slot_id: str, instance_id: str) -> str | None:
"""Resolve Tailscale IP for instance identity via local tailscaled LocalAPI."""
status = self._read_tailscale_status()
if status is None:
return None
peers_obj = status.get("Peer")
if not isinstance(peers_obj, dict):
return None
online_candidates: list[tuple[str, str]] = []
for peer in peers_obj.values():
if not isinstance(peer, dict):
continue
if not self._peer_is_online(peer):
continue
hostname = self._peer_hostname(peer)
if hostname is None:
continue
ip = self._peer_tailscale_ip(peer)
if ip is None:
continue
online_candidates.append((hostname, ip))
identity = f"nix-builder-{slot_id}-{instance_id}".lower()
identity_matches = [ip for host, ip in online_candidates if identity in host]
if len(identity_matches) == 1:
return identity_matches[0]
if len(identity_matches) > 1:
log.warning(
"tailscale_identity_ambiguous",
extra={"slot_id": slot_id, "instance_id": instance_id},
)
return None
slot_identity = f"nix-builder-{slot_id}".lower()
slot_matches = [ip for host, ip in online_candidates if slot_identity in host]
if len(slot_matches) == 1:
return slot_matches[0]
if len(slot_matches) > 1:
log.warning("tailscale_slot_ambiguous", extra={"slot_id": slot_id})
return None
return None
def _read_tailscale_status(self) -> dict[str, Any] | None:
"""Query local tailscaled LocalAPI status endpoint over Unix socket."""
conn = _UnixSocketHTTPConnection(self._tailscale_socket_path, timeout=1.0)
try:
conn.request(
"GET",
"/localapi/v0/status",
headers={"Host": "local-tailscaled.sock", "Accept": "application/json"},
)
response = conn.getresponse()
if response.status != 200:
return None
payload = response.read()
parsed = json.loads(payload.decode())
if isinstance(parsed, dict):
return parsed
return None
except (OSError, PermissionError, TimeoutError, json.JSONDecodeError, UnicodeDecodeError):
return None
except http.client.HTTPException:
return None
finally:
conn.close()
@staticmethod
def _peer_is_online(peer: dict[str, Any]) -> bool:
return bool(peer.get("Online"))
2026-02-27 13:48:52 +01:00
@staticmethod
def _peer_hostname(peer: dict[str, Any]) -> str | None:
host = peer.get("HostName") or peer.get("DNSName")
if not isinstance(host, str) or not host:
return None
return host.strip(".").lower()
@staticmethod
def _peer_tailscale_ip(peer: dict[str, Any]) -> str | None:
ips = peer.get("TailscaleIPs")
if not isinstance(ips, list):
return None
ipv4 = [ip for ip in ips if isinstance(ip, str) and "." in ip]
if ipv4:
return ipv4[0]
for ip in ips:
if isinstance(ip, str) and ip:
return ip
return None
@staticmethod
def _get_tag(tags: list[dict[str, str]], key: str) -> str | None:
"""Extract a tag value from an EC2 tag list."""
for tag in tags:
if tag.get("Key") == key:
return tag.get("Value")
return None