"""EC2 runtime adapter for managing Spot instances.""" from __future__ import annotations import http.client import json import logging import random import socket import time from typing import Any import boto3 from botocore.exceptions import ClientError from ..config import AwsConfig from .base import RuntimeAdapter from .base import RuntimeError as RuntimeAdapterError log = logging.getLogger(__name__) # EC2 ClientError code → normalized error category _ERROR_CATEGORIES: dict[str, str] = { "InsufficientInstanceCapacity": "capacity_unavailable", "SpotMaxPriceTooLow": "price_too_low", "RequestLimitExceeded": "throttled", } _RETRYABLE_CODES: frozenset[str] = frozenset({"RequestLimitExceeded"}) class _UnixSocketHTTPConnection(http.client.HTTPConnection): """HTTP connection over a Unix domain socket.""" def __init__(self, socket_path: str, timeout: float = 1.0) -> None: super().__init__("local-tailscaled.sock", timeout=timeout) self._socket_path = socket_path def connect(self) -> None: self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.sock.connect(self._socket_path) class EC2Runtime(RuntimeAdapter): """EC2 Spot instance runtime adapter. Args: config: AWS configuration dataclass. environment: Environment tag value (e.g. ``"dev"``, ``"prod"``). _client: Optional pre-configured boto3 EC2 client (for testing). """ def __init__( self, config: AwsConfig, environment: str = "dev", *, _client: Any = None, _tailscale_socket_path: str = "/run/tailscale/tailscaled.sock", ) -> None: self._client: Any = _client or boto3.client("ec2", region_name=config.region) self._launch_template_id = config.launch_template_id self._subnet_ids = list(config.subnet_ids) self._security_group_ids = list(config.security_group_ids) self._instance_profile_arn = config.instance_profile_arn self._environment = environment self._subnet_index = 0 self._tailscale_socket_path = _tailscale_socket_path def launch_spot(self, slot_id: str, user_data: str) -> str: """Launch a spot instance for *slot_id*. Return instance ID.""" params: dict[str, Any] = { "MinCount": 1, "MaxCount": 1, "LaunchTemplate": { "LaunchTemplateId": self._launch_template_id, "Version": "$Latest", }, "InstanceMarketOptions": { "MarketType": "spot", "SpotOptions": { "SpotInstanceType": "one-time", "InstanceInterruptionBehavior": "terminate", }, }, "UserData": user_data, "TagSpecifications": [ { "ResourceType": "instance", "Tags": [ {"Key": "Name", "Value": f"nix-builder-{slot_id}"}, {"Key": "AutoscalerSlot", "Value": slot_id}, {"Key": "ManagedBy", "Value": "nix-builder-autoscaler"}, {"Key": "Service", "Value": "nix-builder"}, {"Key": "Environment", "Value": self._environment}, ], } ], } if self._subnet_ids: subnet = self._subnet_ids[self._subnet_index % len(self._subnet_ids)] self._subnet_index += 1 params["SubnetId"] = subnet resp = self._call_with_backoff(self._client.run_instances, **params) return resp["Instances"][0]["InstanceId"] def describe_instance(self, instance_id: str) -> dict: """Return normalized instance info dict.""" try: resp = self._call_with_backoff( self._client.describe_instances, InstanceIds=[instance_id] ) except RuntimeAdapterError: return {"state": "terminated", "tailscale_ip": None, "launch_time": None} reservations = resp.get("Reservations", []) if not reservations or not reservations[0].get("Instances"): return {"state": "terminated", "tailscale_ip": None, "launch_time": None} inst = reservations[0]["Instances"][0] tags = inst.get("Tags", []) slot_id = self._get_tag(tags, "AutoscalerSlot") state = inst["State"]["Name"] tailscale_ip: str | None = None if state == "running" and slot_id: tailscale_ip = self._discover_tailscale_ip(slot_id, instance_id) launch_time = inst.get("LaunchTime") return { "state": state, "tailscale_ip": tailscale_ip, "launch_time": launch_time.isoformat() if launch_time else None, } def terminate_instance(self, instance_id: str) -> None: """Terminate the instance.""" self._call_with_backoff(self._client.terminate_instances, InstanceIds=[instance_id]) def list_managed_instances(self) -> list[dict]: """Return list of managed instances.""" resp = self._call_with_backoff( self._client.describe_instances, Filters=[ {"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]}, { "Name": "instance-state-name", "Values": ["pending", "running", "shutting-down", "stopping"], }, ], ) result: list[dict] = [] for reservation in resp.get("Reservations", []): for inst in reservation.get("Instances", []): tags = inst.get("Tags", []) result.append( { "instance_id": inst["InstanceId"], "state": inst["State"]["Name"], "slot_id": self._get_tag(tags, "AutoscalerSlot"), } ) return result def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any: """Call *fn* with exponential backoff and full jitter on retryable errors.""" delay = 0.5 for attempt in range(max_retries + 1): try: return fn(*args, **kwargs) except ClientError as e: code = e.response["Error"]["Code"] if code in _RETRYABLE_CODES and attempt < max_retries: jitter = random.uniform(0, min(delay, 10.0)) time.sleep(jitter) delay *= 2 log.warning( "Retryable EC2 error (attempt %d/%d): %s", attempt + 1, max_retries, code, ) continue category = _ERROR_CATEGORIES.get(code, "unknown") raise RuntimeAdapterError(str(e), category=category) from e # Unreachable — loop always returns or raises on every path msg = "Retries exhausted" raise RuntimeAdapterError(msg, category="unknown") def _discover_tailscale_ip(self, slot_id: str, instance_id: str) -> str | None: """Resolve Tailscale IP for instance identity via local tailscaled LocalAPI.""" status = self._read_tailscale_status() if status is None: return None peers_obj = status.get("Peer") if not isinstance(peers_obj, dict): return None online_candidates: list[tuple[str, str]] = [] for peer in peers_obj.values(): if not isinstance(peer, dict): continue if not self._peer_is_online(peer): continue hostname = self._peer_hostname(peer) if hostname is None: continue ip = self._peer_tailscale_ip(peer) if ip is None: continue online_candidates.append((hostname, ip)) identity = f"nix-builder-{slot_id}-{instance_id}".lower() identity_matches = [ip for host, ip in online_candidates if identity in host] if len(identity_matches) == 1: return identity_matches[0] if len(identity_matches) > 1: log.warning( "tailscale_identity_ambiguous", extra={"slot_id": slot_id, "instance_id": instance_id}, ) return None slot_identity = f"nix-builder-{slot_id}".lower() slot_matches = [ip for host, ip in online_candidates if slot_identity in host] if len(slot_matches) == 1: return slot_matches[0] if len(slot_matches) > 1: log.warning("tailscale_slot_ambiguous", extra={"slot_id": slot_id}) return None return None def _read_tailscale_status(self) -> dict[str, Any] | None: """Query local tailscaled LocalAPI status endpoint over Unix socket.""" conn = _UnixSocketHTTPConnection(self._tailscale_socket_path, timeout=1.0) try: conn.request( "GET", "/localapi/v0/status", headers={"Host": "local-tailscaled.sock", "Accept": "application/json"}, ) response = conn.getresponse() if response.status != 200: return None payload = response.read() parsed = json.loads(payload.decode()) if isinstance(parsed, dict): return parsed return None except (OSError, PermissionError, TimeoutError, json.JSONDecodeError, UnicodeDecodeError): return None except http.client.HTTPException: return None finally: conn.close() @staticmethod def _peer_is_online(peer: dict[str, Any]) -> bool: return bool(peer.get("Online") or peer.get("Active")) @staticmethod def _peer_hostname(peer: dict[str, Any]) -> str | None: host = peer.get("HostName") or peer.get("DNSName") if not isinstance(host, str) or not host: return None return host.strip(".").lower() @staticmethod def _peer_tailscale_ip(peer: dict[str, Any]) -> str | None: ips = peer.get("TailscaleIPs") if not isinstance(ips, list): return None ipv4 = [ip for ip in ips if isinstance(ip, str) and "." in ip] if ipv4: return ipv4[0] for ip in ips: if isinstance(ip, str) and ip: return ip return None @staticmethod def _get_tag(tags: list[dict[str, str]], key: str) -> str | None: """Extract a tag value from an EC2 tag list.""" for tag in tags: if tag.get("Key") == key: return tag.get("Value") return None