Compare commits
2 commits
5092005e05
...
f0fd0f342e
| Author | SHA1 | Date | |
|---|---|---|---|
| f0fd0f342e | |||
| 4c7333ca07 |
4 changed files with 137 additions and 11 deletions
|
|
@ -28,6 +28,7 @@ class AwsConfig:
|
||||||
subnet_ids: list[str] = field(default_factory=list)
|
subnet_ids: list[str] = field(default_factory=list)
|
||||||
security_group_ids: list[str] = field(default_factory=list)
|
security_group_ids: list[str] = field(default_factory=list)
|
||||||
instance_profile_arn: str = ""
|
instance_profile_arn: str = ""
|
||||||
|
assume_role_arn: str = ""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,9 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from datetime import UTC, datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
|
@ -58,7 +60,16 @@ class EC2Runtime(RuntimeAdapter):
|
||||||
_client: Any = None,
|
_client: Any = None,
|
||||||
_tailscale_socket_path: str = "/run/tailscale/tailscaled.sock",
|
_tailscale_socket_path: str = "/run/tailscale/tailscaled.sock",
|
||||||
) -> None:
|
) -> None:
|
||||||
self._client: Any = _client or boto3.client("ec2", region_name=config.region)
|
self._base_client: Any = _client or boto3.client("ec2", region_name=config.region)
|
||||||
|
self._region = config.region
|
||||||
|
self._assume_role_arn = config.assume_role_arn.strip()
|
||||||
|
self._assume_role_session_name = "nix-builder-autoscaler"
|
||||||
|
self._assumed_client: Any | None = None
|
||||||
|
self._assumed_client_expiration: datetime | None = None
|
||||||
|
self._assume_role_lock = threading.Lock()
|
||||||
|
self._sts_client: Any | None = None
|
||||||
|
if self._assume_role_arn != "":
|
||||||
|
self._sts_client = boto3.client("sts", region_name=config.region)
|
||||||
self._instance_type = config.instance_type
|
self._instance_type = config.instance_type
|
||||||
self._launch_template_id = config.launch_template_id
|
self._launch_template_id = config.launch_template_id
|
||||||
self._on_demand_launch_template_id = config.on_demand_launch_template_id
|
self._on_demand_launch_template_id = config.on_demand_launch_template_id
|
||||||
|
|
@ -78,9 +89,12 @@ class EC2Runtime(RuntimeAdapter):
|
||||||
warnings so a transient permissions issue does not prevent startup.
|
warnings so a transient permissions issue does not prevent startup.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
ec2_client = self._get_ec2_client()
|
||||||
target_azs: set[str] = set()
|
target_azs: set[str] = set()
|
||||||
if self._subnet_ids:
|
if self._subnet_ids:
|
||||||
resp = self._client.describe_subnets(SubnetIds=self._subnet_ids)
|
resp = self._call_with_backoff(
|
||||||
|
ec2_client.describe_subnets, SubnetIds=self._subnet_ids
|
||||||
|
)
|
||||||
for subnet in resp.get("Subnets", []):
|
for subnet in resp.get("Subnets", []):
|
||||||
az = subnet.get("AvailabilityZone")
|
az = subnet.get("AvailabilityZone")
|
||||||
if az:
|
if az:
|
||||||
|
|
@ -92,14 +106,15 @@ class EC2Runtime(RuntimeAdapter):
|
||||||
if target_azs:
|
if target_azs:
|
||||||
filters.append({"Name": "location", "Values": list(target_azs)})
|
filters.append({"Name": "location", "Values": list(target_azs)})
|
||||||
|
|
||||||
resp = self._client.describe_instance_type_offerings(
|
resp = self._call_with_backoff(
|
||||||
|
ec2_client.describe_instance_type_offerings,
|
||||||
LocationType="availability-zone",
|
LocationType="availability-zone",
|
||||||
Filters=filters,
|
Filters=filters,
|
||||||
)
|
)
|
||||||
available_azs = {o["Location"] for o in resp.get("InstanceTypeOfferings", [])}
|
available_azs = {o["Location"] for o in resp.get("InstanceTypeOfferings", [])}
|
||||||
|
|
||||||
if not available_azs:
|
if not available_azs:
|
||||||
region = self._client.meta.region_name
|
region = ec2_client.meta.region_name
|
||||||
log.error(
|
log.error(
|
||||||
"preflight_misconfiguration",
|
"preflight_misconfiguration",
|
||||||
extra={
|
extra={
|
||||||
|
|
@ -194,15 +209,15 @@ class EC2Runtime(RuntimeAdapter):
|
||||||
self._subnet_index += 1
|
self._subnet_index += 1
|
||||||
params["SubnetId"] = subnet
|
params["SubnetId"] = subnet
|
||||||
|
|
||||||
resp = self._call_with_backoff(self._client.run_instances, **params)
|
ec2_client = self._get_ec2_client()
|
||||||
|
resp = self._call_with_backoff(ec2_client.run_instances, **params)
|
||||||
return resp["Instances"][0]["InstanceId"]
|
return resp["Instances"][0]["InstanceId"]
|
||||||
|
|
||||||
def describe_instance(self, instance_id: str) -> dict:
|
def describe_instance(self, instance_id: str) -> dict:
|
||||||
"""Return normalized instance info dict."""
|
"""Return normalized instance info dict."""
|
||||||
try:
|
try:
|
||||||
resp = self._call_with_backoff(
|
ec2_client = self._get_ec2_client()
|
||||||
self._client.describe_instances, InstanceIds=[instance_id]
|
resp = self._call_with_backoff(ec2_client.describe_instances, InstanceIds=[instance_id])
|
||||||
)
|
|
||||||
except RuntimeAdapterError:
|
except RuntimeAdapterError:
|
||||||
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
|
return {"state": "terminated", "tailscale_ip": None, "launch_time": None}
|
||||||
|
|
||||||
|
|
@ -227,12 +242,14 @@ class EC2Runtime(RuntimeAdapter):
|
||||||
|
|
||||||
def terminate_instance(self, instance_id: str) -> None:
|
def terminate_instance(self, instance_id: str) -> None:
|
||||||
"""Terminate the instance."""
|
"""Terminate the instance."""
|
||||||
self._call_with_backoff(self._client.terminate_instances, InstanceIds=[instance_id])
|
ec2_client = self._get_ec2_client()
|
||||||
|
self._call_with_backoff(ec2_client.terminate_instances, InstanceIds=[instance_id])
|
||||||
|
|
||||||
def list_managed_instances(self) -> list[dict]:
|
def list_managed_instances(self) -> list[dict]:
|
||||||
"""Return list of managed instances."""
|
"""Return list of managed instances."""
|
||||||
|
ec2_client = self._get_ec2_client()
|
||||||
resp = self._call_with_backoff(
|
resp = self._call_with_backoff(
|
||||||
self._client.describe_instances,
|
ec2_client.describe_instances,
|
||||||
Filters=[
|
Filters=[
|
||||||
{"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]},
|
{"Name": "tag:ManagedBy", "Values": ["nix-builder-autoscaler"]},
|
||||||
{
|
{
|
||||||
|
|
@ -255,6 +272,45 @@ class EC2Runtime(RuntimeAdapter):
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _get_ec2_client(self) -> Any:
|
||||||
|
if self._assume_role_arn == "":
|
||||||
|
return self._base_client
|
||||||
|
return self._get_assumed_ec2_client()
|
||||||
|
|
||||||
|
def _get_assumed_ec2_client(self) -> Any:
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
with self._assume_role_lock:
|
||||||
|
if (
|
||||||
|
self._assumed_client is not None
|
||||||
|
and self._assumed_client_expiration is not None
|
||||||
|
and (self._assumed_client_expiration - now).total_seconds() > 300
|
||||||
|
):
|
||||||
|
return self._assumed_client
|
||||||
|
|
||||||
|
if self._sts_client is None:
|
||||||
|
raise RuntimeAdapterError(
|
||||||
|
"assume_role_arn is set but STS client is not initialized",
|
||||||
|
category="misconfiguration",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self._sts_client.assume_role(
|
||||||
|
RoleArn=self._assume_role_arn,
|
||||||
|
RoleSessionName=self._assume_role_session_name,
|
||||||
|
)
|
||||||
|
except ClientError as exc:
|
||||||
|
raise RuntimeAdapterError(str(exc), category="misconfiguration") from exc
|
||||||
|
credentials = response["Credentials"]
|
||||||
|
self._assumed_client_expiration = credentials["Expiration"]
|
||||||
|
self._assumed_client = boto3.client(
|
||||||
|
"ec2",
|
||||||
|
region_name=self._region,
|
||||||
|
aws_access_key_id=credentials["AccessKeyId"],
|
||||||
|
aws_secret_access_key=credentials["SecretAccessKey"],
|
||||||
|
aws_session_token=credentials["SessionToken"],
|
||||||
|
)
|
||||||
|
return self._assumed_client
|
||||||
|
|
||||||
def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any:
|
def _call_with_backoff(self, fn: Any, *args: Any, max_retries: int = 3, **kwargs: Any) -> Any:
|
||||||
"""Call *fn* with exponential backoff and full jitter on retryable errors."""
|
"""Call *fn* with exponential backoff and full jitter on retryable errors."""
|
||||||
delay = 0.5
|
delay = 0.5
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""Unit tests for the EC2 runtime adapter using botocore Stubber."""
|
"""Unit tests for the EC2 runtime adapter using botocore Stubber."""
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime, timedelta
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
|
@ -462,3 +462,62 @@ class TestErrorClassification:
|
||||||
with pytest.raises(RuntimeAdapterError) as exc_info:
|
with pytest.raises(RuntimeAdapterError) as exc_info:
|
||||||
runtime.launch_instance("slot001", "#!/bin/bash")
|
runtime.launch_instance("slot001", "#!/bin/bash")
|
||||||
assert exc_info.value.category == "throttled"
|
assert exc_info.value.category == "throttled"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssumeRole:
|
||||||
|
def test_uses_assumed_role_credentials_for_ec2_calls(self):
|
||||||
|
config = _make_config()
|
||||||
|
config.assume_role_arn = "arn:aws:iam::210987654321:role/buildbot-autoscaler-controller"
|
||||||
|
|
||||||
|
base_ec2 = boto3.client("ec2", region_name="us-east-1")
|
||||||
|
assumed_ec2 = boto3.client("ec2", region_name="us-east-1")
|
||||||
|
sts_client = boto3.client("sts", region_name="us-east-1")
|
||||||
|
|
||||||
|
sts_stubber = Stubber(sts_client)
|
||||||
|
sts_stubber.add_response(
|
||||||
|
"assume_role",
|
||||||
|
{
|
||||||
|
"Credentials": {
|
||||||
|
"AccessKeyId": "ASIAAAAAAAAAAAAAAAAA",
|
||||||
|
"SecretAccessKey": "s" * 40,
|
||||||
|
"SessionToken": "t" * 256,
|
||||||
|
"Expiration": datetime.now(UTC) + timedelta(hours=1),
|
||||||
|
},
|
||||||
|
"AssumedRoleUser": {
|
||||||
|
"AssumedRoleId": "AROA1234567890EXAMPLE:nix-builder-autoscaler",
|
||||||
|
"Arn": (
|
||||||
|
"arn:aws:sts::210987654321:assumed-role/"
|
||||||
|
"buildbot-autoscaler-controller/nix-builder-autoscaler"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"RoleArn": config.assume_role_arn,
|
||||||
|
"RoleSessionName": "nix-builder-autoscaler",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sts_stubber.activate()
|
||||||
|
|
||||||
|
assumed_stubber = Stubber(assumed_ec2)
|
||||||
|
assumed_stubber.add_response(
|
||||||
|
"run_instances",
|
||||||
|
{"Instances": [{"InstanceId": "i-assumed"}], "OwnerId": "210987654321"},
|
||||||
|
)
|
||||||
|
assumed_stubber.activate()
|
||||||
|
|
||||||
|
real_boto3_client = boto3.client
|
||||||
|
|
||||||
|
def _patched_client(service_name, **kwargs):
|
||||||
|
if service_name == "sts":
|
||||||
|
return sts_client
|
||||||
|
if service_name == "ec2" and kwargs.get("aws_access_key_id") == "ASIAAAAAAAAAAAAAAAAA":
|
||||||
|
return assumed_ec2
|
||||||
|
return real_boto3_client(service_name, **kwargs)
|
||||||
|
|
||||||
|
with patch("nix_builder_autoscaler.runtime.ec2.boto3.client", side_effect=_patched_client):
|
||||||
|
runtime = EC2Runtime(config, _client=base_ec2)
|
||||||
|
instance_id = runtime.launch_instance("slot001", "#!/bin/bash")
|
||||||
|
|
||||||
|
assert instance_id == "i-assumed"
|
||||||
|
sts_stubber.assert_no_pending_responses()
|
||||||
|
assumed_stubber.assert_no_pending_responses()
|
||||||
|
|
|
||||||
|
|
@ -97,6 +97,12 @@ in
|
||||||
default = "";
|
default = "";
|
||||||
description = "Optional instance profile ARN override.";
|
description = "Optional instance profile ARN override.";
|
||||||
};
|
};
|
||||||
|
|
||||||
|
assumeRoleArnFile = lib.mkOption {
|
||||||
|
type = lib.types.nullOr lib.types.str;
|
||||||
|
default = null;
|
||||||
|
description = "Optional file containing an IAM role ARN for cross-account autoscaler control-plane calls.";
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
haproxy = {
|
haproxy = {
|
||||||
|
|
@ -329,6 +335,9 @@ in
|
||||||
${lib.optionalString (cfg.aws.onDemandLaunchTemplateIdFile != null) ''
|
${lib.optionalString (cfg.aws.onDemandLaunchTemplateIdFile != null) ''
|
||||||
on_demand_launch_template_id="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.onDemandLaunchTemplateIdFile})"
|
on_demand_launch_template_id="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.onDemandLaunchTemplateIdFile})"
|
||||||
''}
|
''}
|
||||||
|
${lib.optionalString (cfg.aws.assumeRoleArnFile != null) ''
|
||||||
|
assume_role_arn="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.assumeRoleArnFile})"
|
||||||
|
''}
|
||||||
|
|
||||||
cat > ${generatedConfigPath} <<EOF
|
cat > ${generatedConfigPath} <<EOF
|
||||||
[server]
|
[server]
|
||||||
|
|
@ -346,6 +355,7 @@ in
|
||||||
subnet_ids = $subnet_ids_json
|
subnet_ids = $subnet_ids_json
|
||||||
security_group_ids = ${tomlStringList cfg.aws.securityGroupIds}
|
security_group_ids = ${tomlStringList cfg.aws.securityGroupIds}
|
||||||
instance_profile_arn = "${cfg.aws.instanceProfileArn}"
|
instance_profile_arn = "${cfg.aws.instanceProfileArn}"
|
||||||
|
${lib.optionalString (cfg.aws.assumeRoleArnFile != null) ''assume_role_arn = "$assume_role_arn"''}
|
||||||
|
|
||||||
[haproxy]
|
[haproxy]
|
||||||
runtime_socket = "${cfg.haproxy.runtimeSocket}"
|
runtime_socket = "${cfg.haproxy.runtimeSocket}"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue