Add remote autoscaler daemon endpoint support
This commit is contained in:
parent
95021a4253
commit
679b5c8d07
11 changed files with 291 additions and 22 deletions
|
|
@ -218,16 +218,23 @@ def main() -> None:
|
|||
reconciler_thread.start()
|
||||
metrics_thread.start()
|
||||
|
||||
socket_path = Path(config.server.socket_path)
|
||||
socket_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if socket_path.exists():
|
||||
socket_path.unlink()
|
||||
|
||||
uvicorn_config = uvicorn.Config(
|
||||
app=app,
|
||||
uds=config.server.socket_path,
|
||||
log_level=config.server.log_level.lower(),
|
||||
)
|
||||
if config.server.listen_port > 0:
|
||||
uvicorn_config = uvicorn.Config(
|
||||
app=app,
|
||||
host=config.server.listen_host,
|
||||
port=config.server.listen_port,
|
||||
log_level=config.server.log_level.lower(),
|
||||
)
|
||||
else:
|
||||
socket_path = Path(config.server.socket_path)
|
||||
socket_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if socket_path.exists():
|
||||
socket_path.unlink()
|
||||
uvicorn_config = uvicorn.Config(
|
||||
app=app,
|
||||
uds=config.server.socket_path,
|
||||
log_level=config.server.log_level.lower(),
|
||||
)
|
||||
server = uvicorn.Server(uvicorn_config)
|
||||
|
||||
def _handle_signal(signum: int, _: FrameType | None) -> None:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
|
|
@ -118,6 +119,8 @@ def create_app(
|
|||
app.state.runtime = runtime
|
||||
app.state.haproxy = haproxy
|
||||
|
||||
auth_token = config.server.auth_token.strip()
|
||||
|
||||
@app.middleware("http")
|
||||
async def request_id_middleware(request: Request, call_next: Callable) -> Response:
|
||||
request.state.request_id = str(uuid.uuid4())
|
||||
|
|
@ -125,6 +128,25 @@ def create_app(
|
|||
response.headers["x-request-id"] = request.state.request_id
|
||||
return response
|
||||
|
||||
@app.middleware("http")
|
||||
async def auth_middleware(request: Request, call_next: Callable) -> Response:
|
||||
path = request.url.path
|
||||
if auth_token != "" and (path.startswith("/v1/") or path == "/metrics"):
|
||||
expected = f"Bearer {auth_token}"
|
||||
provided = request.headers.get("authorization", "")
|
||||
if not hmac.compare_digest(provided, expected):
|
||||
request_id = getattr(request.state, "request_id", str(uuid.uuid4()))
|
||||
payload = ErrorResponse(
|
||||
error=ErrorDetail(
|
||||
code="unauthorized",
|
||||
message="Missing or invalid bearer token",
|
||||
retryable=False,
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
return JSONResponse(status_code=401, content=payload.model_dump(mode="json"))
|
||||
return await call_next(request)
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
|
||||
detail = exc.detail
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@ class ServerConfig:
|
|||
"""[server] section."""
|
||||
|
||||
socket_path: str = "/run/nix-builder-autoscaler/daemon.sock"
|
||||
listen_host: str = "127.0.0.1"
|
||||
listen_port: int = 0
|
||||
auth_token: str = ""
|
||||
log_level: str = "info"
|
||||
db_path: str = "/var/lib/nix-builder-autoscaler/state.db"
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
from fastapi.testclient import TestClient
|
||||
|
||||
from nix_builder_autoscaler.api import create_app
|
||||
from nix_builder_autoscaler.config import AppConfig, CapacityConfig
|
||||
from nix_builder_autoscaler.config import AppConfig, CapacityConfig, ServerConfig
|
||||
from nix_builder_autoscaler.metrics import MetricsRegistry
|
||||
from nix_builder_autoscaler.models import SlotState
|
||||
from nix_builder_autoscaler.providers.clock import FakeClock
|
||||
|
|
@ -18,12 +18,16 @@ from nix_builder_autoscaler.state_db import StateDB
|
|||
def _make_client(
|
||||
*,
|
||||
reconcile_now: Any = None, # noqa: ANN401
|
||||
auth_token: str = "",
|
||||
) -> tuple[TestClient, StateDB, FakeClock, MetricsRegistry]:
|
||||
clock = FakeClock()
|
||||
db = StateDB(":memory:", clock=clock)
|
||||
db.init_schema()
|
||||
db.init_slots("slot", 3, "x86_64-linux", "all")
|
||||
config = AppConfig(capacity=CapacityConfig(reservation_ttl_seconds=1200))
|
||||
config = AppConfig(
|
||||
server=ServerConfig(auth_token=auth_token),
|
||||
capacity=CapacityConfig(reservation_ttl_seconds=1200),
|
||||
)
|
||||
metrics = MetricsRegistry()
|
||||
app = create_app(db, config, clock, metrics, reconcile_now=reconcile_now)
|
||||
return TestClient(app), db, clock, metrics
|
||||
|
|
@ -245,3 +249,20 @@ def test_admin_reconcile_now_success() -> None:
|
|||
assert response.json()["status"] == "accepted"
|
||||
assert response.json()["triggered"] is True
|
||||
assert called["value"] is True
|
||||
|
||||
|
||||
def test_auth_token_required_for_v1_when_configured() -> None:
|
||||
client, _, _, _ = _make_client(auth_token="test-token")
|
||||
response = client.post("/v1/reservations", json={"system": "x86_64-linux", "reason": "test"})
|
||||
assert response.status_code == 401
|
||||
assert response.json()["error"]["code"] == "unauthorized"
|
||||
|
||||
|
||||
def test_auth_token_allows_v1_when_header_matches() -> None:
|
||||
client, _, _, _ = _make_client(auth_token="test-token")
|
||||
response = client.post(
|
||||
"/v1/reservations",
|
||||
json={"system": "x86_64-linux", "reason": "test"},
|
||||
headers={"Authorization": "Bearer test-token"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue