nix-builder-autoscaler/agent/nix_builder_autoscaler/runtime/ec2.py

438 lines
17 KiB
Python

"""EC2 runtime adapter for managing Spot instances."""
from __future__ import annotations
import http.client
import json
import logging
import random
import socket
import threading
import time
from datetime import UTC, datetime
from typing import Any
import boto3
from botocore.exceptions import ClientError
from ..config import AwsConfig
from .base import RuntimeAdapter
from .base import RuntimeError as RuntimeAdapterError
log = logging.getLogger(__name__)
# EC2 ClientError code → normalized error category
_ERROR_CATEGORIES: dict[str, str] = {
"InsufficientInstanceCapacity": "capacity_unavailable",
"SpotMaxPriceTooLow": "price_too_low",
"RequestLimitExceeded": "throttled",
}
_RETRYABLE_CODES: frozenset[str] = frozenset({"RequestLimitExceeded"})
class _UnixSocketHTTPConnection(http.client.HTTPConnection):
"""HTTP connection over a Unix domain socket."""
def __init__(self, socket_path: str, timeout: float = 1.0) -> None:
super().__init__("local-tailscaled.sock", timeout=timeout)
self._socket_path = socket_path
def connect(self) -> None:
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.sock.connect(self._socket_path)
class EC2Runtime(RuntimeAdapter):
"""EC2 Spot instance runtime adapter.
Args:
config: AWS configuration dataclass.
environment: Environment tag value (e.g. ``"dev"``, ``"prod"``).
_client: Optional pre-configured boto3 EC2 client (for testing).
"""
def __init__(
self,
config: AwsConfig,
environment: str = "dev",
*,
_client: Any = None,
_tailscale_socket_path: str = "/run/tailscale/tailscaled.sock",
) -> None:
self._base_client: Any = _client or boto3.client("ec2", region_name=config.region)
self._region = config.region
self._assume_role_arn = config.assume_role_arn.strip()
self._assume_role_session_name = "nix-builder-autoscaler"
self._assumed_client: Any | None = None
self._assumed_client_expiration: datetime | None = None
self._assume_role_lock = threading.Lock()
self._sts_client: Any | None = None
if self._assume_role_arn != "":
self._sts_client = boto3.client("sts", region_name=config.region)
self._instance_type = config.instance_type
self._launch_template_id = config.launch_template_id
self._on_demand_launch_template_id = config.on_demand_launch_template_id
self._subnet_ids = list(config.subnet_ids)
self._security_group_ids = list(config.security_group_ids)
self._instance_profile_arn = config.instance_profile_arn
self._environment = environment
self._subnet_index = 0
self._tailscale_socket_path = _tailscale_socket_path
def preflight_validate(self) -> None:
"""Check that the configured instance type is available in the configured subnets' AZs.
Logs a clear error if the instance type is absent from the region or missing from
any subnet AZ so misconfigurations are surfaced at startup rather than discovered
silently on every failed launch attempt. Never raises; API failures are logged as
warnings so a transient permissions issue does not prevent startup.
"""
try:
ec2_client = self._get_ec2_client()
target_azs: set[str] = set()
if self._subnet_ids:
resp = self._call_with_backoff(
ec2_client.describe_subnets, SubnetIds=self._subnet_ids
)
for subnet in resp.get("Subnets", []):
az = subnet.get("AvailabilityZone")
if az:
target_azs.add(az)
filters: list[dict[str, Any]] = [
{"Name": "instance-type", "Values": [self._instance_type]},
]
if target_azs:
filters.append({"Name": "location", "Values": list(target_azs)})
resp = self._call_with_backoff(
ec2_client.describe_instance_type_offerings,
LocationType="availability-zone",
Filters=filters,
)
available_azs = {o["Location"] for o in resp.get("InstanceTypeOfferings", [])}
if not available_azs:
region = ec2_client.meta.region_name
log.error(
"preflight_misconfiguration",
extra={
"error": (
f"instance type {self._instance_type!r} is not available in"
f" region {region!r} - all launches will fail with Unsupported"
),
"category": "misconfiguration",
},
)
return
missing_azs = target_azs - available_azs
if missing_azs:
log.warning(
"preflight_misconfiguration",
extra={
"error": (
f"instance type {self._instance_type!r} is not available in"
f" AZs {sorted(missing_azs)} - launches into those subnets will"
f" fail with Unsupported"
),
"category": "misconfiguration",
},
)
else:
log.info(
"preflight_ok",
extra={
"error": None,
"category": None,
},
)
except Exception as exc:
log.warning(
"preflight_validate_failed",
extra={"error": str(exc), "category": "unknown"},
)
def launch_instance(
self, slot_id: str, user_data: str, *, nested_virtualization: bool = False
) -> str:
"""Launch an instance for *slot_id*. Return instance ID.
When nested_virtualization is True, an on-demand instance is launched using the
on-demand launch template (cpu_options nested virt enabled, no spot market options).
When False (default), a spot instance is launched using the spot launch template.
"""
if nested_virtualization:
if not self._on_demand_launch_template_id:
raise RuntimeAdapterError(
"nested_virtualization=True but on_demand_launch_template_id is not configured",
category="misconfiguration",
)
lt_id = self._on_demand_launch_template_id
else:
lt_id = self._launch_template_id
params: dict[str, Any] = {
"MinCount": 1,
"MaxCount": 1,
"LaunchTemplate": {
"LaunchTemplateId": lt_id,
"Version": "$Latest",
},
"UserData": user_data,
"TagSpecifications": [
{
"ResourceType": "instance",
"Tags": [
{"Key": "Name", "Value": f"nix-builder-{slot_id}"},
{"Key": "AutoscalerSlot", "Value": slot_id},
{"Key": "ManagedBy", "Value": "nix-builder-autoscaler"},
{"Key": "Service", "Value": "nix-builder"},
{"Key": "Environment", "Value": self._environment},
],
}
],
}
if not nested_virtualization:
params["InstanceMarketOptions"] = {
"MarketType": "spot",
"SpotOptions": {
"SpotInstanceType": "one-time",
"InstanceInterruptionBehavior": "terminate",
},
}
if self._subnet_ids:
subnet = self._subnet_ids[self._subnet_index % len(self._subnet_ids)]
self._subnet_index += 1
params["SubnetId"] = subnet
ec2_client = self._get_ec2_client()
resp = self._call_with_backoff(ec2_client.run_instances, **params)
return resp["Instances"][0]["InstanceId"]
def describe_instance(self, instance_id: str) -> dict:
"""Return normalized instance info dict."""
try:
ec2_client = self._get_ec2_client()
resp = self._call_with_backoff(ec2_client.describe_instances, InstanceIds=[instance_id])
except RuntimeAdapterError:
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
reservations = resp.get("Reservations", [])
if not reservations or not reservations[0].get("Instances"):
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
inst = reservations[0]["Instances"][0]
tags = inst.get("Tags", [])
slot_id = self._get_tag(tags, "AutoscalerSlot")
state = inst["State"]["Name"]
tailscale_ip: str | None = None
if state == "running" and slot_id:
tailscale_ip = self._discover_tailscale_ip(slot_id, instance_id)
launch_time = inst.get("LaunchTime")
return {
"state": state,
"tailscale_ip": tailscale_ip,
"launch_time": launch_time.isoformat() if launch_time else None,
}
def terminate_instance(self, instance_id: str) -> None:
"""Terminate the instance."""
ec2_client = self._get_ec2_client()
self._call_with_backoff(ec2_client.terminate_instances, InstanceIds=[instance_id])
def list_managed_instances(self) -> list[dict]:
"""Return list of managed instances."""
ec2_client = self._get_ec2_client()
resp = self._call_with_backoff(
ec2_client.describe_instances,
Filters=[
{"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]},
{
"Name": "instance-state-name",
"Values": ["pending", "running", "shutting-down", "stopping"],
},
],
)
result: list[dict] = []
for reservation in resp.get("Reservations", []):
for inst in reservation.get("Instances", []):
tags = inst.get("Tags", [])
result.append(
{
"instance_id": inst["InstanceId"],
"state": inst["State"]["Name"],
"slot_id": self._get_tag(tags, "AutoscalerSlot"),
}
)
return result
def _get_ec2_client(self) -> Any:
if self._assume_role_arn == "":
return self._base_client
return self._get_assumed_ec2_client()
def _get_assumed_ec2_client(self) -> Any:
now = datetime.now(UTC)
with self._assume_role_lock:
if (
self._assumed_client is not None
and self._assumed_client_expiration is not None
and (self._assumed_client_expiration - now).total_seconds() > 300
):
return self._assumed_client
if self._sts_client is None:
raise RuntimeAdapterError(
"assume_role_arn is set but STS client is not initialized",
category="misconfiguration",
)
try:
response = self._sts_client.assume_role(
RoleArn=self._assume_role_arn,
RoleSessionName=self._assume_role_session_name,
)
except ClientError as exc:
raise RuntimeAdapterError(str(exc), category="misconfiguration") from exc
credentials = response["Credentials"]
self._assumed_client_expiration = credentials["Expiration"]
self._assumed_client = boto3.client(
"ec2",
region_name=self._region,
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)
return self._assumed_client
def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any:
"""Call *fn* with exponential backoff and full jitter on retryable errors."""
delay = 0.5
for attempt in range(max_retries + 1):
try:
return fn(*args, **kwargs)
except ClientError as e:
code = e.response["Error"]["Code"]
if code in _RETRYABLE_CODES and attempt < max_retries:
jitter = random.uniform(0, min(delay, 10.0))
time.sleep(jitter)
delay *= 2
log.warning(
"Retryable EC2 error (attempt %d/%d): %s",
attempt + 1,
max_retries,
code,
)
continue
category = _ERROR_CATEGORIES.get(code, "unknown")
raise RuntimeAdapterError(str(e), category=category) from e
# Unreachable — loop always returns or raises on every path
msg = "Retries exhausted"
raise RuntimeAdapterError(msg, category="unknown")
def _discover_tailscale_ip(self, slot_id: str, instance_id: str) -> str | None:
"""Resolve Tailscale IP for instance identity via local tailscaled LocalAPI."""
status = self._read_tailscale_status()
if status is None:
return None
peers_obj = status.get("Peer")
if not isinstance(peers_obj, dict):
return None
online_candidates: list[tuple[str, str]] = []
for peer in peers_obj.values():
if not isinstance(peer, dict):
continue
if not self._peer_is_online(peer):
continue
hostname = self._peer_hostname(peer)
if hostname is None:
continue
ip = self._peer_tailscale_ip(peer)
if ip is None:
continue
online_candidates.append((hostname, ip))
identity = f"nix-builder-{slot_id}-{instance_id}".lower()
identity_matches = [ip for host, ip in online_candidates if identity in host]
if len(identity_matches) == 1:
return identity_matches[0]
if len(identity_matches) > 1:
log.warning(
"tailscale_identity_ambiguous",
extra={"slot_id": slot_id, "instance_id": instance_id},
)
return None
slot_identity = f"nix-builder-{slot_id}".lower()
slot_matches = [ip for host, ip in online_candidates if slot_identity in host]
if len(slot_matches) == 1:
return slot_matches[0]
if len(slot_matches) > 1:
log.warning("tailscale_slot_ambiguous", extra={"slot_id": slot_id})
return None
return None
def _read_tailscale_status(self) -> dict[str, Any] | None:
"""Query local tailscaled LocalAPI status endpoint over Unix socket."""
conn = _UnixSocketHTTPConnection(self._tailscale_socket_path, timeout=1.0)
try:
conn.request(
"GET",
"/localapi/v0/status",
headers={"Host": "local-tailscaled.sock", "Accept": "application/json"},
)
response = conn.getresponse()
if response.status != 200:
return None
payload = response.read()
parsed = json.loads(payload.decode())
if isinstance(parsed, dict):
return parsed
return None
except (OSError, PermissionError, TimeoutError, json.JSONDecodeError, UnicodeDecodeError):
return None
except http.client.HTTPException:
return None
finally:
conn.close()
@staticmethod
def _peer_is_online(peer: dict[str, Any]) -> bool:
return bool(peer.get("Online"))
@staticmethod
def _peer_hostname(peer: dict[str, Any]) -> str | None:
host = peer.get("HostName") or peer.get("DNSName")
if not isinstance(host, str) or not host:
return None
return host.strip(".").lower()
@staticmethod
def _peer_tailscale_ip(peer: dict[str, Any]) -> str | None:
ips = peer.get("TailscaleIPs")
if not isinstance(ips, list):
return None
ipv4 = [ip for ip in ips if isinstance(ip, str) and "." in ip]
if ipv4:
return ipv4[0]
for ip in ips:
if isinstance(ip, str) and ip:
return ip
return None
@staticmethod
def _get_tag(tags: list[dict[str, str]], key: str) -> str | None:
"""Extract a tag value from an EC2 tag list."""
for tag in tags:
if tag.get("Key") == key:
return tag.get("Value")
return None