support dual launch templates: spot for normal builds, on-demand for nested virtualization
Some checks failed
buildbot/nix-eval Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.package-nix-builder-autoscaler Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.package-default Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.app-autoscalerctl Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.app-default Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.app-nix-builder-autoscaler Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.nix-builder-autoscaler-pyright Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.nix-builder-autoscaler-integration-tests Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.nix-builder-autoscaler-ruff Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.nix-builder-autoscaler-unit-tests Build done.
buildbot/nix-build gitea:ops/nix-builder-autoscaler#checks.x86_64-linux.package-buildbot-autoscale-ext Build done.
buildbot/nix-build Build done.

AWS does not allow cpu_options.nested_virtualization with spot instances. Add a second
launch template (on-demand, cpu_options enabled) alongside the existing spot template.
The autoscaler selects the template per-system based on nested_virtualization config.

- RuntimeAdapter.launch_spot -> launch_instance(nested_virtualization=False)
- EC2Runtime: selects spot or on-demand LT; raises misconfiguration error if
  on_demand_launch_template_id is empty when nested_virtualization=True
- AwsConfig: add on_demand_launch_template_id field
- SystemConfig: add nested_virtualization field
- Scheduler: looks up system config to pass nested_virtualization flag
- NixOS module: new aws.onDemandLaunchTemplateIdFile + capacity.nestedVirtualization
  options; assertion prevents enabling nestedVirtualization without the LT ID file
This commit is contained in:
Abel Luck 2026-02-28 10:33:26 +01:00
parent 3f70094c0a
commit 02b1a063ab
9 changed files with 101 additions and 35 deletions

View file

@ -23,6 +23,7 @@ class AwsConfig:
region: str = "us-east-1" region: str = "us-east-1"
launch_template_id: str = "" launch_template_id: str = ""
on_demand_launch_template_id: str = ""
subnet_ids: list[str] = field(default_factory=list) subnet_ids: list[str] = field(default_factory=list)
security_group_ids: list[str] = field(default_factory=list) security_group_ids: list[str] = field(default_factory=list)
instance_profile_arn: str = "" instance_profile_arn: str = ""
@ -51,6 +52,7 @@ class SystemConfig:
launch_batch_size: int = 1 launch_batch_size: int = 1
scale_down_idle_seconds: int = 900 scale_down_idle_seconds: int = 900
termination_cooldown_seconds: int = 180 termination_cooldown_seconds: int = 180
nested_virtualization: bool = False
@dataclass @dataclass

View file

@ -12,7 +12,17 @@ from typing import Any
class JSONFormatter(logging.Formatter): class JSONFormatter(logging.Formatter):
"""Format log records as single-line JSON.""" """Format log records as single-line JSON."""
EXTRA_FIELDS = ("slot_id", "reservation_id", "instance_id", "request_id", "error", "category", "count", "ids", "idle_seconds") EXTRA_FIELDS = (
"slot_id",
"reservation_id",
"instance_id",
"request_id",
"error",
"category",
"count",
"ids",
"idle_seconds",
)
def format(self, record: logging.LogRecord) -> str: def format(self, record: logging.LogRecord) -> str:
"""Format a log record as JSON.""" """Format a log record as JSON."""

View file

@ -21,8 +21,14 @@ class RuntimeAdapter(ABC):
"""Interface for compute runtime backends (EC2, fake, etc.).""" """Interface for compute runtime backends (EC2, fake, etc.)."""
@abstractmethod @abstractmethod
def launch_spot(self, slot_id: str, user_data: str) -> str: def launch_instance(
"""Launch a spot instance for slot_id. Return instance_id.""" self, slot_id: str, user_data: str, *, nested_virtualization: bool = False
) -> str:
"""Launch an instance for slot_id. Return instance_id.
When nested_virtualization is True, an on-demand instance is launched using
the on-demand launch template. When False (default), a spot instance is launched.
"""
@abstractmethod @abstractmethod
def describe_instance(self, instance_id: str) -> dict: def describe_instance(self, instance_id: str) -> dict:

View file

