agent: complete plan05 closeout

This commit is contained in:
Abel Luck 2026-02-27 13:48:52 +01:00
parent 33ba248c49
commit 2f0fffa905
12 changed files with 1347 additions and 313 deletions

View file

@ -2,8 +2,11 @@
from __future__ import annotations
import http.client
import json
import logging
import random
import socket
import time
from typing import Any
@ -26,6 +29,18 @@ _ERROR_CATEGORIES: dict[str, str] = {
_RETRYABLE_CODES: frozenset[str] = frozenset({"RequestLimitExceeded"})
class _UnixSocketHTTPConnection(http.client.HTTPConnection):
"""HTTP connection over a Unix domain socket."""
def __init__(self, socket_path: str, timeout: float = 1.0) -> None:
super().__init__("local-tailscaled.sock", timeout=timeout)
self._socket_path = socket_path
def connect(self) -> None:
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.sock.connect(self._socket_path)
class EC2Runtime(RuntimeAdapter):
"""EC2 Spot instance runtime adapter.
@ -41,6 +56,7 @@ class EC2Runtime(RuntimeAdapter):
environment: str = "dev",
*,
_client: Any = None,
_tailscale_socket_path: str = "/run/tailscale/tailscaled.sock",
) -> None:
self._client: Any = _client or boto3.client("ec2", region_name=config.region)
self._launch_template_id = config.launch_template_id
@ -49,6 +65,7 @@ class EC2Runtime(RuntimeAdapter):
self._instance_profile_arn = config.instance_profile_arn
self._environment = environment
self._subnet_index = 0
self._tailscale_socket_path = _tailscale_socket_path
def launch_spot(self, slot_id: str, user_data: str) -> str:
"""Launch a spot instance for *slot_id*. Return instance ID."""
@ -103,10 +120,17 @@ class EC2Runtime(RuntimeAdapter):
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
inst = reservations[0]["Instances"][0]
tags = inst.get("Tags", [])
slot_id = self._get_tag(tags, "AutoscalerSlot")
state = inst["State"]["Name"]
tailscale_ip: str | None = None
if state == "running" and slot_id:
tailscale_ip = self._discover_tailscale_ip(slot_id, instance_id)
launch_time = inst.get("LaunchTime")
return {
"state": inst["State"]["Name"],
"tailscale_ip": None,
"state": state,
"tailscale_ip": tailscale_ip,
"launch_time": launch_time.isoformat() if launch_time else None,
}
@ -166,6 +190,98 @@ class EC2Runtime(RuntimeAdapter):
msg = "Retries exhausted"
raise RuntimeAdapterError(msg, category="unknown")
def _discover_tailscale_ip(self, slot_id: str, instance_id: str) -> str | None:
"""Resolve Tailscale IP for instance identity via local tailscaled LocalAPI."""
status = self._read_tailscale_status()
if status is None:
return None
peers_obj = status.get("Peer")
if not isinstance(peers_obj, dict):
return None
online_candidates: list[tuple[str, str]] = []
for peer in peers_obj.values():
if not isinstance(peer, dict):
continue
if not self._peer_is_online(peer):
continue
hostname = self._peer_hostname(peer)
if hostname is None:
continue
ip = self._peer_tailscale_ip(peer)
if ip is None:
continue
online_candidates.append((hostname, ip))
identity = f"nix-builder-{slot_id}-{instance_id}".lower()
identity_matches = [ip for host, ip in online_candidates if identity in host]
if len(identity_matches) == 1:
return identity_matches[0]
if len(identity_matches) > 1:
log.warning(
"tailscale_identity_ambiguous",
extra={"slot_id": slot_id, "instance_id": instance_id},
)
return None
slot_identity = f"nix-builder-{slot_id}".lower()
slot_matches = [ip for host, ip in online_candidates if slot_identity in host]
if len(slot_matches) == 1:
return slot_matches[0]
if len(slot_matches) > 1:
log.warning("tailscale_slot_ambiguous", extra={"slot_id": slot_id})
return None
return None
def _read_tailscale_status(self) -> dict[str, Any] | None:
"""Query local tailscaled LocalAPI status endpoint over Unix socket."""
conn = _UnixSocketHTTPConnection(self._tailscale_socket_path, timeout=1.0)
try:
conn.request(
"GET",
"/localapi/v0/status",
headers={"Host": "local-tailscaled.sock", "Accept": "application/json"},
)
response = conn.getresponse()
if response.status != 200:
return None
payload = response.read()
parsed = json.loads(payload.decode())
if isinstance(parsed, dict):
return parsed
return None
except (OSError, PermissionError, TimeoutError, json.JSONDecodeError, UnicodeDecodeError):
return None
except http.client.HTTPException:
return None
finally:
conn.close()
@staticmethod
def _peer_is_online(peer: dict[str, Any]) -> bool:
return bool(peer.get("Online") or peer.get("Active"))
@staticmethod
def _peer_hostname(peer: dict[str, Any]) -> str | None:
host = peer.get("HostName") or peer.get("DNSName")
if not isinstance(host, str) or not host:
return None
return host.strip(".").lower()
@staticmethod
def _peer_tailscale_ip(peer: dict[str, Any]) -> str | None:
ips = peer.get("TailscaleIPs")
if not isinstance(ips, list):
return None
ipv4 = [ip for ip in ips if isinstance(ip, str) and "." in ip]
if ipv4:
return ipv4[0]
for ip in ips:
if isinstance(ip, str) and ip:
return ip
return None
@staticmethod
def _get_tag(tags: list[dict[str, str]], key: str) -> str | None:
"""Extract a tag value from an EC2 tag list."""