diff --git a/agent/nix_builder_autoscaler/config.py b/agent/nix_builder_autoscaler/config.py index 94203bd..267ef5c 100644 --- a/agent/nix_builder_autoscaler/config.py +++ b/agent/nix_builder_autoscaler/config.py @@ -28,6 +28,7 @@ class AwsConfig: subnet_ids: list[str] = field(default_factory=list) security_group_ids: list[str] = field(default_factory=list) instance_profile_arn: str = "" + assume_role_arn: str = "" @dataclass diff --git a/agent/nix_builder_autoscaler/runtime/ec2.py b/agent/nix_builder_autoscaler/runtime/ec2.py index aa5bab5..13871c4 100644 --- a/agent/nix_builder_autoscaler/runtime/ec2.py +++ b/agent/nix_builder_autoscaler/runtime/ec2.py @@ -7,7 +7,9 @@ import json import logging import random import socket +import threading import time +from datetime import UTC, datetime from typing import Any import boto3 @@ -58,7 +60,16 @@ class EC2Runtime(RuntimeAdapter): _client: Any = None, _tailscale_socket_path: str = "/run/tailscale/tailscaled.sock", ) -> None: - self._client: Any = _client or boto3.client("ec2", region_name=config.region) + self._base_client: Any = _client or boto3.client("ec2", region_name=config.region) + self._region = config.region + self._assume_role_arn = config.assume_role_arn.strip() + self._assume_role_session_name = "nix-builder-autoscaler" + self._assumed_client: Any | None = None + self._assumed_client_expiration: datetime | None = None + self._assume_role_lock = threading.Lock() + self._sts_client: Any | None = None + if self._assume_role_arn != "": + self._sts_client = boto3.client("sts", region_name=config.region) self._instance_type = config.instance_type self._launch_template_id = config.launch_template_id self._on_demand_launch_template_id = config.on_demand_launch_template_id @@ -78,9 +89,12 @@ class EC2Runtime(RuntimeAdapter): warnings so a transient permissions issue does not prevent startup. """ try: + ec2_client = self._get_ec2_client() target_azs: set[str] = set() if self._subnet_ids: - resp = self._client.describe_subnets(SubnetIds=self._subnet_ids) + resp = self._call_with_backoff( + ec2_client.describe_subnets, SubnetIds=self._subnet_ids + ) for subnet in resp.get("Subnets", []): az = subnet.get("AvailabilityZone") if az: @@ -92,14 +106,15 @@ class EC2Runtime(RuntimeAdapter): if target_azs: filters.append({"Name": "location", "Values": list(target_azs)}) - resp = self._client.describe_instance_type_offerings( + resp = self._call_with_backoff( + ec2_client.describe_instance_type_offerings, LocationType="availability-zone", Filters=filters, ) available_azs = {o["Location"] for o in resp.get("InstanceTypeOfferings", [])} if not available_azs: - region = self._client.meta.region_name + region = ec2_client.meta.region_name log.error( "preflight_misconfiguration", extra={ @@ -194,15 +209,15 @@ class EC2Runtime(RuntimeAdapter): self._subnet_index += 1 params["SubnetId"] = subnet - resp = self._call_with_backoff(self._client.run_instances, **params) + ec2_client = self._get_ec2_client() + resp = self._call_with_backoff(ec2_client.run_instances, **params) return resp["Instances"][0]["InstanceId"] def describe_instance(self, instance_id: str) -> dict: """Return normalized instance info dict.""" try: - resp = self._call_with_backoff( - self._client.describe_instances, InstanceIds=[instance_id] - ) + ec2_client = self._get_ec2_client() + resp = self._call_with_backoff(ec2_client.describe_instances, InstanceIds=[instance_id]) except RuntimeAdapterError: return {"state": "terminated", "tailscale_ip": None, "launch_time": None} @@ -227,12 +242,14 @@ class EC2Runtime(RuntimeAdapter): def terminate_instance(self, instance_id: str) -> None: """Terminate the instance.""" - self._call_with_backoff(self._client.terminate_instances, InstanceIds=[instance_id]) + ec2_client = self._get_ec2_client() + self._call_with_backoff(ec2_client.terminate_instances, InstanceIds=[instance_id]) def list_managed_instances(self) -> list[dict]: """Return list of managed instances.""" + ec2_client = self._get_ec2_client() resp = self._call_with_backoff( - self._client.describe_instances, + ec2_client.describe_instances, Filters=[ {"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]}, { @@ -255,6 +272,45 @@ class EC2Runtime(RuntimeAdapter): ) return result + def _get_ec2_client(self) -> Any: + if self._assume_role_arn == "": + return self._base_client + return self._get_assumed_ec2_client() + + def _get_assumed_ec2_client(self) -> Any: + now = datetime.now(UTC) + with self._assume_role_lock: + if ( + self._assumed_client is not None + and self._assumed_client_expiration is not None + and (self._assumed_client_expiration - now).total_seconds() > 300 + ): + return self._assumed_client + + if self._sts_client is None: + raise RuntimeAdapterError( + "assume_role_arn is set but STS client is not initialized", + category="misconfiguration", + ) + + try: + response = self._sts_client.assume_role( + RoleArn=self._assume_role_arn, + RoleSessionName=self._assume_role_session_name, + ) + except ClientError as exc: + raise RuntimeAdapterError(str(exc), category="misconfiguration") from exc + credentials = response["Credentials"] + self._assumed_client_expiration = credentials["Expiration"] + self._assumed_client = boto3.client( + "ec2", + region_name=self._region, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + return self._assumed_client + def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any: """Call *fn* with exponential backoff and full jitter on retryable errors.""" delay = 0.5 diff --git a/nix/modules/nixos/services/nix-builder-autoscaler.nix b/nix/modules/nixos/services/nix-builder-autoscaler.nix index e37ce99..205ebf2 100644 --- a/nix/modules/nixos/services/nix-builder-autoscaler.nix +++ b/nix/modules/nixos/services/nix-builder-autoscaler.nix @@ -97,6 +97,12 @@ in default = ""; description = "Optional instance profile ARN override."; }; + + assumeRoleArnFile = lib.mkOption { + type = lib.types.nullOr lib.types.str; + default = null; + description = "Optional file containing an IAM role ARN for cross-account autoscaler control-plane calls."; + }; }; haproxy = { @@ -329,6 +335,9 @@ in ${lib.optionalString (cfg.aws.onDemandLaunchTemplateIdFile != null) '' on_demand_launch_template_id="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.onDemandLaunchTemplateIdFile})" ''} + ${lib.optionalString (cfg.aws.assumeRoleArnFile != null) '' + assume_role_arn="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.assumeRoleArnFile})" + ''} cat > ${generatedConfigPath} <