add runtime adapters, scheduler, reconciler, and their unit tests
This commit is contained in:
parent
d1976a5fd8
commit
b63d69c81d
10 changed files with 1471 additions and 28 deletions
|
|
@ -1,32 +1,175 @@
|
|||
"""EC2 runtime adapter — stub for Plan 02."""
|
||||
"""EC2 runtime adapter for managing Spot instances."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from ..config import AwsConfig
|
||||
from .base import RuntimeAdapter
|
||||
from .base import RuntimeError as RuntimeAdapterError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# EC2 ClientError code → normalized error category
|
||||
_ERROR_CATEGORIES: dict[str, str] = {
|
||||
"InsufficientInstanceCapacity": "capacity_unavailable",
|
||||
"SpotMaxPriceTooLow": "price_too_low",
|
||||
"RequestLimitExceeded": "throttled",
|
||||
}
|
||||
|
||||
_RETRYABLE_CODES: frozenset[str] = frozenset({"RequestLimitExceeded"})
|
||||
|
||||
|
||||
class EC2Runtime(RuntimeAdapter):
|
||||
"""EC2 Spot instance runtime adapter.
|
||||
|
||||
Full implementation in Plan 02.
|
||||
Args:
|
||||
config: AWS configuration dataclass.
|
||||
environment: Environment tag value (e.g. ``"dev"``, ``"prod"``).
|
||||
_client: Optional pre-configured boto3 EC2 client (for testing).
|
||||
"""
|
||||
|
||||
def __init__(self, region: str, launch_template_id: str, **kwargs: object) -> None:
|
||||
self._region = region
|
||||
self._launch_template_id = launch_template_id
|
||||
def __init__(
|
||||
self,
|
||||
config: AwsConfig,
|
||||
environment: str = "dev",
|
||||
*,
|
||||
_client: Any = None,
|
||||
) -> None:
|
||||
self._client: Any = _client or boto3.client("ec2", region_name=config.region)
|
||||
self._launch_template_id = config.launch_template_id
|
||||
self._subnet_ids = list(config.subnet_ids)
|
||||
self._security_group_ids = list(config.security_group_ids)
|
||||
self._instance_profile_arn = config.instance_profile_arn
|
||||
self._environment = environment
|
||||
self._subnet_index = 0
|
||||
|
||||
def launch_spot(self, slot_id: str, user_data: str) -> str:
|
||||
"""Launch a spot instance for slot_id."""
|
||||
raise NotImplementedError
|
||||
"""Launch a spot instance for *slot_id*. Return instance ID."""
|
||||
params: dict[str, Any] = {
|
||||
"MinCount": 1,
|
||||
"MaxCount": 1,
|
||||
"LaunchTemplate": {
|
||||
"LaunchTemplateId": self._launch_template_id,
|
||||
"Version": "$Latest",
|
||||
},
|
||||
"InstanceMarketOptions": {
|
||||
"MarketType": "spot",
|
||||
"SpotOptions": {
|
||||
"SpotInstanceType": "one-time",
|
||||
"InstanceInterruptionBehavior": "terminate",
|
||||
},
|
||||
},
|
||||
"UserData": user_data,
|
||||
"TagSpecifications": [
|
||||
{
|
||||
"ResourceType": "instance",
|
||||
"Tags": [
|
||||
{"Key": "Name", "Value": f"nix-builder-{slot_id}"},
|
||||
{"Key": "AutoscalerSlot", "Value": slot_id},
|
||||
{"Key": "ManagedBy", "Value": "nix-builder-autoscaler"},
|
||||
{"Key": "Service", "Value": "nix-builder"},
|
||||
{"Key": "Environment", "Value": self._environment},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
if self._subnet_ids:
|
||||
subnet = self._subnet_ids[self._subnet_index % len(self._subnet_ids)]
|
||||
self._subnet_index += 1
|
||||
params["SubnetId"] = subnet
|
||||
|
||||
resp = self._call_with_backoff(self._client.run_instances, **params)
|
||||
return resp["Instances"][0]["InstanceId"]
|
||||
|
||||
def describe_instance(self, instance_id: str) -> dict:
|
||||
"""Return normalized instance info."""
|
||||
raise NotImplementedError
|
||||
"""Return normalized instance info dict."""
|
||||
try:
|
||||
resp = self._call_with_backoff(
|
||||
self._client.describe_instances, InstanceIds=[instance_id]
|
||||
)
|
||||
except RuntimeAdapterError:
|
||||
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
|
||||
|
||||
reservations = resp.get("Reservations", [])
|
||||
if not reservations or not reservations[0].get("Instances"):
|
||||
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
|
||||
|
||||
inst = reservations[0]["Instances"][0]
|
||||
launch_time = inst.get("LaunchTime")
|
||||
return {
|
||||
"state": inst["State"]["Name"],
|
||||
"tailscale_ip": None,
|
||||
"launch_time": launch_time.isoformat() if launch_time else None,
|
||||
}
|
||||
|
||||
def terminate_instance(self, instance_id: str) -> None:
|
||||
"""Terminate the instance."""
|
||||
raise NotImplementedError
|
||||
self._call_with_backoff(self._client.terminate_instances, InstanceIds=[instance_id])
|
||||
|
||||
def list_managed_instances(self) -> list[dict]:
|
||||
"""Return list of managed instances."""
|
||||
raise NotImplementedError
|
||||
resp = self._call_with_backoff(
|
||||
self._client.describe_instances,
|
||||
Filters=[
|
||||
{"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]},
|
||||
{
|
||||
"Name": "instance-state-name",
|
||||
"Values": ["pending", "running", "shutting-down", "stopping"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
result: list[dict] = []
|
||||
for reservation in resp.get("Reservations", []):
|
||||
for inst in reservation.get("Instances", []):
|
||||
tags = inst.get("Tags", [])
|
||||
result.append(
|
||||
{
|
||||
"instance_id": inst["InstanceId"],
|
||||
"state": inst["State"]["Name"],
|
||||
"slot_id": self._get_tag(tags, "AutoscalerSlot"),
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any:
|
||||
"""Call *fn* with exponential backoff and full jitter on retryable errors."""
|
||||
delay = 0.5
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except ClientError as e:
|
||||
code = e.response["Error"]["Code"]
|
||||
if code in _RETRYABLE_CODES and attempt < max_retries:
|
||||
jitter = random.uniform(0, min(delay, 10.0))
|
||||
time.sleep(jitter)
|
||||
delay *= 2
|
||||
log.warning(
|
||||
"Retryable EC2 error (attempt %d/%d): %s",
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
code,
|
||||
)
|
||||
continue
|
||||
category = _ERROR_CATEGORIES.get(code, "unknown")
|
||||
raise RuntimeAdapterError(str(e), category=category) from e
|
||||
|
||||
# Unreachable — loop always returns or raises on every path
|
||||
msg = "Retries exhausted"
|
||||
raise RuntimeAdapterError(msg, category="unknown")
|
||||
|
||||
@staticmethod
|
||||
def _get_tag(tags: list[dict[str, str]], key: str) -> str | None:
|
||||
"""Extract a tag value from an EC2 tag list."""
|
||||
for tag in tags:
|
||||
if tag.get("Key") == key:
|
||||
return tag.get("Value")
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue