Compare commits

...

2 commits

Author SHA1 Message Date
f0fd0f342e Add EC2 runtime test for assume-role path
Some checks failed
buildbot/nix-eval Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.buildbot-autoscale-ext-pyright Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.package-default Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.app-autoscalerctl Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.app-nix-builder-autoscaler Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.app-default Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.buildbot-autoscale-ext-ruff Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.package-nix-builder-autoscaler Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.nix-builder-autoscaler-integration-tests Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.nix-builder-autoscaler-pyright Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.nix-builder-autoscaler-ruff Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.nix-builder-autoscaler-unit-tests Build done.
buildbot/nix-build Build done.
2026-03-05 12:42:57 +01:00
4c7333ca07 Add optional autoscaler cross-account assume-role support 2026-03-05 12:38:10 +01:00
4 changed files with 137 additions and 11 deletions

View file

@ -28,6 +28,7 @@ class AwsConfig:
subnet_ids: list[str] = field(default_factory=list) subnet_ids: list[str] = field(default_factory=list)
security_group_ids: list[str] = field(default_factory=list) security_group_ids: list[str] = field(default_factory=list)
instance_profile_arn: str = "" instance_profile_arn: str = ""
assume_role_arn: str = ""
@dataclass @dataclass

View file

@ -7,7 +7,9 @@ import json
import logging import logging
import random import random
import socket import socket
import threading
import time import time
from datetime import UTC, datetime
from typing import Any from typing import Any
import boto3 import boto3
@ -58,7 +60,16 @@ class EC2Runtime(RuntimeAdapter):
_client: Any = None, _client: Any = None,
_tailscale_socket_path: str = "/run/tailscale/tailscaled.sock", _tailscale_socket_path: str = "/run/tailscale/tailscaled.sock",
) -> None: ) -> None:
self._client: Any = _client or boto3.client("ec2", region_name=config.region) self._base_client: Any = _client or boto3.client("ec2", region_name=config.region)
self._region = config.region
self._assume_role_arn = config.assume_role_arn.strip()
self._assume_role_session_name = "nix-builder-autoscaler"
self._assumed_client: Any | None = None
self._assumed_client_expiration: datetime | None = None
self._assume_role_lock = threading.Lock()
self._sts_client: Any | None = None
if self._assume_role_arn != "":
self._sts_client = boto3.client("sts", region_name=config.region)
self._instance_type = config.instance_type self._instance_type = config.instance_type
self._launch_template_id = config.launch_template_id self._launch_template_id = config.launch_template_id
self._on_demand_launch_template_id = config.on_demand_launch_template_id self._on_demand_launch_template_id = config.on_demand_launch_template_id
@ -78,9 +89,12 @@ class EC2Runtime(RuntimeAdapter):
warnings so a transient permissions issue does not prevent startup. warnings so a transient permissions issue does not prevent startup.
""" """
try: try:
ec2_client = self._get_ec2_client()
target_azs: set[str] = set() target_azs: set[str] = set()
if self._subnet_ids: if self._subnet_ids:
resp = self._client.describe_subnets(SubnetIds=self._subnet_ids) resp = self._call_with_backoff(
ec2_client.describe_subnets, SubnetIds=self._subnet_ids
)
for subnet in resp.get("Subnets", []): for subnet in resp.get("Subnets", []):
az = subnet.get("AvailabilityZone") az = subnet.get("AvailabilityZone")
if az: if az:
@ -92,14 +106,15 @@ class EC2Runtime(RuntimeAdapter):
if target_azs: if target_azs:
filters.append({"Name": "location", "Values": list(target_azs)}) filters.append({"Name": "location", "Values": list(target_azs)})
resp = self._client.describe_instance_type_offerings( resp = self._call_with_backoff(
ec2_client.describe_instance_type_offerings,
LocationType="availability-zone", LocationType="availability-zone",
Filters=filters, Filters=filters,
) )
available_azs = {o["Location"] for o in resp.get("InstanceTypeOfferings", [])} available_azs = {o["Location"] for o in resp.get("InstanceTypeOfferings", [])}
if not available_azs: if not available_azs:
region = self._client.meta.region_name region = ec2_client.meta.region_name
log.error( log.error(
"preflight_misconfiguration", "preflight_misconfiguration",
extra={ extra={
@ -194,15 +209,15 @@ class EC2Runtime(RuntimeAdapter):
self._subnet_index += 1 self._subnet_index += 1
params["SubnetId"] = subnet params["SubnetId"] = subnet
resp = self._call_with_backoff(self._client.run_instances, **params) ec2_client = self._get_ec2_client()
resp = self._call_with_backoff(ec2_client.run_instances, **params)
return resp["Instances"][0]["InstanceId"] return resp["Instances"][0]["InstanceId"]
def describe_instance(self, instance_id: str) -> dict: def describe_instance(self, instance_id: str) -> dict:
"""Return normalized instance info dict.""" """Return normalized instance info dict."""
try: try:
resp = self._call_with_backoff( ec2_client = self._get_ec2_client()
self._client.describe_instances, InstanceIds=[instance_id] resp = self._call_with_backoff(ec2_client.describe_instances, InstanceIds=[instance_id])
)
except RuntimeAdapterError: except RuntimeAdapterError:
return {"state": "terminated", "tailscale_ip": None, "launch_time": None} return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
@ -227,12 +242,14 @@ class EC2Runtime(RuntimeAdapter):
def terminate_instance(self, instance_id: str) -> None: def terminate_instance(self, instance_id: str) -> None:
"""Terminate the instance.""" """Terminate the instance."""
self._call_with_backoff(self._client.terminate_instances, InstanceIds=[instance_id]) ec2_client = self._get_ec2_client()
self._call_with_backoff(ec2_client.terminate_instances, InstanceIds=[instance_id])
def list_managed_instances(self) -> list[dict]: def list_managed_instances(self) -> list[dict]:
"""Return list of managed instances.""" """Return list of managed instances."""
ec2_client = self._get_ec2_client()
resp = self._call_with_backoff( resp = self._call_with_backoff(
self._client.describe_instances, ec2_client.describe_instances,
Filters=[ Filters=[
{"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]}, {"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]},
{ {
@ -255,6 +272,45 @@ class EC2Runtime(RuntimeAdapter):
) )
return result return result
def _get_ec2_client(self) -> Any:
if self._assume_role_arn == "":
return self._base_client
return self._get_assumed_ec2_client()
def _get_assumed_ec2_client(self) -> Any:
now = datetime.now(UTC)
with self._assume_role_lock:
if (
self._assumed_client is not None
and self._assumed_client_expiration is not None
and (self._assumed_client_expiration - now).total_seconds() > 300
):
return self._assumed_client
if self._sts_client is None:
raise RuntimeAdapterError(
"assume_role_arn is set but STS client is not initialized",
category="misconfiguration",
)
try:
response = self._sts_client.assume_role(
RoleArn=self._assume_role_arn,
RoleSessionName=self._assume_role_session_name,
)
except ClientError as exc:
raise RuntimeAdapterError(str(exc), category="misconfiguration") from exc
credentials = response["Credentials"]
self._assumed_client_expiration = credentials["Expiration"]
self._assumed_client = boto3.client(
"ec2",
region_name=self._region,
aws_access_key_id=credentials["AccessKeyId"],
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)
return self._assumed_client
def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any: def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any:
"""Call *fn* with exponential backoff and full jitter on retryable errors.""" """Call *fn* with exponential backoff and full jitter on retryable errors."""
delay = 0.5 delay = 0.5

View file

@ -1,6 +1,6 @@
"""Unit tests for the EC2 runtime adapter using botocore Stubber.""" """Unit tests for the EC2 runtime adapter using botocore Stubber."""
from datetime import UTC, datetime from datetime import UTC, datetime, timedelta
from unittest.mock import patch from unittest.mock import patch
import boto3 import boto3
@ -462,3 +462,62 @@ class TestErrorClassification:
with pytest.raises(RuntimeAdapterError) as exc_info: with pytest.raises(RuntimeAdapterError) as exc_info:
runtime.launch_instance("slot001", "#!/bin/bash") runtime.launch_instance("slot001", "#!/bin/bash")
assert exc_info.value.category == "throttled" assert exc_info.value.category == "throttled"
class TestAssumeRole:
def test_uses_assumed_role_credentials_for_ec2_calls(self):
config = _make_config()
config.assume_role_arn = "arn:aws:iam::210987654321:role/buildbot-autoscaler-controller"
base_ec2 = boto3.client("ec2", region_name="us-east-1")
assumed_ec2 = boto3.client("ec2", region_name="us-east-1")
sts_client = boto3.client("sts", region_name="us-east-1")
sts_stubber = Stubber(sts_client)
sts_stubber.add_response(
"assume_role",
{
"Credentials": {
"AccessKeyId": "ASIAAAAAAAAAAAAAAAAA",
"SecretAccessKey": "s" * 40,
"SessionToken": "t" * 256,
"Expiration": datetime.now(UTC) + timedelta(hours=1),
},
"AssumedRoleUser": {
"AssumedRoleId": "AROA1234567890EXAMPLE:nix-builder-autoscaler",
"Arn": (
"arn:aws:sts::210987654321:assumed-role/"
"buildbot-autoscaler-controller/nix-builder-autoscaler"
),
},
},
{
"RoleArn": config.assume_role_arn,
"RoleSessionName": "nix-builder-autoscaler",
},
)
sts_stubber.activate()
assumed_stubber = Stubber(assumed_ec2)
assumed_stubber.add_response(
"run_instances",
{"Instances": [{"InstanceId": "i-assumed"}], "OwnerId": "210987654321"},
)
assumed_stubber.activate()
real_boto3_client = boto3.client
def _patched_client(service_name, **kwargs):
if service_name == "sts":
return sts_client
if service_name == "ec2" and kwargs.get("aws_access_key_id") == "ASIAAAAAAAAAAAAAAAAA":
return assumed_ec2
return real_boto3_client(service_name, **kwargs)
with patch("nix_builder_autoscaler.runtime.ec2.boto3.client", side_effect=_patched_client):
runtime = EC2Runtime(config, _client=base_ec2)
instance_id = runtime.launch_instance("slot001", "#!/bin/bash")
assert instance_id == "i-assumed"
sts_stubber.assert_no_pending_responses()
assumed_stubber.assert_no_pending_responses()

View file

@ -97,6 +97,12 @@ in
default = ""; default = "";
description = "Optional instance profile ARN override."; description = "Optional instance profile ARN override.";
}; };
assumeRoleArnFile = lib.mkOption {
type = lib.types.nullOr lib.types.str;
default = null;
description = "Optional file containing an IAM role ARN for cross-account autoscaler control-plane calls.";
};
}; };
haproxy = { haproxy = {
@ -329,6 +335,9 @@ in
${lib.optionalString (cfg.aws.onDemandLaunchTemplateIdFile != null) '' ${lib.optionalString (cfg.aws.onDemandLaunchTemplateIdFile != null) ''
on_demand_launch_template_id="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.onDemandLaunchTemplateIdFile})" on_demand_launch_template_id="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.onDemandLaunchTemplateIdFile})"
''} ''}
${lib.optionalString (cfg.aws.assumeRoleArnFile != null) ''
assume_role_arn="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.assumeRoleArnFile})"
''}
cat > ${generatedConfigPath} <<EOF cat > ${generatedConfigPath} <<EOF
[server] [server]
@ -346,6 +355,7 @@ in
subnet_ids = $subnet_ids_json subnet_ids = $subnet_ids_json
security_group_ids = ${tomlStringList cfg.aws.securityGroupIds} security_group_ids = ${tomlStringList cfg.aws.securityGroupIds}
instance_profile_arn = "${cfg.aws.instanceProfileArn}" instance_profile_arn = "${cfg.aws.instanceProfileArn}"
${lib.optionalString (cfg.aws.assumeRoleArnFile != null) ''assume_role_arn = "$assume_role_arn"''}
[haproxy] [haproxy]
runtime_socket = "${cfg.haproxy.runtimeSocket}" runtime_socket = "${cfg.haproxy.runtimeSocket}"