@ -60,6 +60,7 @@ class EC2Runtime(RuntimeAdapter):
) -> None: ) -> None:
self._client: Any = _client or boto3.client("ec2", region_name=config.region) self._client: Any = _client or boto3.client("ec2", region_name=config.region)
self._launch_template_id = config.launch_template_id self._launch_template_id = config.launch_template_id
self._on_demand_launch_template_id = config.on_demand_launch_template_id
self._subnet_ids = list(config.subnet_ids) self._subnet_ids = list(config.subnet_ids)
self._security_group_ids = list(config.security_group_ids) self._security_group_ids = list(config.security_group_ids)
self._instance_profile_arn = config.instance_profile_arn self._instance_profile_arn = config.instance_profile_arn
@ -67,22 +68,32 @@ class EC2Runtime(RuntimeAdapter):
self._subnet_index = 0 self._subnet_index = 0
self._tailscale_socket_path = _tailscale_socket_path self._tailscale_socket_path = _tailscale_socket_path
def launch_spot(self, slot_id: str, user_data: str) -> str: def launch_instance(
"""Launch a spot instance for *slot_id*. Return instance ID.""" self, slot_id: str, user_data: str, *, nested_virtualization: bool = False
) -> str:
"""Launch an instance for *slot_id*. Return instance ID.
When nested_virtualization is True, an on-demand instance is launched using the
on-demand launch template (cpu_options nested virt enabled, no spot market options).
When False (default), a spot instance is launched using the spot launch template.
"""
if nested_virtualization:
if not self._on_demand_launch_template_id:
raise RuntimeAdapterError(
"nested_virtualization=True but on_demand_launch_template_id is not configured",
category="misconfiguration",
)
lt_id = self._on_demand_launch_template_id
else:
lt_id = self._launch_template_id
params: dict[str, Any] = { params: dict[str, Any] = {
"MinCount": 1, "MinCount": 1,
"MaxCount": 1, "MaxCount": 1,
"LaunchTemplate": { "LaunchTemplate": {
"LaunchTemplateId": self._launch_template_id, "LaunchTemplateId": lt_id,
"Version": "$Latest", "Version": "$Latest",
}, },
"InstanceMarketOptions": {
"MarketType": "spot",
"SpotOptions": {
"SpotInstanceType": "one-time",
"InstanceInterruptionBehavior": "terminate",
},
},
"UserData": user_data, "UserData": user_data,
"TagSpecifications": [ "TagSpecifications": [
{ {
@ -98,6 +109,15 @@ class EC2Runtime(RuntimeAdapter):
], ],
} }
if not nested_virtualization:
params["InstanceMarketOptions"] = {
"MarketType": "spot",
"SpotOptions": {
"SpotInstanceType": "one-time",
"InstanceInterruptionBehavior": "terminate",
},
}
if self._subnet_ids: if self._subnet_ids:
subnet = self._subnet_ids[self._subnet_index % len(self._subnet_ids)] subnet = self._subnet_ids[self._subnet_index % len(self._subnet_ids)]
self._subnet_index += 1 self._subnet_index += 1

View file

@ -38,8 +38,10 @@ class FakeRuntime(RuntimeAdapter):
self._tick_count: int = 0 self._tick_count: int = 0
self._next_ip_counter: int = 1 self._next_ip_counter: int = 1
def launch_spot(self, slot_id: str, user_data: str) -> str: def launch_instance(
"""Launch a fake spot instance.""" self, slot_id: str, user_data: str, *, nested_virtualization: bool = False
) -> str:
"""Launch a fake instance (nested_virtualization is accepted but ignored)."""
if slot_id in self._launch_failures: if slot_id in self._launch_failures:
self._launch_failures.discard(slot_id) self._launch_failures.discard(slot_id)
raise RuntimeAdapterError( raise RuntimeAdapterError(

View file

@ -245,8 +245,11 @@ def _launch_slot(
"""Launch a single slot. Transition to LAUNCHING on success, ERROR on failure.""" """Launch a single slot. Transition to LAUNCHING on success, ERROR on failure."""
slot_id = slot["slot_id"] slot_id = slot["slot_id"]
user_data = render_userdata(slot_id) user_data = render_userdata(slot_id)
system_name = slot.get("system", config.capacity.default_system)
sys_cfg = next((s for s in config.systems if s.name == system_name), None)
nested_virt = sys_cfg.nested_virtualization if sys_cfg else False
try: try:
instance_id = runtime.launch_spot(slot_id, user_data) instance_id = runtime.launch_instance(slot_id, user_data, nested_virtualization=nested_virt)
metrics.counter("autoscaler_ec2_launch_total", {"result": "success"}, 1.0) metrics.counter("autoscaler_ec2_launch_total", {"result": "success"}, 1.0)
db.update_slot_state(slot_id, SlotState.LAUNCHING, instance_id=instance_id) db.update_slot_state(slot_id, SlotState.LAUNCHING, instance_id=instance_id)
log.info("slot_launched", extra={"slot_id": slot_id, "instance_id": instance_id}) log.info("slot_launched", extra={"slot_id": slot_id, "instance_id": instance_id})

View file

@ -73,7 +73,7 @@ class TestLaunchSpot:
stubber.add_response("run_instances", response, expected_params) stubber.add_response("run_instances", response, expected_params)
runtime = _make_runtime(stubber, ec2_client, config=config) runtime = _make_runtime(stubber, ec2_client, config=config)
iid = runtime.launch_spot("slot001", "#!/bin/bash\necho hello") iid = runtime.launch_instance("slot001", "#!/bin/bash\necho hello")
assert iid == "i-12345678" assert iid == "i-12345678"
stubber.assert_no_pending_responses() stubber.assert_no_pending_responses()
@ -90,8 +90,8 @@ class TestLaunchSpot:
) )
runtime = _make_runtime(stubber, ec2_client, config=config) runtime = _make_runtime(stubber, ec2_client, config=config)
runtime.launch_spot("slot001", "") runtime.launch_instance("slot001", "")
runtime.launch_spot("slot002", "") runtime.launch_instance("slot002", "")
stubber.assert_no_pending_responses() stubber.assert_no_pending_responses()
@ -418,7 +418,7 @@ class TestErrorClassification:
runtime = _make_runtime(stubber, ec2_client) runtime = _make_runtime(stubber, ec2_client)
with pytest.raises(RuntimeAdapterError) as exc_info: with pytest.raises(RuntimeAdapterError) as exc_info:
runtime.launch_spot("slot001", "#!/bin/bash") runtime.launch_instance("slot001", "#!/bin/bash")
assert exc_info.value.category == "capacity_unavailable" assert exc_info.value.category == "capacity_unavailable"
@patch("nix_builder_autoscaler.runtime.ec2.time.sleep") @patch("nix_builder_autoscaler.runtime.ec2.time.sleep")
@ -439,7 +439,7 @@ class TestErrorClassification:
) )
runtime = _make_runtime(stubber, ec2_client) runtime = _make_runtime(stubber, ec2_client)
iid = runtime.launch_spot("slot001", "#!/bin/bash") iid = runtime.launch_instance("slot001", "#!/bin/bash")
assert iid == "i-retry123" assert iid == "i-retry123"
assert mock_sleep.called assert mock_sleep.called
stubber.assert_no_pending_responses() stubber.assert_no_pending_responses()
@ -460,5 +460,5 @@ class TestErrorClassification:
runtime = _make_runtime(stubber, ec2_client) runtime = _make_runtime(stubber, ec2_client)
with pytest.raises(RuntimeAdapterError) as exc_info: with pytest.raises(RuntimeAdapterError) as exc_info:
runtime.launch_spot("slot001", "#!/bin/bash") runtime.launch_instance("slot001", "#!/bin/bash")
assert exc_info.value.category == "throttled" assert exc_info.value.category == "throttled"

View file

@ -9,13 +9,13 @@ from nix_builder_autoscaler.runtime.fake import FakeRuntime
class TestLaunchSpot: class TestLaunchSpot:
def test_returns_synthetic_instance_id(self): def test_returns_synthetic_instance_id(self):
rt = FakeRuntime() rt = FakeRuntime()
iid = rt.launch_spot("slot001", "#!/bin/bash\necho hello") iid = rt.launch_instance("slot001", "#!/bin/bash\necho hello")
assert iid.startswith("i-fake-") assert iid.startswith("i-fake-")
assert len(iid) > 10 assert len(iid) > 10
def test_instance_starts_pending(self): def test_instance_starts_pending(self):
rt = FakeRuntime() rt = FakeRuntime()
iid = rt.launch_spot("slot001", "") iid = rt.launch_instance("slot001", "")
info = rt.describe_instance(iid) info = rt.describe_instance(iid)
assert info["state"] == "pending" assert info["state"] == "pending"
assert info["tailscale_ip"] is None assert info["tailscale_ip"] is None
@ -24,7 +24,7 @@ class TestLaunchSpot:
class TestTickProgression: class TestTickProgression:
def test_transitions_to_running_after_configured_ticks(self): def test_transitions_to_running_after_configured_ticks(self):
rt = FakeRuntime(launch_latency_ticks=3, ip_delay_ticks=1) rt = FakeRuntime(launch_latency_ticks=3, ip_delay_ticks=1)
iid = rt.launch_spot("slot001", "") iid = rt.launch_instance("slot001", "")
for _ in range(2): for _ in range(2):
rt.tick() rt.tick()
@ -35,7 +35,7 @@ class TestTickProgression:
def test_tailscale_ip_appears_after_configured_delay(self): def test_tailscale_ip_appears_after_configured_delay(self):
rt = FakeRuntime(launch_latency_ticks=2, ip_delay_ticks=2) rt = FakeRuntime(launch_latency_ticks=2, ip_delay_ticks=2)
iid = rt.launch_spot("slot001", "") iid = rt.launch_instance("slot001", "")
for _ in range(2): for _ in range(2):
rt.tick() rt.tick()
@ -56,7 +56,7 @@ class TestInjectedFailure:
rt = FakeRuntime() rt = FakeRuntime()
rt.inject_launch_failure("slot001") rt.inject_launch_failure("slot001")
try: try:
rt.launch_spot("slot001", "") rt.launch_instance("slot001", "")
raise AssertionError("Should have raised") raise AssertionError("Should have raised")
except RuntimeAdapterError as e: except RuntimeAdapterError as e:
assert e.category == "capacity_unavailable" assert e.category == "capacity_unavailable"
@ -65,16 +65,16 @@ class TestInjectedFailure:
rt = FakeRuntime() rt = FakeRuntime()
rt.inject_launch_failure("slot001") rt.inject_launch_failure("slot001")
with contextlib.suppress(RuntimeAdapterError): with contextlib.suppress(RuntimeAdapterError):
rt.launch_spot("slot001", "") rt.launch_instance("slot001", "")
# Second call should succeed # Second call should succeed
iid = rt.launch_spot("slot001", "") iid = rt.launch_instance("slot001", "")
assert iid.startswith("i-fake-") assert iid.startswith("i-fake-")
class TestInjectedInterruption: class TestInjectedInterruption:
def test_interruption_returns_terminated(self): def test_interruption_returns_terminated(self):
rt = FakeRuntime(launch_latency_ticks=1) rt = FakeRuntime(launch_latency_ticks=1)
iid = rt.launch_spot("slot001", "") iid = rt.launch_instance("slot001", "")
rt.tick() rt.tick()
assert rt.describe_instance(iid)["state"] == "running" assert rt.describe_instance(iid)["state"] == "running"
@ -85,7 +85,7 @@ class TestInjectedInterruption:
def test_interruption_is_one_shot(self): def test_interruption_is_one_shot(self):
"""After the interruption fires, subsequent describes stay terminated.""" """After the interruption fires, subsequent describes stay terminated."""
rt = FakeRuntime(launch_latency_ticks=1) rt = FakeRuntime(launch_latency_ticks=1)
iid = rt.launch_spot("slot001", "") iid = rt.launch_instance("slot001", "")
rt.tick() rt.tick()
rt.inject_interruption(iid) rt.inject_interruption(iid)
rt.describe_instance(iid) # consumes the injection rt.describe_instance(iid) # consumes the injection
@ -96,7 +96,7 @@ class TestInjectedInterruption:
class TestTerminate: class TestTerminate:
def test_terminate_marks_instance(self): def test_terminate_marks_instance(self):
rt = FakeRuntime(launch_latency_ticks=1) rt = FakeRuntime(launch_latency_ticks=1)
iid = rt.launch_spot("slot001", "") iid = rt.launch_instance("slot001", "")
rt.tick() rt.tick()
rt.terminate_instance(iid) rt.terminate_instance(iid)
assert rt.describe_instance(iid)["state"] == "terminated" assert rt.describe_instance(iid)["state"] == "terminated"
@ -105,8 +105,8 @@ class TestTerminate:
class TestListManaged: class TestListManaged:
def test_lists_non_terminated(self): def test_lists_non_terminated(self):
rt = FakeRuntime(launch_latency_ticks=1) rt = FakeRuntime(launch_latency_ticks=1)
iid1 = rt.launch_spot("slot001", "") iid1 = rt.launch_instance("slot001", "")
iid2 = rt.launch_spot("slot002", "") iid2 = rt.launch_instance("slot002", "")
rt.tick() rt.tick()
rt.terminate_instance(iid1) rt.terminate_instance(iid1)

View file

@ -66,7 +66,13 @@ in
launchTemplateIdFile = lib.mkOption { launchTemplateIdFile = lib.mkOption {
type = lib.types.nullOr lib.types.str; type = lib.types.nullOr lib.types.str;
default = null; default = null;
description = "Runtime file containing the EC2 launch template ID."; description = "Runtime file containing the EC2 spot launch template ID.";
};
onDemandLaunchTemplateIdFile = lib.mkOption {
type = lib.types.nullOr lib.types.str;
default = null;
description = "Runtime file containing the EC2 on-demand launch template ID (required when capacity.nestedVirtualization is true).";
}; };
subnetIdsJsonFile = lib.mkOption { subnetIdsJsonFile = lib.mkOption {
@ -216,6 +222,12 @@ in
default = 1; default = 1;
description = "Launch batch size for the default system entry."; description = "Launch batch size for the default system entry.";
}; };
nestedVirtualization = lib.mkOption {
type = lib.types.bool;
default = false;
description = "Whether slots use on-demand instances with nested virtualization. Requires aws.onDemandLaunchTemplateIdFile to be set.";
};
}; };
security = { security = {
@ -256,6 +268,10 @@ in
assertion = cfg.aws.subnetIdsJsonFile != null; assertion = cfg.aws.subnetIdsJsonFile != null;
message = "services.nix-builder-autoscaler.aws.subnetIdsJsonFile must be set."; message = "services.nix-builder-autoscaler.aws.subnetIdsJsonFile must be set.";
} }
{
assertion = !cfg.capacity.nestedVirtualization || cfg.aws.onDemandLaunchTemplateIdFile != null;
message = "services.nix-builder-autoscaler.aws.onDemandLaunchTemplateIdFile must be set when capacity.nestedVirtualization is true.";
}
]; ];
environment.systemPackages = [ cfg.package ]; environment.systemPackages = [ cfg.package ];
@ -301,6 +317,9 @@ in
install -d -m 0750 -o ${cfg.user} -g ${cfg.group} /run/nix-builder-autoscaler install -d -m 0750 -o ${cfg.user} -g ${cfg.group} /run/nix-builder-autoscaler
launch_template_id="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.launchTemplateIdFile})" launch_template_id="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.launchTemplateIdFile})"
subnet_ids_json="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.subnetIdsJsonFile})" subnet_ids_json="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.subnetIdsJsonFile})"
${lib.optionalString (cfg.aws.onDemandLaunchTemplateIdFile != null) ''
on_demand_launch_template_id="$(tr -d '\n' < ${lib.escapeShellArg cfg.aws.onDemandLaunchTemplateIdFile})"
''}
cat > ${generatedConfigPath} <<EOF cat > ${generatedConfigPath} <<EOF
[server] [server]
@ -311,6 +330,9 @@ in
[aws] [aws]
region = "${cfg.aws.region}" region = "${cfg.aws.region}"
launch_template_id = "$launch_template_id" launch_template_id = "$launch_template_id"
${lib.optionalString (
cfg.aws.onDemandLaunchTemplateIdFile != null
) ''on_demand_launch_template_id = "$on_demand_launch_template_id"''}
subnet_ids = $subnet_ids_json subnet_ids = $subnet_ids_json
security_group_ids = ${tomlStringList cfg.aws.securityGroupIds} security_group_ids = ${tomlStringList cfg.aws.securityGroupIds}
instance_profile_arn = "${cfg.aws.instanceProfileArn}" instance_profile_arn = "${cfg.aws.instanceProfileArn}"
@ -351,6 +373,7 @@ in
launch_batch_size = ${toString cfg.capacity.launchBatchSize} launch_batch_size = ${toString cfg.capacity.launchBatchSize}
scale_down_idle_seconds = ${toString cfg.capacity.idleScaleDownSeconds} scale_down_idle_seconds = ${toString cfg.capacity.idleScaleDownSeconds}
termination_cooldown_seconds = ${toString cfg.capacity.terminationCooldownSeconds} termination_cooldown_seconds = ${toString cfg.capacity.terminationCooldownSeconds}
nested_virtualization = ${lib.boolToString cfg.capacity.nestedVirtualization}
EOF EOF
chown ${cfg.user}:${cfg.group} ${generatedConfigPath} chown ${cfg.user}:${cfg.group} ${generatedConfigPath}