fix attach workers

This commit is contained in:
Abel Luck 2026-03-30 15:53:04 +02:00
parent 0c36ee6662
commit ec4bdf1096
3 changed files with 354 additions and 6 deletions

View file

@ -1,8 +1,11 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
import signal
import subprocess import subprocess
import sys import sys
import time
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from pathlib import Path from pathlib import Path
@ -38,10 +41,43 @@ class JobArtifacts:
@dataclass @dataclass
class RunningWorker: class RunningWorker:
execution_id: int execution_id: int
process: subprocess.Popen[str] process: subprocess.Popen[str] | DetachedProcess
log_handle: TextIO log_handle: TextIO
artifacts: JobArtifacts artifacts: JobArtifacts
stats_offset: int = 0 stats_offset: int = 0
detached: bool = False
@dataclass(frozen=True)
class DetachedProcess:
pid: int
def poll(self) -> int | None:
return None if _pid_is_running(self.pid) else 0
def terminate(self) -> None:
_send_signal(self.pid, signal.SIGTERM)
def kill(self) -> None:
_send_signal(self.pid, signal.SIGKILL)
def wait(self, timeout: float | None = None) -> int:
deadline = None if timeout is None else time.monotonic() + timeout
while self.poll() is None:
if deadline is not None and time.monotonic() >= deadline:
raise subprocess.TimeoutExpired(
cmd=f"pid {self.pid}",
timeout=timeout if timeout is not None else 0.0,
)
time.sleep(0.05)
return 0
@dataclass(frozen=True)
class LiveWorker:
job_id: int
execution_id: int
pid: int
@dataclass(frozen=True) @dataclass(frozen=True)
@ -248,8 +284,9 @@ class JobRuntime:
with database.connection_context(): with database.connection_context():
execution = JobExecution.get_by_id(execution_id) execution = JobExecution.get_by_id(execution_id)
execution.ended_at = utc_now() execution.ended_at = utc_now()
execution.running_status = _final_status( execution.running_status = _worker_final_status(
execution=execution, execution=execution,
worker=worker,
returncode=returncode, returncode=returncode,
) )
execution.save() execution.save()
@ -306,7 +343,51 @@ class JobRuntime:
self.refresh_callback() self.refresh_callback()
def _reconcile_stale_executions(self) -> None: def _reconcile_stale_executions(self) -> None:
live_workers = _find_live_workers()
recovered_execution_ids: set[int] = set()
with database.connection_context(): with database.connection_context():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
if live_workers:
live_executions = tuple(
JobExecution.select(JobExecution, Job)
.join(Job)
.where(execution_primary_key.in_(tuple(live_workers)))
)
else:
live_executions = ()
for execution in live_executions:
job = cast(Job, execution.job)
execution_id = _execution_id(execution)
live_worker = live_workers.get(execution_id)
if live_worker is None or live_worker.job_id != _job_id(job):
continue
artifacts = JobArtifacts.for_execution(
log_dir=self.log_dir,
job_id=live_worker.job_id,
execution_id=execution_id,
)
artifacts.log_path.parent.mkdir(parents=True, exist_ok=True)
if execution.running_status != JobExecutionStatus.RUNNING:
execution.running_status = JobExecutionStatus.RUNNING
execution.ended_at = None
execution.save()
message = f"scheduler: restored execution state from live worker pid {live_worker.pid} after app restart\n"
else:
message = f"scheduler: reattached to worker pid {live_worker.pid} after app restart\n"
log_handle = artifacts.log_path.open("a", encoding="utf-8", buffering=1)
log_handle.write(message)
self._workers[execution_id] = RunningWorker(
execution_id=execution_id,
process=DetachedProcess(pid=live_worker.pid),
log_handle=log_handle,
artifacts=artifacts,
detached=True,
)
recovered_execution_ids.add(execution_id)
stale_executions = tuple( stale_executions = tuple(
JobExecution.select(JobExecution, Job) JobExecution.select(JobExecution, Job)
.join(Job) .join(Job)
@ -316,9 +397,12 @@ class JobRuntime:
for execution in stale_executions: for execution in stale_executions:
job = cast(Job, execution.job) job = cast(Job, execution.job)
execution_id = _execution_id(execution) execution_id = _execution_id(execution)
if execution_id in recovered_execution_ids:
continue
job_id = _job_id(job)
artifacts = JobArtifacts.for_execution( artifacts = JobArtifacts.for_execution(
log_dir=self.log_dir, log_dir=self.log_dir,
job_id=_job_id(job), job_id=job_id,
execution_id=execution_id, execution_id=execution_id,
) )
artifacts.log_path.parent.mkdir(parents=True, exist_ok=True) artifacts.log_path.parent.mkdir(parents=True, exist_ok=True)
@ -335,7 +419,7 @@ class JobRuntime:
) )
execution.save() execution.save()
if stale_executions: if stale_executions or recovered_execution_ids:
self._trigger_refresh() self._trigger_refresh()
@ -514,7 +598,7 @@ def _project_running_execution(
"worker": ( "worker": (
"graceful stop requested" "graceful stop requested"
if execution.stop_requested_at if execution.stop_requested_at
else "streaming stats from worker jsonl" else "streaming stats from worker"
), ),
"log_href": f"/job/{job_id}/execution/{execution_id}/logs", "log_href": f"/job/{job_id}/execution/{execution_id}/logs",
"log_exists": artifacts.log_path.exists(), "log_exists": artifacts.log_path.exists(),
@ -675,6 +759,28 @@ def _final_status(*, execution: JobExecution, returncode: int) -> JobExecutionSt
return JobExecutionStatus.FAILED return JobExecutionStatus.FAILED
def _worker_final_status(
*,
execution: JobExecution,
worker: RunningWorker,
returncode: int,
) -> JobExecutionStatus:
if worker.detached:
return _detached_final_status(execution=execution, artifacts=worker.artifacts)
return _final_status(execution=execution, returncode=returncode)
def _detached_final_status(
*, execution: JobExecution, artifacts: JobArtifacts
) -> JobExecutionStatus:
if execution.stop_requested_at is not None:
return JobExecutionStatus.CANCELED
log_tail = _read_log_tail(artifacts.log_path)
if "completed successfully" in log_tail:
return JobExecutionStatus.SUCCEEDED
return JobExecutionStatus.FAILED
def _coerce_datetime(value: datetime | str) -> datetime: def _coerce_datetime(value: datetime | str) -> datetime:
if isinstance(value, datetime): if isinstance(value, datetime):
if value.tzinfo is None: if value.tzinfo is None:
@ -745,3 +851,87 @@ def _humanize_relative_time(reference_time: datetime, target_time: datetime) ->
if delta_seconds > 0: if delta_seconds > 0:
return f"in {absolute_delta_seconds} seconds" return f"in {absolute_delta_seconds} seconds"
return f"{absolute_delta_seconds} seconds ago" return f"{absolute_delta_seconds} seconds ago"
def _find_live_workers() -> dict[int, LiveWorker]:
proc_dir = Path("/proc")
if not proc_dir.exists():
return {}
live_workers: dict[int, LiveWorker] = {}
for cmdline_path in proc_dir.glob("[0-9]*/cmdline"):
try:
argv = [
part
for part in cmdline_path.read_bytes()
.decode("utf-8", errors="ignore")
.split("\x00")
if part != ""
]
except OSError:
continue
live_worker = _parse_live_worker(argv, pid=int(cmdline_path.parent.name))
if live_worker is None or not _pid_is_running(live_worker.pid):
continue
live_workers[live_worker.execution_id] = live_worker
return live_workers
def _parse_live_worker(argv: list[str], *, pid: int) -> LiveWorker | None:
if "repub.job_runner" not in argv:
return None
job_id = _argv_flag_value(argv, "--job-id")
execution_id = _argv_flag_value(argv, "--execution-id")
if job_id is None or execution_id is None:
return None
return LiveWorker(job_id=int(job_id), execution_id=int(execution_id), pid=pid)
def _argv_flag_value(argv: list[str], flag: str) -> str | None:
try:
index = argv.index(flag)
except ValueError:
return None
value_index = index + 1
if value_index >= len(argv):
return None
return argv[value_index]
def _pid_is_running(pid: int) -> bool:
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
except PermissionError:
return True
stat_path = Path(f"/proc/{pid}/stat")
if stat_path.exists():
try:
stat_text = stat_path.read_text(encoding="utf-8")
except OSError:
return True
state_parts = stat_text.split(") ", 1)
if len(state_parts) == 2 and state_parts[1].startswith("Z"):
return False
return True
def _send_signal(pid: int, signum: int) -> None:
try:
os.kill(pid, signum)
except ProcessLookupError:
return
def _read_log_tail(path: Path, *, max_bytes: int = 8192) -> str:
if not path.exists():
return ""
with path.open("rb") as handle:
handle.seek(0, os.SEEK_END)
size = handle.tell()
handle.seek(max(0, size - max_bytes))
return handle.read().decode("utf-8", errors="replace")

View file

@ -3,6 +3,8 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import socketserver import socketserver
import subprocess
import sys
import threading import threading
import time import time
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
@ -226,6 +228,161 @@ def test_job_runtime_start_reconciles_stale_running_execution(tmp_path: Path) ->
runtime.shutdown() runtime.shutdown()
def test_job_runtime_start_reattaches_live_worker_after_app_restart(
tmp_path: Path,
) -> None:
db_path = tmp_path / "live-worker.db"
log_dir = tmp_path / "out" / "logs"
initialize_database(db_path)
with _slow_feed_server() as feed_url:
source = create_source(
name="Live worker source",
slug="live-worker-source",
source_type="feed",
notes="",
spider_arguments="",
enabled=False,
cron_minute="*/5",
cron_hour="*",
cron_day_of_month="*",
cron_day_of_week="*",
cron_month="*",
feed_url=feed_url,
)
job = Job.get(Job.source == source)
execution = JobExecution.create(
job=job,
started_at=datetime.now(UTC),
running_status=JobExecutionStatus.RUNNING,
)
artifacts = JobArtifacts.for_execution(
log_dir=log_dir,
job_id=job.id,
execution_id=int(execution.get_id()),
)
artifacts.log_path.parent.mkdir(parents=True, exist_ok=True)
log_handle = artifacts.log_path.open("a", encoding="utf-8", buffering=1)
process = subprocess.Popen(
[
sys.executable,
"-u",
"-m",
"repub.job_runner",
"--job-id",
str(job.id),
"--execution-id",
str(execution.get_id()),
"--db-path",
str(db_path),
"--out-dir",
str(log_dir.parent),
"--stats-path",
str(artifacts.stats_path),
],
stdout=log_handle,
stderr=subprocess.STDOUT,
text=True,
)
runtime = JobRuntime(log_dir=log_dir)
try:
time.sleep(0.1)
runtime.start()
running_execution = JobExecution.get_by_id(execution.get_id())
assert running_execution.running_status == JobExecutionStatus.RUNNING
assert running_execution.ended_at is None
completed_execution = _wait_for_terminal_execution(int(execution.get_id()))
assert completed_execution.running_status == JobExecutionStatus.SUCCEEDED
assert "reattached" in artifacts.log_path.read_text(encoding="utf-8")
finally:
runtime.shutdown()
if process.poll() is None:
process.kill()
process.wait(timeout=2)
log_handle.close()
def test_job_runtime_start_restores_live_worker_marked_failed_by_restart_bug(
tmp_path: Path,
) -> None:
db_path = tmp_path / "restore-live-worker.db"
log_dir = tmp_path / "out" / "logs"
initialize_database(db_path)
with _slow_feed_server() as feed_url:
source = create_source(
name="Recovered worker source",
slug="recovered-worker-source",
source_type="feed",
notes="",
spider_arguments="",
enabled=False,
cron_minute="*/5",
cron_hour="*",
cron_day_of_month="*",
cron_day_of_week="*",
cron_month="*",
feed_url=feed_url,
)
job = Job.get(Job.source == source)
execution = JobExecution.create(
job=job,
started_at=datetime.now(UTC),
ended_at=datetime.now(UTC),
running_status=JobExecutionStatus.FAILED,
)
artifacts = JobArtifacts.for_execution(
log_dir=log_dir,
job_id=job.id,
execution_id=int(execution.get_id()),
)
artifacts.log_path.parent.mkdir(parents=True, exist_ok=True)
log_handle = artifacts.log_path.open("a", encoding="utf-8", buffering=1)
process = subprocess.Popen(
[
sys.executable,
"-u",
"-m",
"repub.job_runner",
"--job-id",
str(job.id),
"--execution-id",
str(execution.get_id()),
"--db-path",
str(db_path),
"--out-dir",
str(log_dir.parent),
"--stats-path",
str(artifacts.stats_path),
],
stdout=log_handle,
stderr=subprocess.STDOUT,
text=True,
)
runtime = JobRuntime(log_dir=log_dir)
try:
time.sleep(0.1)
runtime.start()
restored_execution = JobExecution.get_by_id(execution.get_id())
assert restored_execution.running_status == JobExecutionStatus.RUNNING
assert restored_execution.ended_at is None
completed_execution = _wait_for_terminal_execution(int(execution.get_id()))
assert completed_execution.running_status == JobExecutionStatus.SUCCEEDED
assert "restored execution state" in artifacts.log_path.read_text(
encoding="utf-8"
)
finally:
runtime.shutdown()
if process.poll() is None:
process.kill()
process.wait(timeout=2)
log_handle.close()
def test_generate_pangea_feed_writes_pangea_rss_file( def test_generate_pangea_feed_writes_pangea_rss_file(
monkeypatch, tmp_path: Path monkeypatch, tmp_path: Path
) -> None: ) -> None:

View file

@ -440,7 +440,8 @@ def test_render_create_source_shows_dedicated_form_page() -> None:
not in body not in body
) )
assert "language=en,download_media=true" not in body assert "language=en,download_media=true" not in body
assert "language=en\ndownload_media=true" in body assert 'id="spider-arguments"' in body
assert "language=en\ndownload_media=true" not in body
assert 'value="articles"' in body assert 'value="articles"' in body
assert 'value="10"' in body assert 'value="10"' in body
assert 'value="3"' in body assert 'value="3"' in body