fix attach workers

This commit is contained in:
Abel Luck 2026-03-30 15:53:04 +02:00
parent 0c36ee6662
commit ec4bdf1096
3 changed files with 354 additions and 6 deletions

View file

@ -1,8 +1,11 @@
from __future__ import annotations
import json
import os
import signal
import subprocess
import sys
import time
from dataclasses import dataclass
from datetime import UTC, datetime, timedelta
from pathlib import Path
@ -38,10 +41,43 @@ class JobArtifacts:
@dataclass
class RunningWorker:
execution_id: int
process: subprocess.Popen[str]
process: subprocess.Popen[str] | DetachedProcess
log_handle: TextIO
artifacts: JobArtifacts
stats_offset: int = 0
detached: bool = False
@dataclass(frozen=True)
class DetachedProcess:
pid: int
def poll(self) -> int | None:
return None if _pid_is_running(self.pid) else 0
def terminate(self) -> None:
_send_signal(self.pid, signal.SIGTERM)
def kill(self) -> None:
_send_signal(self.pid, signal.SIGKILL)
def wait(self, timeout: float | None = None) -> int:
deadline = None if timeout is None else time.monotonic() + timeout
while self.poll() is None:
if deadline is not None and time.monotonic() >= deadline:
raise subprocess.TimeoutExpired(
cmd=f"pid {self.pid}",
timeout=timeout if timeout is not None else 0.0,
)
time.sleep(0.05)
return 0
@dataclass(frozen=True)
class LiveWorker:
job_id: int
execution_id: int
pid: int
@dataclass(frozen=True)
@ -248,8 +284,9 @@ class JobRuntime:
with database.connection_context():
execution = JobExecution.get_by_id(execution_id)
execution.ended_at = utc_now()
execution.running_status = _final_status(
execution.running_status = _worker_final_status(
execution=execution,
worker=worker,
returncode=returncode,
)
execution.save()
@ -306,7 +343,51 @@ class JobRuntime:
self.refresh_callback()
def _reconcile_stale_executions(self) -> None:
live_workers = _find_live_workers()
recovered_execution_ids: set[int] = set()
with database.connection_context():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
if live_workers:
live_executions = tuple(
JobExecution.select(JobExecution, Job)
.join(Job)
.where(execution_primary_key.in_(tuple(live_workers)))
)
else:
live_executions = ()
for execution in live_executions:
job = cast(Job, execution.job)
execution_id = _execution_id(execution)
live_worker = live_workers.get(execution_id)
if live_worker is None or live_worker.job_id != _job_id(job):
continue
artifacts = JobArtifacts.for_execution(
log_dir=self.log_dir,
job_id=live_worker.job_id,
execution_id=execution_id,
)
artifacts.log_path.parent.mkdir(parents=True, exist_ok=True)
if execution.running_status != JobExecutionStatus.RUNNING:
execution.running_status = JobExecutionStatus.RUNNING
execution.ended_at = None
execution.save()
message = f"scheduler: restored execution state from live worker pid {live_worker.pid} after app restart\n"
else:
message = f"scheduler: reattached to worker pid {live_worker.pid} after app restart\n"
log_handle = artifacts.log_path.open("a", encoding="utf-8", buffering=1)
log_handle.write(message)
self._workers[execution_id] = RunningWorker(
execution_id=execution_id,
process=DetachedProcess(pid=live_worker.pid),
log_handle=log_handle,
artifacts=artifacts,
detached=True,
)
recovered_execution_ids.add(execution_id)
stale_executions = tuple(
JobExecution.select(JobExecution, Job)
.join(Job)
@ -316,9 +397,12 @@ class JobRuntime:
for execution in stale_executions:
job = cast(Job, execution.job)
execution_id = _execution_id(execution)
if execution_id in recovered_execution_ids:
continue
job_id = _job_id(job)
artifacts = JobArtifacts.for_execution(
log_dir=self.log_dir,
job_id=_job_id(job),
job_id=job_id,
execution_id=execution_id,
)
artifacts.log_path.parent.mkdir(parents=True, exist_ok=True)
@ -335,7 +419,7 @@ class JobRuntime:
)
execution.save()
if stale_executions:
if stale_executions or recovered_execution_ids:
self._trigger_refresh()
@ -514,7 +598,7 @@ def _project_running_execution(
"worker": (
"graceful stop requested"
if execution.stop_requested_at
else "streaming stats from worker jsonl"
else "streaming stats from worker"
),
"log_href": f"/job/{job_id}/execution/{execution_id}/logs",
"log_exists": artifacts.log_path.exists(),
@ -675,6 +759,28 @@ def _final_status(*, execution: JobExecution, returncode: int) -> JobExecutionSt
return JobExecutionStatus.FAILED
def _worker_final_status(
*,
execution: JobExecution,
worker: RunningWorker,
returncode: int,
) -> JobExecutionStatus:
if worker.detached:
return _detached_final_status(execution=execution, artifacts=worker.artifacts)
return _final_status(execution=execution, returncode=returncode)
def _detached_final_status(
*, execution: JobExecution, artifacts: JobArtifacts
) -> JobExecutionStatus:
if execution.stop_requested_at is not None:
return JobExecutionStatus.CANCELED
log_tail = _read_log_tail(artifacts.log_path)
if "completed successfully" in log_tail:
return JobExecutionStatus.SUCCEEDED
return JobExecutionStatus.FAILED
def _coerce_datetime(value: datetime | str) -> datetime:
if isinstance(value, datetime):
if value.tzinfo is None:
@ -745,3 +851,87 @@ def _humanize_relative_time(reference_time: datetime, target_time: datetime) ->
if delta_seconds > 0:
return f"in {absolute_delta_seconds} seconds"
return f"{absolute_delta_seconds} seconds ago"
def _find_live_workers() -> dict[int, LiveWorker]:
proc_dir = Path("/proc")
if not proc_dir.exists():
return {}
live_workers: dict[int, LiveWorker] = {}
for cmdline_path in proc_dir.glob("[0-9]*/cmdline"):
try:
argv = [
part
for part in cmdline_path.read_bytes()
.decode("utf-8", errors="ignore")
.split("\x00")
if part != ""
]
except OSError:
continue
live_worker = _parse_live_worker(argv, pid=int(cmdline_path.parent.name))
if live_worker is None or not _pid_is_running(live_worker.pid):
continue
live_workers[live_worker.execution_id] = live_worker
return live_workers
def _parse_live_worker(argv: list[str], *, pid: int) -> LiveWorker | None:
if "repub.job_runner" not in argv:
return None
job_id = _argv_flag_value(argv, "--job-id")
execution_id = _argv_flag_value(argv, "--execution-id")
if job_id is None or execution_id is None:
return None
return LiveWorker(job_id=int(job_id), execution_id=int(execution_id), pid=pid)
def _argv_flag_value(argv: list[str], flag: str) -> str | None:
try:
index = argv.index(flag)
except ValueError:
return None
value_index = index + 1
if value_index >= len(argv):
return None
return argv[value_index]
def _pid_is_running(pid: int) -> bool:
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
except PermissionError:
return True
stat_path = Path(f"/proc/{pid}/stat")
if stat_path.exists():
try:
stat_text = stat_path.read_text(encoding="utf-8")
except OSError:
return True
state_parts = stat_text.split(") ", 1)
if len(state_parts) == 2 and state_parts[1].startswith("Z"):
return False
return True
def _send_signal(pid: int, signum: int) -> None:
try:
os.kill(pid, signum)
except ProcessLookupError:
return
def _read_log_tail(path: Path, *, max_bytes: int = 8192) -> str:
if not path.exists():
return ""
with path.open("rb") as handle:
handle.seek(0, os.SEEK_END)
size = handle.tell()
handle.seek(max(0, size - max_bytes))
return handle.read().decode("utf-8", errors="replace")

View file

@ -3,6 +3,8 @@ from __future__ import annotations
import asyncio
import json
import socketserver
import subprocess
import sys
import threading
import time
from datetime import UTC, datetime, timedelta
@ -226,6 +228,161 @@ def test_job_runtime_start_reconciles_stale_running_execution(tmp_path: Path) ->
runtime.shutdown()
def test_job_runtime_start_reattaches_live_worker_after_app_restart(
tmp_path: Path,
) -> None:
db_path = tmp_path / "live-worker.db"
log_dir = tmp_path / "out" / "logs"
initialize_database(db_path)
with _slow_feed_server() as feed_url:
source = create_source(
name="Live worker source",
slug="live-worker-source",
source_type="feed",
notes="",
spider_arguments="",
enabled=False,
cron_minute="*/5",
cron_hour="*",
cron_day_of_month="*",
cron_day_of_week="*",
cron_month="*",
feed_url=feed_url,
)
job = Job.get(Job.source == source)
execution = JobExecution.create(
job=job,
started_at=datetime.now(UTC),
running_status=JobExecutionStatus.RUNNING,
)
artifacts = JobArtifacts.for_execution(
log_dir=log_dir,
job_id=job.id,
execution_id=int(execution.get_id()),
)
artifacts.log_path.parent.mkdir(parents=True, exist_ok=True)
log_handle = artifacts.log_path.open("a", encoding="utf-8", buffering=1)
process = subprocess.Popen(
[
sys.executable,
"-u",
"-m",
"repub.job_runner",
"--job-id",
str(job.id),
"--execution-id",
str(execution.get_id()),
"--db-path",
str(db_path),
"--out-dir",
str(log_dir.parent),
"--stats-path",
str(artifacts.stats_path),
],
stdout=log_handle,
stderr=subprocess.STDOUT,
text=True,
)
runtime = JobRuntime(log_dir=log_dir)
try:
time.sleep(0.1)
runtime.start()
running_execution = JobExecution.get_by_id(execution.get_id())
assert running_execution.running_status == JobExecutionStatus.RUNNING
assert running_execution.ended_at is None
completed_execution = _wait_for_terminal_execution(int(execution.get_id()))
assert completed_execution.running_status == JobExecutionStatus.SUCCEEDED
assert "reattached" in artifacts.log_path.read_text(encoding="utf-8")
finally:
runtime.shutdown()
if process.poll() is None:
process.kill()
process.wait(timeout=2)
log_handle.close()
def test_job_runtime_start_restores_live_worker_marked_failed_by_restart_bug(
tmp_path: Path,
) -> None:
db_path = tmp_path / "restore-live-worker.db"
log_dir = tmp_path / "out" / "logs"
initialize_database(db_path)
with _slow_feed_server() as feed_url:
source = create_source(
name="Recovered worker source",
slug="recovered-worker-source",
source_type="feed",
notes="",
spider_arguments="",
enabled=False,
cron_minute="*/5",
cron_hour="*",
cron_day_of_month="*",
cron_day_of_week="*",
cron_month="*",
feed_url=feed_url,
)
job = Job.get(Job.source == source)
execution = JobExecution.create(
job=job,
started_at=datetime.now(UTC),
ended_at=datetime.now(UTC),
running_status=JobExecutionStatus.FAILED,
)
artifacts = JobArtifacts.for_execution(
log_dir=log_dir,
job_id=job.id,
execution_id=int(execution.get_id()),
)
artifacts.log_path.parent.mkdir(parents=True, exist_ok=True)
log_handle = artifacts.log_path.open("a", encoding="utf-8", buffering=1)
process = subprocess.Popen(
[
sys.executable,
"-u",
"-m",
"repub.job_runner",
"--job-id",
str(job.id),
"--execution-id",
str(execution.get_id()),
"--db-path",
str(db_path),
"--out-dir",
str(log_dir.parent),
"--stats-path",
str(artifacts.stats_path),
],
stdout=log_handle,
stderr=subprocess.STDOUT,
text=True,
)
runtime = JobRuntime(log_dir=log_dir)
try:
time.sleep(0.1)
runtime.start()
restored_execution = JobExecution.get_by_id(execution.get_id())
assert restored_execution.running_status == JobExecutionStatus.RUNNING
assert restored_execution.ended_at is None
completed_execution = _wait_for_terminal_execution(int(execution.get_id()))
assert completed_execution.running_status == JobExecutionStatus.SUCCEEDED
assert "restored execution state" in artifacts.log_path.read_text(
encoding="utf-8"
)
finally:
runtime.shutdown()
if process.poll() is None:
process.kill()
process.wait(timeout=2)
log_handle.close()
def test_generate_pangea_feed_writes_pangea_rss_file(
monkeypatch, tmp_path: Path
) -> None:

View file

@ -440,7 +440,8 @@ def test_render_create_source_shows_dedicated_form_page() -> None:
not in body
)
assert "language=en,download_media=true" not in body
assert "language=en\ndownload_media=true" in body
assert 'id="spider-arguments"' in body
assert "language=en\ndownload_media=true" not in body
assert 'value="articles"' in body
assert 'value="10"' in body
assert 'value="3"' in body