"""autoscalerctl CLI entry point.""" from __future__ import annotations import argparse import http.client import json import socket from collections.abc import Sequence from datetime import UTC, datetime from typing import Any class UnixHTTPConnection(http.client.HTTPConnection): """HTTPConnection that dials a Unix domain socket.""" def __init__(self, socket_path: str, timeout: float = 5.0) -> None: super().__init__("localhost", timeout=timeout) self._socket_path = socket_path def connect(self) -> None: self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.sock.connect(self._socket_path) def _uds_request( socket_path: str, method: str, path: str, body: dict[str, Any] | None = None, ) -> tuple[int, dict[str, Any] | list[dict[str, Any]] | str]: conn = UnixHTTPConnection(socket_path) headers = {"Host": "localhost", "Accept": "application/json"} payload: str | None = None if body is not None: payload = json.dumps(body) headers["Content-Type"] = "application/json" try: conn.request(method, path, body=payload, headers=headers) resp = conn.getresponse() raw = resp.read() text = raw.decode() if raw else "" content_type = resp.getheader("Content-Type", "") if text and "application/json" in content_type: parsed = json.loads(text) if isinstance(parsed, dict | list): return resp.status, parsed return resp.status, text finally: conn.close() def _print_table(headers: Sequence[str], rows: Sequence[Sequence[str]]) -> None: widths = [len(h) for h in headers] for row in rows: for idx, cell in enumerate(row): widths[idx] = max(widths[idx], len(cell)) header_line = " ".join(h.ljust(widths[idx]) for idx, h in enumerate(headers)) separator = " ".join("-" * widths[idx] for idx in range(len(headers))) print(header_line) print(separator) for row in rows: print(" ".join(cell.ljust(widths[idx]) for idx, cell in enumerate(row))) def _format_duration(seconds: float) -> str: total = int(max(0, round(seconds))) hours, rem = divmod(total, 3600) minutes, secs = divmod(rem, 60) if hours > 0: return f"{hours}h{minutes:02d}m" if minutes > 0: return f"{minutes}m{secs:02d}s" return f"{secs}s" def _format_timeout_ttl(timeout_seconds: float, age_seconds: float) -> str: remaining = timeout_seconds - age_seconds if remaining <= 0: return "due" return _format_duration(remaining) def _slot_age_seconds(slot: dict[str, Any]) -> float | None: raw = slot.get("last_state_change") if not isinstance(raw, str): return None try: dt = datetime.fromisoformat(raw) except ValueError: return None if dt.tzinfo is None: dt = dt.replace(tzinfo=UTC) return (datetime.now(UTC) - dt).total_seconds() def _slot_ttl(slot: dict[str, Any], policy: dict[str, Any], active_slots: int) -> str: capacity = policy.get("capacity") scheduler = policy.get("scheduler") if not isinstance(capacity, dict) or not isinstance(scheduler, dict): return "-" state = str(slot.get("state", "")) lease_count = int(slot.get("lease_count", 0)) age_seconds = _slot_age_seconds(slot) if state in {"empty", "error"}: return "-" if age_seconds is None: return "?" if state == "launching": return _format_timeout_ttl(float(capacity.get("launch_timeout_seconds", 0)), age_seconds) if state == "booting": return _format_timeout_ttl(float(capacity.get("boot_timeout_seconds", 0)), age_seconds) if state == "binding": return _format_timeout_ttl(float(capacity.get("binding_timeout_seconds", 0)), age_seconds) if state == "terminating": return _format_timeout_ttl( float(capacity.get("terminating_timeout_seconds", 0)), age_seconds ) if state == "draining": if lease_count == 0: return f"<={_format_duration(float(scheduler.get('reconcile_seconds', 0)))}" return _format_timeout_ttl(float(capacity.get("drain_timeout_seconds", 0)), age_seconds) if state == "ready": if lease_count > 0: return "-" min_slots = int(capacity.get("min_slots", 0)) if active_slots <= min_slots: return "pinned" return _format_timeout_ttl(float(capacity.get("idle_scale_down_seconds", 0)), age_seconds) return "-" def _print_slots(data: list[dict[str, Any]], policy: dict[str, Any]) -> None: active_slots = sum(1 for slot in data if str(slot.get("state", "")) not in {"empty", "error"}) rows: list[list[str]] = [] for slot in data: rows.append( [ str(slot.get("slot_id", "")), str(slot.get("state", "")), str(slot.get("instance_id") or "-"), str(slot.get("instance_ip") or "-"), str(slot.get("lease_count", 0)), _slot_ttl(slot, policy, active_slots), ] ) _print_table(["slot_id", "state", "instance_id", "ip", "leases", "ttl"], rows) def _print_reservations(data: list[dict[str, Any]]) -> None: rows: list[list[str]] = [] for resv in data: rows.append( [ str(resv.get("reservation_id", "")), str(resv.get("phase", "")), str(resv.get("system", "")), str(resv.get("slot") or "-"), str(resv.get("instance_id") or "-"), ] ) _print_table(["reservation_id", "phase", "system", "slot", "instance_id"], rows) def _print_status_summary(data: dict[str, Any]) -> None: slots = data.get("slots", {}) reservations = data.get("reservations", {}) ec2 = data.get("ec2", {}) haproxy = data.get("haproxy", {}) rows = [ ["slots.total", str(slots.get("total", 0))], ["slots.ready", str(slots.get("ready", 0))], ["slots.launching", str(slots.get("launching", 0))], ["slots.booting", str(slots.get("booting", 0))], ["slots.binding", str(slots.get("binding", 0))], ["slots.terminating", str(slots.get("terminating", 0))], ["slots.empty", str(slots.get("empty", 0))], ["slots.error", str(slots.get("error", 0))], ["reservations.pending", str(reservations.get("pending", 0))], ["reservations.ready", str(reservations.get("ready", 0))], ["reservations.failed", str(reservations.get("failed", 0))], ["ec2.api_ok", str(ec2.get("api_ok", False))], ["haproxy.socket_ok", str(haproxy.get("socket_ok", False))], ] _print_table(["metric", "value"], rows) def _get_effective_config(socket_path: str) -> dict[str, Any]: status, data = _uds_request(socket_path, "GET", "/v1/config/effective") if status < 200 or status >= 300: msg = "failed to fetch effective config" raise RuntimeError(msg) if isinstance(data, dict): return data msg = "invalid effective config payload" raise RuntimeError(msg) def _bulk_slot_action(socket_path: str, action: str) -> dict[str, Any]: if action == "drain": eligible_states = {"ready"} action_path = "/v1/admin/drain" elif action == "unquarantine": eligible_states = {"error"} action_path = "/v1/admin/unquarantine" else: msg = f"unknown bulk action: {action}" raise ValueError(msg) status, data = _uds_request(socket_path, "GET", "/v1/slots") if status < 200 or status >= 300 or not isinstance(data, list): msg = "failed to list slots for bulk action" raise RuntimeError(msg) results: list[dict[str, Any]] = [] summary: dict[str, Any] = { "action": action, "matched": 0, "attempted": 0, "succeeded": 0, "failed": 0, "skipped": 0, "results": results, } for slot in data: slot_id = str(slot.get("slot_id", "")) state = str(slot.get("state", "")) if not slot_id: continue if state not in eligible_states: summary["skipped"] += 1 results.append( { "slot_id": slot_id, "state": state, "result": "skipped", "reason": "ineligible_state", } ) continue summary["matched"] += 1 summary["attempted"] += 1 try: action_status, action_data = _uds_request( socket_path, "POST", action_path, body={"slot_id": slot_id}, ) except OSError as err: summary["failed"] += 1 results.append( { "slot_id": slot_id, "state": state, "result": "failed", "error": str(err), } ) continue if 200 <= action_status < 300: summary["succeeded"] += 1 results.append( { "slot_id": slot_id, "state": state, "result": "ok", } ) else: summary["failed"] += 1 results.append( { "slot_id": slot_id, "state": state, "result": "failed", "status": action_status, "response": action_data, } ) return summary def _parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser(prog="autoscalerctl", description="Autoscaler CLI") parser.add_argument( "--socket", default="/run/nix-builder-autoscaler/daemon.sock", help="Daemon Unix socket path", ) parser.add_argument( "--json", action="store_true", help="Output JSON for status command.", ) subparsers = parser.add_subparsers(dest="command") subparsers.add_parser("status", help="Show state summary") subparsers.add_parser("slots", help="List slots") subparsers.add_parser("reservations", help="List reservations") parser_drain = subparsers.add_parser("drain", help="Drain a slot") parser_drain.add_argument("slot_id") parser_unq = subparsers.add_parser("unquarantine", help="Unquarantine a slot") parser_unq.add_argument("slot_id") subparsers.add_parser("drain-all", help="Drain all eligible slots (state=ready)") subparsers.add_parser("unquarantine-all", help="Unquarantine all error slots") subparsers.add_parser("reconcile-now", help="Trigger immediate reconcile tick") args = parser.parse_args(argv) if not args.command: parser.print_help() raise SystemExit(0) return args def _print_error(data: object) -> None: if isinstance(data, dict | list): print(json.dumps(data, indent=2)) else: print(str(data)) def main() -> None: """Entry point for the autoscalerctl CLI.""" args = _parse_args() if args.command in {"drain-all", "unquarantine-all"}: action = "drain" if args.command == "drain-all" else "unquarantine" try: summary = _bulk_slot_action(args.socket, action) except OSError as err: print(f"Error: cannot connect to daemon at {args.socket}") raise SystemExit(1) from err except RuntimeError as err: print(str(err)) raise SystemExit(1) from err print(json.dumps(summary, indent=2)) raise SystemExit(0 if summary["failed"] == 0 else 1) method = "GET" path = "" body: dict[str, Any] | None = None if args.command == "status": path = "/v1/state/summary" elif args.command == "slots": path = "/v1/slots" elif args.command == "reservations": path = "/v1/reservations" elif args.command == "drain": method = "POST" path = "/v1/admin/drain" body = {"slot_id": args.slot_id} elif args.command == "unquarantine": method = "POST" path = "/v1/admin/unquarantine" body = {"slot_id": args.slot_id} elif args.command == "reconcile-now": method = "POST" path = "/v1/admin/reconcile-now" else: raise SystemExit(1) try: status, data = _uds_request(args.socket, method, path, body=body) except OSError as err: print(f"Error: cannot connect to daemon at {args.socket}") raise SystemExit(1) from err if status < 200 or status >= 300: _print_error(data) raise SystemExit(1) if args.command == "status": if not isinstance(data, dict): _print_error(data) raise SystemExit(1) if args.json: print(json.dumps(data, indent=2)) else: _print_status_summary(data) elif args.command in {"drain", "unquarantine", "reconcile-now"}: print(json.dumps(data, indent=2)) elif args.command == "slots": if isinstance(data, list): try: policy = _get_effective_config(args.socket) except RuntimeError as err: print(str(err)) raise SystemExit(1) from err _print_slots(data, policy) else: _print_error(data) raise SystemExit(1) elif args.command == "reservations": if isinstance(data, list): _print_reservations(data) else: _print_error(data) raise SystemExit(1) raise SystemExit(0) if __name__ == "__main__": main()