diff --git a/agent/nix_builder_autoscaler/config.py b/agent/nix_builder_autoscaler/config.py index 267ef5c..94203bd 100644 --- a/agent/nix_builder_autoscaler/config.py +++ b/agent/nix_builder_autoscaler/config.py @@ -28,7 +28,6 @@ class AwsConfig: subnet_ids: list[str] = field(default_factory=list) security_group_ids: list[str] = field(default_factory=list) instance_profile_arn: str = "" - assume_role_arn: str = "" @dataclass diff --git a/agent/nix_builder_autoscaler/runtime/ec2.py b/agent/nix_builder_autoscaler/runtime/ec2.py index 13871c4..aa5bab5 100644 --- a/agent/nix_builder_autoscaler/runtime/ec2.py +++ b/agent/nix_builder_autoscaler/runtime/ec2.py @@ -7,9 +7,7 @@ import json import logging import random import socket -import threading import time -from datetime import UTC, datetime from typing import Any import boto3 @@ -60,16 +58,7 @@ class EC2Runtime(RuntimeAdapter): _client: Any = None, _tailscale_socket_path: str = "/run/tailscale/tailscaled.sock", ) -> None: - self._base_client: Any = _client or boto3.client("ec2", region_name=config.region) - self._region = config.region - self._assume_role_arn = config.assume_role_arn.strip() - self._assume_role_session_name = "nix-builder-autoscaler" - self._assumed_client: Any | None = None - self._assumed_client_expiration: datetime | None = None - self._assume_role_lock = threading.Lock() - self._sts_client: Any | None = None - if self._assume_role_arn != "": - self._sts_client = boto3.client("sts", region_name=config.region) + self._client: Any = _client or boto3.client("ec2", region_name=config.region) self._instance_type = config.instance_type self._launch_template_id = config.launch_template_id self._on_demand_launch_template_id = config.on_demand_launch_template_id @@ -89,12 +78,9 @@ class EC2Runtime(RuntimeAdapter): warnings so a transient permissions issue does not prevent startup. """ try: - ec2_client = self._get_ec2_client() target_azs: set[str] = set() if self._subnet_ids: - resp = self._call_with_backoff( - ec2_client.describe_subnets, SubnetIds=self._subnet_ids - ) + resp = self._client.describe_subnets(SubnetIds=self._subnet_ids) for subnet in resp.get("Subnets", []): az = subnet.get("AvailabilityZone") if az: @@ -106,15 +92,14 @@ class EC2Runtime(RuntimeAdapter): if target_azs: filters.append({"Name": "location", "Values": list(target_azs)}) - resp = self._call_with_backoff( - ec2_client.describe_instance_type_offerings, + resp = self._client.describe_instance_type_offerings( LocationType="availability-zone", Filters=filters, ) available_azs = {o["Location"] for o in resp.get("InstanceTypeOfferings", [])} if not available_azs: - region = ec2_client.meta.region_name + region = self._client.meta.region_name log.error( "preflight_misconfiguration", extra={ @@ -209,15 +194,15 @@ class EC2Runtime(RuntimeAdapter): self._subnet_index += 1 params["SubnetId"] = subnet - ec2_client = self._get_ec2_client() - resp = self._call_with_backoff(ec2_client.run_instances, **params) + resp = self._call_with_backoff(self._client.run_instances, **params) return resp["Instances"][0]["InstanceId"] def describe_instance(self, instance_id: str) -> dict: """Return normalized instance info dict.""" try: - ec2_client = self._get_ec2_client() - resp = self._call_with_backoff(ec2_client.describe_instances, InstanceIds=[instance_id]) + resp = self._call_with_backoff( + self._client.describe_instances, InstanceIds=[instance_id] + ) except RuntimeAdapterError: return {"state": "terminated", "tailscale_ip": None, "launch_time": None} @@ -242,14 +227,12 @@ class EC2Runtime(RuntimeAdapter): def terminate_instance(self, instance_id: str) -> None: """Terminate the instance.""" - ec2_client = self._get_ec2_client() - self._call_with_backoff(ec2_client.terminate_instances, InstanceIds=[instance_id]) + self._call_with_backoff(self._client.terminate_instances, InstanceIds=[instance_id]) def list_managed_instances(self) -> list[dict]: """Return list of managed instances.""" - ec2_client = self._get_ec2_client() resp = self._call_with_backoff( - ec2_client.describe_instances, + self._client.describe_instances, Filters=[ {"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]}, { @@ -272,45 +255,6 @@ class EC2Runtime(RuntimeAdapter): ) return result - def _get_ec2_client(self) -> Any: - if self._assume_role_arn == "": - return self._base_client - return self._get_assumed_ec2_client() - - def _get_assumed_ec2_client(self) -> Any: - now = datetime.now(UTC) - with self._assume_role_lock: - if ( - self._assumed_client is not None - and self._assumed_client_expiration is not None - and (self._assumed_client_expiration - now).total_seconds() > 300 - ): - return self._assumed_client - - if self._sts_client is None: - raise RuntimeAdapterError( - "assume_role_arn is set but STS client is not initialized", - category="misconfiguration", - ) - - try: - response = self._sts_client.assume_role( - RoleArn=self._assume_role_arn, - RoleSessionName=self._assume_role_session_name, - ) - except ClientError as exc: - raise RuntimeAdapterError(str(exc), category="misconfiguration") from exc - credentials = response["Credentials"] - self._assumed_client_expiration = credentials["Expiration"] - self._assumed_client = boto3.client( - "ec2", - region_name=self._region, - aws_access_key_id=credentials["AccessKeyId"], - aws_secret_access_key=credentials["SecretAccessKey"], - aws_session_token=credentials["SessionToken"], - ) - return self._assumed_client - def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any: """Call *fn* with exponential backoff and full jitter on retryable errors.""" delay = 0.5 diff --git a/agent/nix_builder_autoscaler/tests/test_runtime_ec2.py b/agent/nix_builder_autoscaler/tests/test_runtime_ec2.py index 765288d..de8b26a 100644 --- a/agent/nix_builder_autoscaler/tests/test_runtime_ec2.py +++ b/agent/nix_builder_autoscaler/tests/test_runtime_ec2.py @@ -1,6 +1,6 @@ """Unit tests for the EC2 runtime adapter using botocore Stubber.""" -from datetime import UTC, datetime, timedelta +from datetime import UTC, datetime from unittest.mock import patch import boto3 @@ -462,62 +462,3 @@ class TestErrorClassification: with pytest.raises(RuntimeAdapterError) as exc_info: runtime.launch_instance("slot001", "#!/bin/bash") assert exc_info.value.category == "throttled" - - -class TestAssumeRole: - def test_uses_assumed_role_credentials_for_ec2_calls(self): - config = _make_config() - config.assume_role_arn = "arn:aws:iam::210987654321:role/buildbot-autoscaler-controller" - - base_ec2 = boto3.client("ec2", region_name="us-east-1") - assumed_ec2 = boto3.client("ec2", region_name="us-east-1") - sts_client = boto3.client("sts", region_name="us-east-1") - - sts_stubber = Stubber(sts_client) - sts_stubber.add_response( - "assume_role", - { - "Credentials": { - "AccessKeyId": "ASIAAAAAAAAAAAAAAAAA", - "SecretAccessKey": "s" * 40, - "SessionToken": "t" * 256, - "Expiration": datetime.now(UTC) + timedelta(hours=1), - }, - "AssumedRoleUser": { - "AssumedRoleId": "AROA1234567890EXAMPLE:nix-builder-autoscaler", - "Arn": ( - "arn:aws:sts::210987654321:assumed-role/" - "buildbot-autoscaler-controller/nix-builder-autoscaler" - ), - }, - }, - { - "RoleArn": config.assume_role_arn, - "RoleSessionName": "nix-builder-autoscaler", - }, - ) - sts_stubber.activate() - - assumed_stubber = Stubber(assumed_ec2) - assumed_stubber.add_response( - "run_instances", - {"Instances": [{"InstanceId": "i-assumed"}], "OwnerId": "210987654321"}, - ) - assumed_stubber.activate() - - real_boto3_client = boto3.client - - def _patched_client(service_name, **kwargs): - if service_name == "sts": - return sts_client - if service_name == "ec2" and kwargs.get("aws_access_key_id") == "ASIAAAAAAAAAAAAAAAAA": - return assumed_ec2 - return real_boto3_client(service_name, **kwargs) - - with patch("nix_builder_autoscaler.runtime.ec2.boto3.client", side_effect=_patched_client): - runtime = EC2Runtime(config, _client=base_ec2) - instance_id = runtime.launch_instance("slot001", "#!/bin/bash") - - assert instance_id == "i-assumed" - sts_stubber.assert_no_pending_responses() - assumed_stubber.assert_no_pending_responses() diff --git a/nix/modules/nixos/services/nix-builder-autoscaler.nix b/nix/modules/nixos/services/nix-builder-autoscaler.nix index 205ebf2..e37ce99 100644 --- a/nix/modules/nixos/services/nix-builder-autoscaler.nix +++ b/nix/modules/nixos/services/nix-builder-autoscaler.nix @@ -97,12 +97,6 @@ in default = ""; description = "Optional instance profile ARN override."; }; - - assumeRoleArnFile = lib.mkOption { - type = lib.types.nullOr lib.types.str; - default = null; - description = "Optional file containing an IAM role ARN for cross-account autoscaler control-plane calls."; - }; }; haproxy = { @@ -335,9 +329,6 @@ in ${lib.optionalString (cfg.aws.onDemandLaunchTemplateIdFile != null) '' on_demand_launch_template_id="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.onDemandLaunchTemplateIdFile})" ''} - ${lib.optionalString (cfg.aws.assumeRoleArnFile != null) '' - assume_role_arn="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.assumeRoleArnFile})" - ''} cat > ${generatedConfigPath} <