Add optional autoscaler cross-account assume-role support

This commit is contained in:
Abel Luck 2026-03-05 12:38:10 +01:00
parent 5092005e05
commit 4c7333ca07
3 changed files with 77 additions and 10 deletions

View file

@ -28,6 +28,7 @@ class AwsConfig:
subnet_ids: list[str] = field(default_factory=list)
security_group_ids: list[str] = field(default_factory=list)
instance_profile_arn: str = ""
assume_role_arn: str = ""
@dataclass

View file

@ -7,7 +7,9 @@ import json
import logging
import random
import socket
import threading
import time
from datetime import UTC, datetime
from typing import Any
import boto3
@ -58,7 +60,16 @@ class EC2Runtime(RuntimeAdapter):
_client: Any = None,
_tailscale_socket_path: str = "/run/tailscale/tailscaled.sock",
) -> None:
self._client: Any = _client or boto3.client("ec2", region_name=config.region)
self._base_client: Any = _client or boto3.client("ec2", region_name=config.region)
self._region = config.region
self._assume_role_arn = config.assume_role_arn.strip()
self._assume_role_session_name = "nix-builder-autoscaler"
self._assumed_client: Any | None = None
self._assumed_client_expiration: datetime | None = None
self._assume_role_lock = threading.Lock()
self._sts_client: Any | None = None
if self._assume_role_arn != "":
self._sts_client = boto3.client("sts", region_name=config.region)
self._instance_type = config.instance_type
self._launch_template_id = config.launch_template_id
self._on_demand_launch_template_id = config.on_demand_launch_template_id
@ -78,9 +89,12 @@ class EC2Runtime(RuntimeAdapter):
warnings so a transient permissions issue does not prevent startup.
"""
try:
ec2_client = self._get_ec2_client()
target_azs: set[str] = set()
if self._subnet_ids:
resp = self._client.describe_subnets(SubnetIds=self._subnet_ids)
resp = self._call_with_backoff(
ec2_client.describe_subnets, SubnetIds=self._subnet_ids
)
for subnet in resp.get("Subnets", []):
az = subnet.get("AvailabilityZone")
if az:
@ -92,14 +106,15 @@ class EC2Runtime(RuntimeAdapter):
if target_azs:
filters.append({"Name": "location", "Values": list(target_azs)})
resp = self._client.describe_instance_type_offerings(
resp = self._call_with_backoff(
ec2_client.describe_instance_type_offerings,
LocationType="availability-zone",
Filters=filters,
)
available_azs = {o["Location"] for o in resp.get("InstanceTypeOfferings", [])}
if not available_azs:
region = self._client.meta.region_name
region = ec2_client.meta.region_name
log.error(
"preflight_misconfiguration",
extra={
@ -194,15 +209,15 @@ class EC2Runtime(RuntimeAdapter):
self._subnet_index += 1
params["SubnetId"] = subnet
resp = self._call_with_backoff(self._client.run_instances, **params)
ec2_client = self._get_ec2_client()
resp = self._call_with_backoff(ec2_client.run_instances, **params)
return resp["Instances"][0]["InstanceId"]
def describe_instance(self, instance_id: str) -> dict:
"""Return normalized instance info dict."""
try:
resp = self._call_with_backoff(
self._client.describe_instances, InstanceIds=[instance_id]
)
ec2_client = self._get_ec2_client()
resp = self._call_with_backoff(ec2_client.describe_instances, InstanceIds=[instance_id])
except RuntimeAdapterError:
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
@ -227,12 +242,14 @@ class EC2Runtime(RuntimeAdapter):
def terminate_instance(self, instance_id: str) -> None:
"""Terminate the instance."""
self._call_with_backoff(self._client.terminate_instances, InstanceIds=[instance_id])
ec2_client = self._get_ec2_client()
self._call_with_backoff(ec2_client.terminate_instances, InstanceIds=[instance_id])
def list_managed_instances(self) -> list[dict]:
"""Return list of managed instances."""
ec2_client = self._get_ec2_client()
resp = self._call_with_backoff(
self._client.describe_instances,
ec2_client.describe_instances,
Filters=[
{"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]},
{
@ -255,6 +272,45 @@ class EC2Runtime(RuntimeAdapter):
)
return result
def _get_ec2_client(self) -> Any:
if self._assume_role_arn == "":
return self._base_client
return self._get_assumed_ec2_client()
def _get_assumed_ec2_client(self) -> Any:
now = datetime.now(UTC)
with self._assume_role_lock:
if (
self._assumed_client is not None
and self._assumed_client_expiration is not None
and (self._assumed_client_expiration - now).total_seconds() > 300
):
return self._assumed_client
if self._sts_client is None:
raise RuntimeAdapterError(
"assume_role_arn is set but STS client is not initialized",
category="misconfiguration",
)
try:
response = self._sts_client.assume_role(
RoleArn=self._assume_role_arn,
RoleSessionName=self._assume_role_session_name,
)
except ClientError as exc:
raise RuntimeAdapterError(str(exc), category="misconfiguration") from exc
credentials = response["Credentials"]
self._assumed_client_expiration = credentials["Expiration"]
self._assumed_client = boto3.client(
"ec2",
region_name=self._region,
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)
return self._assumed_client
def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any:
"""Call *fn* with exponential backoff and full jitter on retryable errors."""
delay = 0.5

View file

@ -97,6 +97,12 @@ in
default = "";
description = "Optional instance profile ARN override.";
};
assumeRoleArnFile = lib.mkOption {
type = lib.types.nullOr lib.types.str;
default = null;
description = "Optional file containing an IAM role ARN for cross-account autoscaler control-plane calls.";
};
};
haproxy = {
@ -329,6 +335,9 @@ in
${lib.optionalString (cfg.aws.onDemandLaunchTemplateIdFile != null) ''
on_demand_launch_template_id="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.onDemandLaunchTemplateIdFile})"
''}
${lib.optionalString (cfg.aws.assumeRoleArnFile != null) ''
assume_role_arn="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.assumeRoleArnFile})"
''}
cat > ${generatedConfigPath} <<EOF
[server]
@ -346,6 +355,7 @@ in
subnet_ids = $subnet_ids_json
security_group_ids = ${tomlStringList cfg.aws.securityGroupIds}
instance_profile_arn = "${cfg.aws.instanceProfileArn}"
${lib.optionalString (cfg.aws.assumeRoleArnFile != null) ''assume_role_arn = "$assume_role_arn"''}
[haproxy]
runtime_socket = "${cfg.haproxy.runtimeSocket}"