fix attach workers

This commit is contained in:
Abel Luck 2026-03-30 15:53:04 +02:00
parent 0c36ee6662
commit ec4bdf1096
3 changed files with 354 additions and 6 deletions

View file

@ -1,8 +1,11 @@
from __future__ import annotations
import json
import os
import signal
import subprocess
import sys
import time
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from pathlib import Path
@ -38,10 +41,43 @@ class JobArtifacts:
@dataclass
class RunningWorker:
execution_id: int
process: subprocess.Popen[str]
process: subprocess.Popen[str] | DetachedProcess
log_handle: TextIO
artifacts: JobArtifacts
stats_offset: int = 0
detached: bool = False
@dataclass(frozen=True)
class DetachedProcess:
pid: int
def poll(self) -> int | None:
return None if _pid_is_running(self.pid) else 0
def terminate(self) -> None:
_send_signal(self.pid, signal.SIGTERM)
def kill(self) -> None:
_send_signal(self.pid, signal.SIGKILL)
def wait(self, timeout: float | None = None) -> int:
deadline = None if timeout is None else time.monotonic() + timeout
while self.poll() is None:
if deadline is not None and time.monotonic() >= deadline:
raise subprocess.TimeoutExpired(
cmd=f"pid {self.pid}",
timeout=timeout if timeout is not None else 0.0,
)
time.sleep(0.05)
return 0
@dataclass(frozen=True)
class LiveWorker:
job_id: int
execution_id: int
pid: int
@dataclass(frozen=True)
@ -248,8 +284,9 @@ class JobRuntime:
with database.connection_context():
execution = JobExecution.get_by_id(execution_id)
execution.ended_at = utc_now()
execution.running_status = _final_status(
execution.running_status = _worker_final_status(
execution=execution,
worker=worker,
returncode=returncode,
)
execution.save()
@ -306,7 +343,51 @@ class JobRuntime:
self.refresh_callback()
def _reconcile_stale_executions(self) -> None:
live_workers = _find_live_workers()
recovered_execution_ids: set[int] = set()
with database.connection_context():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
if live_workers:
live_executions = tuple(
JobExecution.select(JobExecution, Job)
.join(Job)
.where(execution_primary_key.in_(tuple(live_workers)))
)
else:
live_executions = ()
for execution in live_executions:
job = cast(Job, execution.job)
execution_id = _execution_id(execution)
live_worker = live_workers.get(execution_id)
if live_worker is None or live_worker.job_id != _job_id(job):
continue
artifacts = JobArtifacts.for_execution(
log_dir=self.log_dir,
job_id=live_worker.job_id,
execution_id=execution_id,
)
artifacts.log_path.parent.mkdir(parents=True, exist_ok=True)
if execution.running_status != JobExecutionStatus.RUNNING:
execution.running_status = JobExecutionStatus.RUNNING
execution.ended_at = None
execution.save()
message = f"scheduler: restored execution state from live worker pid {live_worker.pid} after app restart\n"
else:
message = f"scheduler: reattached to worker pid {live_worker.pid} after app restart\n"
log_handle = artifacts.log_path.open("a", encoding="utf-8", buffering=1)
log_handle.write(message)
self._workers[execution_id] = RunningWorker(
execution_id=execution_id,
process=DetachedProcess(pid=live_worker.pid),
log_handle=log_handle,
artifacts=artifacts,
detached=True,
)
recovered_execution_ids.add(execution_id)
stale_executions = tuple(
JobExecution.select(JobExecution, Job)
.join(Job)
@ -316,9 +397,12 @@ class JobRuntime:
for execution in stale_executions:
job = cast(Job, execution.job)
execution_id = _execution_id(execution)
if execution_id in recovered_execution_ids:
continue
job_id = _job_id(job)
artifacts = JobArtifacts.for_execution(
log_dir=self.log_dir,
job_id=_job_id(job),
job_id=job_id,
execution_id=execution_id,
)
artifacts.log_path.parent.mkdir(parents=True, exist_ok=True)
@ -335,7 +419,7 @@ class JobRuntime:
)
execution.save()
if stale_executions:
if stale_executions or recovered_execution_ids:
self._trigger_refresh()
@ -514,7 +598,7 @@ def _project_running_execution(
"worker": (
"graceful stop requested"
if execution.stop_requested_at
else "streaming stats from worker jsonl"
else "streaming stats from worker"
),
"log_href": f"/job/{job_id}/execution/{execution_id}/logs",
"log_exists": artifacts.log_path.exists(),
@ -675,6 +759,28 @@ def _final_status(*, execution: JobExecution, returncode: int) -> JobExecutionSt
return JobExecutionStatus.FAILED
def _worker_final_status(
*,
execution: JobExecution,
worker: RunningWorker,
returncode: int,
) -> JobExecutionStatus:
if worker.detached:
return _detached_final_status(execution=execution, artifacts=worker.artifacts)
return _final_status(execution=execution, returncode=returncode)
def _detached_final_status(
*, execution: JobExecution, artifacts: JobArtifacts
) -> JobExecutionStatus:
if execution.stop_requested_at is not None:
return JobExecutionStatus.CANCELED
log_tail = _read_log_tail(artifacts.log_path)
if "completed successfully" in log_tail:
return JobExecutionStatus.SUCCEEDED
return JobExecutionStatus.FAILED
def _coerce_datetime(value: datetime | str) -> datetime:
if isinstance(value, datetime):
if value.tzinfo is None:
@ -745,3 +851,87 @@ def _humanize_relative_time(reference_time: datetime, target_time: datetime) ->
if delta_seconds > 0:
return f"in {absolute_delta_seconds} seconds"
return f"{absolute_delta_seconds} seconds ago"
def _find_live_workers() -> dict[int, LiveWorker]:
proc_dir = Path("/proc")
if not proc_dir.exists():
return {}
live_workers: dict[int, LiveWorker] = {}
for cmdline_path in proc_dir.glob("[0-9]*/cmdline"):
try:
argv = [
part
for part in cmdline_path.read_bytes()
.decode("utf-8", errors="ignore")
.split("\x00")
if part != ""
]
except OSError:
continue
live_worker = _parse_live_worker(argv, pid=int(cmdline_path.parent.name))
if live_worker is None or not _pid_is_running(live_worker.pid):
continue
live_workers[live_worker.execution_id] = live_worker
return live_workers
def _parse_live_worker(argv: list[str], *, pid: int) -> LiveWorker | None:
if "repub.job_runner" not in argv:
return None
job_id = _argv_flag_value(argv, "--job-id")
execution_id = _argv_flag_value(argv, "--execution-id")
if job_id is None or execution_id is None:
return None
return LiveWorker(job_id=int(job_id), execution_id=int(execution_id), pid=pid)
def _argv_flag_value(argv: list[str], flag: str) -> str | None:
try:
index = argv.index(flag)
except ValueError:
return None
value_index = index + 1
if value_index >= len(argv):
return None
return argv[value_index]
def _pid_is_running(pid: int) -> bool:
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
except PermissionError:
return True
stat_path = Path(f"/proc/{pid}/stat")
if stat_path.exists():
try:
stat_text = stat_path.read_text(encoding="utf-8")
except OSError:
return True
state_parts = stat_text.split(") ", 1)
if len(state_parts) == 2 and state_parts[1].startswith("Z"):
return False
return True
def _send_signal(pid: int, signum: int) -> None:
try:
os.kill(pid, signum)
except ProcessLookupError:
return
def _read_log_tail(path: Path, *, max_bytes: int = 8192) -> str:
if not path.exists():
return ""
with path.open("rb") as handle:
handle.seek(0, os.SEEK_END)
size = handle.tell()
handle.seek(max(0, size - max_bytes))
return handle.read().decode("utf-8", errors="replace")