fix attach workers
This commit is contained in:
parent
0c36ee6662
commit
ec4bdf1096
3 changed files with 354 additions and 6 deletions
200
repub/jobs.py
200
repub/jobs.py
|
|
@ -1,8 +1,11 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
|
@ -38,10 +41,43 @@ class JobArtifacts:
|
|||
@dataclass
|
||||
class RunningWorker:
|
||||
execution_id: int
|
||||
process: subprocess.Popen[str]
|
||||
process: subprocess.Popen[str] | DetachedProcess
|
||||
log_handle: TextIO
|
||||
artifacts: JobArtifacts
|
||||
stats_offset: int = 0
|
||||
detached: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DetachedProcess:
|
||||
pid: int
|
||||
|
||||
def poll(self) -> int | None:
|
||||
return None if _pid_is_running(self.pid) else 0
|
||||
|
||||
def terminate(self) -> None:
|
||||
_send_signal(self.pid, signal.SIGTERM)
|
||||
|
||||
def kill(self) -> None:
|
||||
_send_signal(self.pid, signal.SIGKILL)
|
||||
|
||||
def wait(self, timeout: float | None = None) -> int:
|
||||
deadline = None if timeout is None else time.monotonic() + timeout
|
||||
while self.poll() is None:
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
raise subprocess.TimeoutExpired(
|
||||
cmd=f"pid {self.pid}",
|
||||
timeout=timeout if timeout is not None else 0.0,
|
||||
)
|
||||
time.sleep(0.05)
|
||||
return 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LiveWorker:
|
||||
job_id: int
|
||||
execution_id: int
|
||||
pid: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -248,8 +284,9 @@ class JobRuntime:
|
|||
with database.connection_context():
|
||||
execution = JobExecution.get_by_id(execution_id)
|
||||
execution.ended_at = utc_now()
|
||||
execution.running_status = _final_status(
|
||||
execution.running_status = _worker_final_status(
|
||||
execution=execution,
|
||||
worker=worker,
|
||||
returncode=returncode,
|
||||
)
|
||||
execution.save()
|
||||
|
|
@ -306,7 +343,51 @@ class JobRuntime:
|
|||
self.refresh_callback()
|
||||
|
||||
def _reconcile_stale_executions(self) -> None:
|
||||
live_workers = _find_live_workers()
|
||||
recovered_execution_ids: set[int] = set()
|
||||
with database.connection_context():
|
||||
execution_primary_key = getattr(JobExecution, "_meta").primary_key
|
||||
if live_workers:
|
||||
live_executions = tuple(
|
||||
JobExecution.select(JobExecution, Job)
|
||||
.join(Job)
|
||||
.where(execution_primary_key.in_(tuple(live_workers)))
|
||||
)
|
||||
else:
|
||||
live_executions = ()
|
||||
|
||||
for execution in live_executions:
|
||||
job = cast(Job, execution.job)
|
||||
execution_id = _execution_id(execution)
|
||||
live_worker = live_workers.get(execution_id)
|
||||
if live_worker is None or live_worker.job_id != _job_id(job):
|
||||
continue
|
||||
|
||||
artifacts = JobArtifacts.for_execution(
|
||||
log_dir=self.log_dir,
|
||||
job_id=live_worker.job_id,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
artifacts.log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if execution.running_status != JobExecutionStatus.RUNNING:
|
||||
execution.running_status = JobExecutionStatus.RUNNING
|
||||
execution.ended_at = None
|
||||
execution.save()
|
||||
message = f"scheduler: restored execution state from live worker pid {live_worker.pid} after app restart\n"
|
||||
else:
|
||||
message = f"scheduler: reattached to worker pid {live_worker.pid} after app restart\n"
|
||||
|
||||
log_handle = artifacts.log_path.open("a", encoding="utf-8", buffering=1)
|
||||
log_handle.write(message)
|
||||
self._workers[execution_id] = RunningWorker(
|
||||
execution_id=execution_id,
|
||||
process=DetachedProcess(pid=live_worker.pid),
|
||||
log_handle=log_handle,
|
||||
artifacts=artifacts,
|
||||
detached=True,
|
||||
)
|
||||
recovered_execution_ids.add(execution_id)
|
||||
|
||||
stale_executions = tuple(
|
||||
JobExecution.select(JobExecution, Job)
|
||||
.join(Job)
|
||||
|
|
@ -316,9 +397,12 @@ class JobRuntime:
|
|||
for execution in stale_executions:
|
||||
job = cast(Job, execution.job)
|
||||
execution_id = _execution_id(execution)
|
||||
if execution_id in recovered_execution_ids:
|
||||
continue
|
||||
job_id = _job_id(job)
|
||||
artifacts = JobArtifacts.for_execution(
|
||||
log_dir=self.log_dir,
|
||||
job_id=_job_id(job),
|
||||
job_id=job_id,
|
||||
execution_id=execution_id,
|
||||
)
|
||||
artifacts.log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -335,7 +419,7 @@ class JobRuntime:
|
|||
)
|
||||
execution.save()
|
||||
|
||||
if stale_executions:
|
||||
if stale_executions or recovered_execution_ids:
|
||||
self._trigger_refresh()
|
||||
|
||||
|
||||
|
|
@ -514,7 +598,7 @@ def _project_running_execution(
|
|||
"worker": (
|
||||
"graceful stop requested"
|
||||
if execution.stop_requested_at
|
||||
else "streaming stats from worker jsonl"
|
||||
else "streaming stats from worker"
|
||||
),
|
||||
"log_href": f"/job/{job_id}/execution/{execution_id}/logs",
|
||||
"log_exists": artifacts.log_path.exists(),
|
||||
|
|
@ -675,6 +759,28 @@ def _final_status(*, execution: JobExecution, returncode: int) -> JobExecutionSt
|
|||
return JobExecutionStatus.FAILED
|
||||
|
||||
|
||||
def _worker_final_status(
|
||||
*,
|
||||
execution: JobExecution,
|
||||
worker: RunningWorker,
|
||||
returncode: int,
|
||||
) -> JobExecutionStatus:
|
||||
if worker.detached:
|
||||
return _detached_final_status(execution=execution, artifacts=worker.artifacts)
|
||||
return _final_status(execution=execution, returncode=returncode)
|
||||
|
||||
|
||||
def _detached_final_status(
|
||||
*, execution: JobExecution, artifacts: JobArtifacts
|
||||
) -> JobExecutionStatus:
|
||||
if execution.stop_requested_at is not None:
|
||||
return JobExecutionStatus.CANCELED
|
||||
log_tail = _read_log_tail(artifacts.log_path)
|
||||
if "completed successfully" in log_tail:
|
||||
return JobExecutionStatus.SUCCEEDED
|
||||
return JobExecutionStatus.FAILED
|
||||
|
||||
|
||||
def _coerce_datetime(value: datetime | str) -> datetime:
|
||||
if isinstance(value, datetime):
|
||||
if value.tzinfo is None:
|
||||
|
|
@ -745,3 +851,87 @@ def _humanize_relative_time(reference_time: datetime, target_time: datetime) ->
|
|||
if delta_seconds > 0:
|
||||
return f"in {absolute_delta_seconds} seconds"
|
||||
return f"{absolute_delta_seconds} seconds ago"
|
||||
|
||||
|
||||
def _find_live_workers() -> dict[int, LiveWorker]:
|
||||
proc_dir = Path("/proc")
|
||||
if not proc_dir.exists():
|
||||
return {}
|
||||
|
||||
live_workers: dict[int, LiveWorker] = {}
|
||||
for cmdline_path in proc_dir.glob("[0-9]*/cmdline"):
|
||||
try:
|
||||
argv = [
|
||||
part
|
||||
for part in cmdline_path.read_bytes()
|
||||
.decode("utf-8", errors="ignore")
|
||||
.split("\x00")
|
||||
if part != ""
|
||||
]
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
live_worker = _parse_live_worker(argv, pid=int(cmdline_path.parent.name))
|
||||
if live_worker is None or not _pid_is_running(live_worker.pid):
|
||||
continue
|
||||
|
||||
live_workers[live_worker.execution_id] = live_worker
|
||||
return live_workers
|
||||
|
||||
|
||||
def _parse_live_worker(argv: list[str], *, pid: int) -> LiveWorker | None:
|
||||
if "repub.job_runner" not in argv:
|
||||
return None
|
||||
job_id = _argv_flag_value(argv, "--job-id")
|
||||
execution_id = _argv_flag_value(argv, "--execution-id")
|
||||
if job_id is None or execution_id is None:
|
||||
return None
|
||||
return LiveWorker(job_id=int(job_id), execution_id=int(execution_id), pid=pid)
|
||||
|
||||
|
||||
def _argv_flag_value(argv: list[str], flag: str) -> str | None:
|
||||
try:
|
||||
index = argv.index(flag)
|
||||
except ValueError:
|
||||
return None
|
||||
value_index = index + 1
|
||||
if value_index >= len(argv):
|
||||
return None
|
||||
return argv[value_index]
|
||||
|
||||
|
||||
def _pid_is_running(pid: int) -> bool:
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
except PermissionError:
|
||||
return True
|
||||
stat_path = Path(f"/proc/{pid}/stat")
|
||||
if stat_path.exists():
|
||||
try:
|
||||
stat_text = stat_path.read_text(encoding="utf-8")
|
||||
except OSError:
|
||||
return True
|
||||
state_parts = stat_text.split(") ", 1)
|
||||
if len(state_parts) == 2 and state_parts[1].startswith("Z"):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _send_signal(pid: int, signum: int) -> None:
|
||||
try:
|
||||
os.kill(pid, signum)
|
||||
except ProcessLookupError:
|
||||
return
|
||||
|
||||
|
||||
def _read_log_tail(path: Path, *, max_bytes: int = 8192) -> str:
|
||||
if not path.exists():
|
||||
return ""
|
||||
|
||||
with path.open("rb") as handle:
|
||||
handle.seek(0, os.SEEK_END)
|
||||
size = handle.tell()
|
||||
handle.seek(max(0, size - max_bytes))
|
||||
return handle.read().decode("utf-8", errors="replace")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue