Add optional autoscaler cross-account assume-role support

This commit is contained in:
Abel Luck 2026-03-05 12:38:10 +01:00
parent 5092005e05
commit 4c7333ca07
3 changed files with 77 additions and 10 deletions

View file

@ -28,6 +28,7 @@ class AwsConfig:
subnet_ids: list[str] = field(default_factory=list) subnet_ids: list[str] = field(default_factory=list)
security_group_ids: list[str] = field(default_factory=list) security_group_ids: list[str] = field(default_factory=list)
instance_profile_arn: str = "" instance_profile_arn: str = ""
assume_role_arn: str = ""
@dataclass @dataclass

View file

@ -7,7 +7,9 @@ import json
import logging import logging
import random import random
import socket import socket
import threading
import time import time
from datetime import UTC, datetime
from typing import Any from typing import Any
import boto3 import boto3
@ -58,7 +60,16 @@ class EC2Runtime(RuntimeAdapter):
_client: Any = None, _client: Any = None,
_tailscale_socket_path: str = "/run/tailscale/tailscaled.sock", _tailscale_socket_path: str = "/run/tailscale/tailscaled.sock",
) -> None: ) -> None:
self._client: Any = _client or boto3.client("ec2", region_name=config.region) self._base_client: Any = _client or boto3.client("ec2", region_name=config.region)
self._region = config.region
self._assume_role_arn = config.assume_role_arn.strip()
self._assume_role_session_name = "nix-builder-autoscaler"
self._assumed_client: Any | None = None
self._assumed_client_expiration: datetime | None = None
self._assume_role_lock = threading.Lock()
self._sts_client: Any | None = None
if self._assume_role_arn != "":
self._sts_client = boto3.client("sts", region_name=config.region)
self._instance_type = config.instance_type self._instance_type = config.instance_type
self._launch_template_id = config.launch_template_id self._launch_template_id = config.launch_template_id
self._on_demand_launch_template_id = config.on_demand_launch_template_id self._on_demand_launch_template_id = config.on_demand_launch_template_id
@ -78,9 +89,12 @@ class EC2Runtime(RuntimeAdapter):
warnings so a transient permissions issue does not prevent startup. warnings so a transient permissions issue does not prevent startup.
""" """
try: try:
ec2_client = self._get_ec2_client()
target_azs: set[str] = set() target_azs: set[str] = set()
if self._subnet_ids: if self._subnet_ids:
resp = self._client.describe_subnets(SubnetIds=self._subnet_ids) resp = self._call_with_backoff(
ec2_client.describe_subnets, SubnetIds=self._subnet_ids
)
for subnet in resp.get("Subnets", []): for subnet in resp.get("Subnets", []):
az = subnet.get("AvailabilityZone") az = subnet.get("AvailabilityZone")
if az: if az:
@ -92,14 +106,15 @@ class EC2Runtime(RuntimeAdapter):
if target_azs: if target_azs:
filters.append({"Name": "location", "Values": list(target_azs)}) filters.append({"Name": "location", "Values": list(target_azs)})
resp = self._client.describe_instance_type_offerings( resp = self._call_with_backoff(
ec2_client.describe_instance_type_offerings,
LocationType="availability-zone", LocationType="availability-zone",
Filters=filters, Filters=filters,
) )
available_azs = {o["Location"] for o in resp.get("InstanceTypeOfferings", [])} available_azs = {o["Location"] for o in resp.get("InstanceTypeOfferings", [])}
if not available_azs: if not available_azs:
region = self._client.meta.region_name region = ec2_client.meta.region_name
log.error( log.error(
"preflight_misconfiguration", "preflight_misconfiguration",
extra={ extra={
@ -194,15 +209,15 @@ class EC2Runtime(RuntimeAdapter):
self._subnet_index += 1 self._subnet_index += 1
params["SubnetId"] = subnet params["SubnetId"] = subnet
resp = self._call_with_backoff(self._client.run_instances, **params) ec2_client = self._get_ec2_client()
resp = self._call_with_backoff(ec2_client.run_instances, **params)
return resp["Instances"][0]["InstanceId"] return resp["Instances"][0]["InstanceId"]
def describe_instance(self, instance_id: str) -> dict: def describe_instance(self, instance_id: str) -> dict:
"""Return normalized instance info dict.""" """Return normalized instance info dict."""
try: try:
resp = self._call_with_backoff( ec2_client = self._get_ec2_client()
self._client.describe_instances, InstanceIds=[instance_id] resp = self._call_with_backoff(ec2_client.describe_instances, InstanceIds=[instance_id])
)
except RuntimeAdapterError: except RuntimeAdapterError:
return {"state": "terminated", "tailscale_ip": None, "launch_time": None} return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
@ -227,12 +242,14 @@ class EC2Runtime(RuntimeAdapter):
def terminate_instance(self, instance_id: str) -> None: def terminate_instance(self, instance_id: str) -> None:
"""Terminate the instance.""" """Terminate the instance."""
self._call_with_backoff(self._client.terminate_instances, InstanceIds=[instance_id]) ec2_client = self._get_ec2_client()
self._call_with_backoff(ec2_client.terminate_instances, InstanceIds=[instance_id])
def list_managed_instances(self) -> list[dict]: def list_managed_instances(self) -> list[dict]:
"""Return list of managed instances.""" """Return list of managed instances."""
ec2_client = self._get_ec2_client()
resp = self._call_with_backoff( resp = self._call_with_backoff(
self._client.describe_instances, ec2_client.describe_instances,
Filters=[ Filters=[
{"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]}, {"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]},
{ {
@ -255,6 +272,45 @@ class EC2Runtime(RuntimeAdapter):
) )
return result return result
def _get_ec2_client(self) -> Any:
if self._assume_role_arn == "":
return self._base_client
return self._get_assumed_ec2_client()
def _get_assumed_ec2_client(self) -> Any:
now = datetime.now(UTC)
with self._assume_role_lock:
if (
self._assumed_client is not None
and self._assumed_client_expiration is not None
and (self._assumed_client_expiration - now).total_seconds() > 300
):
return self._assumed_client
if self._sts_client is None:
raise RuntimeAdapterError(
"assume_role_arn is set but STS client is not initialized",
category="misconfiguration",
)
try:
response = self._sts_client.assume_role(
RoleArn=self._assume_role_arn,
RoleSessionName=self._assume_role_session_name,
)
except ClientError as exc:
raise RuntimeAdapterError(str(exc), category="misconfiguration") from exc
credentials = response["Credentials"]
self._assumed_client_expiration = credentials["Expiration"]
self._assumed_client = boto3.client(
"ec2",
region_name=self._region,
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)
return self._assumed_client
def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any: def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any:
"""Call *fn* with exponential backoff and full jitter on retryable errors.""" """Call *fn* with exponential backoff and full jitter on retryable errors."""
delay = 0.5 delay = 0.5

View file

@ -97,6 +97,12 @@ in
default = ""; default = "";
description = "Optional instance profile ARN override."; description = "Optional instance profile ARN override.";
}; };
assumeRoleArnFile = lib.mkOption {
type = lib.types.nullOr lib.types.str;
default = null;
description = "Optional file containing an IAM role ARN for cross-account autoscaler control-plane calls.";
};
}; };
haproxy = { haproxy = {
@ -329,6 +335,9 @@ in
${lib.optionalString (cfg.aws.onDemandLaunchTemplateIdFile != null) '' ${lib.optionalString (cfg.aws.onDemandLaunchTemplateIdFile != null) ''
on_demand_launch_template_id="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.onDemandLaunchTemplateIdFile})" on_demand_launch_template_id="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.onDemandLaunchTemplateIdFile})"
''} ''}
${lib.optionalString (cfg.aws.assumeRoleArnFile != null) ''
assume_role_arn="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.assumeRoleArnFile})"
''}
cat > ${generatedConfigPath} <<EOF cat > ${generatedConfigPath} <<EOF
[server] [server]
@ -346,6 +355,7 @@ in
subnet_ids = $subnet_ids_json subnet_ids = $subnet_ids_json
security_group_ids = ${tomlStringList cfg.aws.securityGroupIds} security_group_ids = ${tomlStringList cfg.aws.securityGroupIds}
instance_profile_arn = "${cfg.aws.instanceProfileArn}" instance_profile_arn = "${cfg.aws.instanceProfileArn}"
${lib.optionalString (cfg.aws.assumeRoleArnFile != null) ''assume_role_arn = "$assume_role_arn"''}
[haproxy] [haproxy]
runtime_socket = "${cfg.haproxy.runtimeSocket}" runtime_socket = "${cfg.haproxy.runtimeSocket}"