diff --git a/repub/jobs.py b/repub/jobs.py index 5774195..de504ff 100644 --- a/repub/jobs.py +++ b/repub/jobs.py @@ -1,8 +1,11 @@ from __future__ import annotations import json +import os +import signal import subprocess import sys +import time from dataclasses import dataclass from datetime import UTC, datetime, timedelta from pathlib import Path @@ -38,10 +41,43 @@ class JobArtifacts: @dataclass class RunningWorker: execution_id: int - process: subprocess.Popen[str] + process: subprocess.Popen[str] | DetachedProcess log_handle: TextIO artifacts: JobArtifacts stats_offset: int = 0 + detached: bool = False + + +@dataclass(frozen=True) +class DetachedProcess: + pid: int + + def poll(self) -> int | None: + return None if _pid_is_running(self.pid) else 0 + + def terminate(self) -> None: + _send_signal(self.pid, signal.SIGTERM) + + def kill(self) -> None: + _send_signal(self.pid, signal.SIGKILL) + + def wait(self, timeout: float | None = None) -> int: + deadline = None if timeout is None else time.monotonic() + timeout + while self.poll() is None: + if deadline is not None and time.monotonic() >= deadline: + raise subprocess.TimeoutExpired( + cmd=f"pid {self.pid}", + timeout=timeout if timeout is not None else 0.0, + ) + time.sleep(0.05) + return 0 + + +@dataclass(frozen=True) +class LiveWorker: + job_id: int + execution_id: int + pid: int @dataclass(frozen=True) @@ -248,8 +284,9 @@ class JobRuntime: with database.connection_context(): execution = JobExecution.get_by_id(execution_id) execution.ended_at = utc_now() - execution.running_status = _final_status( + execution.running_status = _worker_final_status( execution=execution, + worker=worker, returncode=returncode, ) execution.save() @@ -306,7 +343,51 @@ class JobRuntime: self.refresh_callback() def _reconcile_stale_executions(self) -> None: + live_workers = _find_live_workers() + recovered_execution_ids: set[int] = set() with database.connection_context(): + execution_primary_key = getattr(JobExecution, "_meta").primary_key + if live_workers: + live_executions = tuple( + JobExecution.select(JobExecution, Job) + .join(Job) + .where(execution_primary_key.in_(tuple(live_workers))) + ) + else: + live_executions = () + + for execution in live_executions: + job = cast(Job, execution.job) + execution_id = _execution_id(execution) + live_worker = live_workers.get(execution_id) + if live_worker is None or live_worker.job_id != _job_id(job): + continue + + artifacts = JobArtifacts.for_execution( + log_dir=self.log_dir, + job_id=live_worker.job_id, + execution_id=execution_id, + ) + artifacts.log_path.parent.mkdir(parents=True, exist_ok=True) + if execution.running_status != JobExecutionStatus.RUNNING: + execution.running_status = JobExecutionStatus.RUNNING + execution.ended_at = None + execution.save() + message = f"scheduler: restored execution state from live worker pid {live_worker.pid} after app restart\n" + else: + message = f"scheduler: reattached to worker pid {live_worker.pid} after app restart\n" + + log_handle = artifacts.log_path.open("a", encoding="utf-8", buffering=1) + log_handle.write(message) + self._workers[execution_id] = RunningWorker( + execution_id=execution_id, + process=DetachedProcess(pid=live_worker.pid), + log_handle=log_handle, + artifacts=artifacts, + detached=True, + ) + recovered_execution_ids.add(execution_id) + stale_executions = tuple( JobExecution.select(JobExecution, Job) .join(Job) @@ -316,9 +397,12 @@ class JobRuntime: for execution in stale_executions: job = cast(Job, execution.job) execution_id = _execution_id(execution) + if execution_id in recovered_execution_ids: + continue + job_id = _job_id(job) artifacts = JobArtifacts.for_execution( log_dir=self.log_dir, - job_id=_job_id(job), + job_id=job_id, execution_id=execution_id, ) artifacts.log_path.parent.mkdir(parents=True, exist_ok=True) @@ -335,7 +419,7 @@ class JobRuntime: ) execution.save() - if stale_executions: + if stale_executions or recovered_execution_ids: self._trigger_refresh() @@ -514,7 +598,7 @@ def _project_running_execution( "worker": ( "graceful stop requested" if execution.stop_requested_at - else "streaming stats from worker jsonl" + else "streaming stats from worker" ), "log_href": f"/job/{job_id}/execution/{execution_id}/logs", "log_exists": artifacts.log_path.exists(), @@ -675,6 +759,28 @@ def _final_status(*, execution: JobExecution, returncode: int) -> JobExecutionSt return JobExecutionStatus.FAILED +def _worker_final_status( + *, + execution: JobExecution, + worker: RunningWorker, + returncode: int, +) -> JobExecutionStatus: + if worker.detached: + return _detached_final_status(execution=execution, artifacts=worker.artifacts) + return _final_status(execution=execution, returncode=returncode) + + +def _detached_final_status( + *, execution: JobExecution, artifacts: JobArtifacts +) -> JobExecutionStatus: + if execution.stop_requested_at is not None: + return JobExecutionStatus.CANCELED + log_tail = _read_log_tail(artifacts.log_path) + if "completed successfully" in log_tail: + return JobExecutionStatus.SUCCEEDED + return JobExecutionStatus.FAILED + + def _coerce_datetime(value: datetime | str) -> datetime: if isinstance(value, datetime): if value.tzinfo is None: @@ -745,3 +851,87 @@ def _humanize_relative_time(reference_time: datetime, target_time: datetime) -> if delta_seconds > 0: return f"in {absolute_delta_seconds} seconds" return f"{absolute_delta_seconds} seconds ago" + + +def _find_live_workers() -> dict[int, LiveWorker]: + proc_dir = Path("/proc") + if not proc_dir.exists(): + return {} + + live_workers: dict[int, LiveWorker] = {} + for cmdline_path in proc_dir.glob("[0-9]*/cmdline"): + try: + argv = [ + part + for part in cmdline_path.read_bytes() + .decode("utf-8", errors="ignore") + .split("\x00") + if part != "" + ] + except OSError: + continue + + live_worker = _parse_live_worker(argv, pid=int(cmdline_path.parent.name)) + if live_worker is None or not _pid_is_running(live_worker.pid): + continue + + live_workers[live_worker.execution_id] = live_worker + return live_workers + + +def _parse_live_worker(argv: list[str], *, pid: int) -> LiveWorker | None: + if "repub.job_runner" not in argv: + return None + job_id = _argv_flag_value(argv, "--job-id") + execution_id = _argv_flag_value(argv, "--execution-id") + if job_id is None or execution_id is None: + return None + return LiveWorker(job_id=int(job_id), execution_id=int(execution_id), pid=pid) + + +def _argv_flag_value(argv: list[str], flag: str) -> str | None: + try: + index = argv.index(flag) + except ValueError: + return None + value_index = index + 1 + if value_index >= len(argv): + return None + return argv[value_index] + + +def _pid_is_running(pid: int) -> bool: + try: + os.kill(pid, 0) + except ProcessLookupError: + return False + except PermissionError: + return True + stat_path = Path(f"/proc/{pid}/stat") + if stat_path.exists(): + try: + stat_text = stat_path.read_text(encoding="utf-8") + except OSError: + return True + state_parts = stat_text.split(") ", 1) + if len(state_parts) == 2 and state_parts[1].startswith("Z"): + return False + return True + + +def _send_signal(pid: int, signum: int) -> None: + try: + os.kill(pid, signum) + except ProcessLookupError: + return + + +def _read_log_tail(path: Path, *, max_bytes: int = 8192) -> str: + if not path.exists(): + return "" + + with path.open("rb") as handle: + handle.seek(0, os.SEEK_END) + size = handle.tell() + handle.seek(max(0, size - max_bytes)) + return handle.read().decode("utf-8", errors="replace") diff --git a/tests/test_scheduler_runtime.py b/tests/test_scheduler_runtime.py index d9964ff..31ee7fa 100644 --- a/tests/test_scheduler_runtime.py +++ b/tests/test_scheduler_runtime.py @@ -3,6 +3,8 @@ from __future__ import annotations import asyncio import json import socketserver +import subprocess +import sys import threading import time from datetime import UTC, datetime, timedelta @@ -226,6 +228,161 @@ def test_job_runtime_start_reconciles_stale_running_execution(tmp_path: Path) -> runtime.shutdown() +def test_job_runtime_start_reattaches_live_worker_after_app_restart( + tmp_path: Path, +) -> None: + db_path = tmp_path / "live-worker.db" + log_dir = tmp_path / "out" / "logs" + initialize_database(db_path) + with _slow_feed_server() as feed_url: + source = create_source( + name="Live worker source", + slug="live-worker-source", + source_type="feed", + notes="", + spider_arguments="", + enabled=False, + cron_minute="*/5", + cron_hour="*", + cron_day_of_month="*", + cron_day_of_week="*", + cron_month="*", + feed_url=feed_url, + ) + job = Job.get(Job.source == source) + execution = JobExecution.create( + job=job, + started_at=datetime.now(UTC), + running_status=JobExecutionStatus.RUNNING, + ) + artifacts = JobArtifacts.for_execution( + log_dir=log_dir, + job_id=job.id, + execution_id=int(execution.get_id()), + ) + artifacts.log_path.parent.mkdir(parents=True, exist_ok=True) + log_handle = artifacts.log_path.open("a", encoding="utf-8", buffering=1) + process = subprocess.Popen( + [ + sys.executable, + "-u", + "-m", + "repub.job_runner", + "--job-id", + str(job.id), + "--execution-id", + str(execution.get_id()), + "--db-path", + str(db_path), + "--out-dir", + str(log_dir.parent), + "--stats-path", + str(artifacts.stats_path), + ], + stdout=log_handle, + stderr=subprocess.STDOUT, + text=True, + ) + + runtime = JobRuntime(log_dir=log_dir) + try: + time.sleep(0.1) + runtime.start() + + running_execution = JobExecution.get_by_id(execution.get_id()) + assert running_execution.running_status == JobExecutionStatus.RUNNING + assert running_execution.ended_at is None + + completed_execution = _wait_for_terminal_execution(int(execution.get_id())) + assert completed_execution.running_status == JobExecutionStatus.SUCCEEDED + assert "reattached" in artifacts.log_path.read_text(encoding="utf-8") + finally: + runtime.shutdown() + if process.poll() is None: + process.kill() + process.wait(timeout=2) + log_handle.close() + + +def test_job_runtime_start_restores_live_worker_marked_failed_by_restart_bug( + tmp_path: Path, +) -> None: + db_path = tmp_path / "restore-live-worker.db" + log_dir = tmp_path / "out" / "logs" + initialize_database(db_path) + with _slow_feed_server() as feed_url: + source = create_source( + name="Recovered worker source", + slug="recovered-worker-source", + source_type="feed", + notes="", + spider_arguments="", + enabled=False, + cron_minute="*/5", + cron_hour="*", + cron_day_of_month="*", + cron_day_of_week="*", + cron_month="*", + feed_url=feed_url, + ) + job = Job.get(Job.source == source) + execution = JobExecution.create( + job=job, + started_at=datetime.now(UTC), + ended_at=datetime.now(UTC), + running_status=JobExecutionStatus.FAILED, + ) + artifacts = JobArtifacts.for_execution( + log_dir=log_dir, + job_id=job.id, + execution_id=int(execution.get_id()), + ) + artifacts.log_path.parent.mkdir(parents=True, exist_ok=True) + log_handle = artifacts.log_path.open("a", encoding="utf-8", buffering=1) + process = subprocess.Popen( + [ + sys.executable, + "-u", + "-m", + "repub.job_runner", + "--job-id", + str(job.id), + "--execution-id", + str(execution.get_id()), + "--db-path", + str(db_path), + "--out-dir", + str(log_dir.parent), + "--stats-path", + str(artifacts.stats_path), + ], + stdout=log_handle, + stderr=subprocess.STDOUT, + text=True, + ) + + runtime = JobRuntime(log_dir=log_dir) + try: + time.sleep(0.1) + runtime.start() + + restored_execution = JobExecution.get_by_id(execution.get_id()) + assert restored_execution.running_status == JobExecutionStatus.RUNNING + assert restored_execution.ended_at is None + + completed_execution = _wait_for_terminal_execution(int(execution.get_id())) + assert completed_execution.running_status == JobExecutionStatus.SUCCEEDED + assert "restored execution state" in artifacts.log_path.read_text( + encoding="utf-8" + ) + finally: + runtime.shutdown() + if process.poll() is None: + process.kill() + process.wait(timeout=2) + log_handle.close() + + def test_generate_pangea_feed_writes_pangea_rss_file( monkeypatch, tmp_path: Path ) -> None: diff --git a/tests/test_web.py b/tests/test_web.py index e668543..eff5922 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -440,7 +440,8 @@ def test_render_create_source_shows_dedicated_form_page() -> None: not in body ) assert "language=en,download_media=true" not in body - assert "language=en\ndownload_media=true" in body + assert 'id="spider-arguments"' in body + assert "language=en\ndownload_media=true" not in body assert 'value="articles"' in body assert 'value="10"' in body assert 'value="3"' in body