Refactor database access through managed connections

This commit is contained in:
Abel Luck 2026-03-31 17:30:07 +02:00
parent f19bab6fa2
commit 3f28e46ff6
10 changed files with 1327 additions and 716 deletions

View file

@ -18,6 +18,7 @@ from apscheduler.triggers.cron import CronTrigger
from peewee import IntegrityError
from repub.config import feed_output_dir, feed_output_path
from repub.db import get_database_connection
from repub.model import (
Job,
JobExecution,
@ -173,7 +174,7 @@ class JobRuntime:
self._started = False
def sync_jobs(self) -> None:
with database.connection_context():
with database.reader():
jobs = tuple(Job.select().where(Job.enabled == True)) # noqa: E712
desired_ids = set()
@ -216,35 +217,34 @@ class JobRuntime:
return execution_id
def _enqueue_job_run_locked(self, job_id: int) -> int | None:
with database.connection_context():
with database.atomic():
job = Job.get_or_none(id=job_id)
if job is None:
return None
with database.writer():
job = Job.get_or_none(id=job_id)
if job is None:
return None
pending_execution = JobExecution.get_or_none(
(JobExecution.job == job)
& (JobExecution.running_status == JobExecutionStatus.PENDING)
)
if pending_execution is not None:
return _execution_id(pending_execution)
try:
execution = JobExecution.create(
job=job,
running_status=JobExecutionStatus.PENDING,
)
except IntegrityError:
pending_execution = JobExecution.get_or_none(
(JobExecution.job == job)
& (JobExecution.running_status == JobExecutionStatus.PENDING)
)
if pending_execution is not None:
return _execution_id(pending_execution)
try:
execution = JobExecution.create(
job=job,
running_status=JobExecutionStatus.PENDING,
)
except IntegrityError:
pending_execution = JobExecution.get_or_none(
(JobExecution.job == job)
& (JobExecution.running_status == JobExecutionStatus.PENDING)
)
return (
_execution_id(pending_execution)
if pending_execution is not None
else None
)
return _execution_id(execution)
return (
_execution_id(pending_execution)
if pending_execution is not None
else None
)
return _execution_id(execution)
def _start_queued_jobs(self) -> None:
with self._run_lock:
@ -259,14 +259,13 @@ class JobRuntime:
if claimed_execution is None:
return
job = cast(Job, claimed_execution.job)
self._start_worker_for_execution(
job_id=_job_id(job),
execution_id=_execution_id(claimed_execution),
job_id=claimed_execution[0],
execution_id=claimed_execution[1],
)
def _claim_next_pending_execution(self) -> JobExecution | None:
with database.connection_context():
def _claim_next_pending_execution(self) -> tuple[int, int] | None:
with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
pending_executions = tuple(
JobExecution.select(JobExecution, Job)
@ -301,10 +300,21 @@ class JobRuntime:
.execute()
)
if claimed:
return JobExecution.get_by_id(_execution_id(execution))
claimed_execution = (
JobExecution.select(JobExecution, Job)
.join(Job)
.where(execution_primary_key == _execution_id(execution))
.get()
)
claimed_job = cast(Job, claimed_execution.job)
return (_job_id(claimed_job), _execution_id(claimed_execution))
return None
def _start_worker_for_execution(self, *, job_id: int, execution_id: int) -> None:
database_connection = get_database_connection()
if database_connection is None:
raise RuntimeError("Database has not been initialized.")
artifacts = JobArtifacts.for_execution(
log_dir=self.log_dir, job_id=job_id, execution_id=execution_id
)
@ -324,7 +334,7 @@ class JobRuntime:
"--execution-id",
str(execution_id),
"--db-path",
str(database.database),
str(database_connection.db_path),
"--out-dir",
str(self.log_dir.parent),
"--stats-path",
@ -342,15 +352,16 @@ class JobRuntime:
)
def _max_concurrent_jobs_reached(self) -> bool:
return (
JobExecution.select()
.where(JobExecution.running_status == JobExecutionStatus.RUNNING)
.count()
>= load_max_concurrent_jobs()
)
with database.reader():
return (
JobExecution.select()
.where(JobExecution.running_status == JobExecutionStatus.RUNNING)
.count()
>= load_max_concurrent_jobs()
)
def request_execution_cancel(self, execution_id: int) -> bool:
with database.connection_context():
with database.writer():
execution = JobExecution.get_or_none(id=execution_id)
if execution is None:
return False
@ -372,7 +383,7 @@ class JobRuntime:
def cancel_queued_execution(self, execution_id: int) -> bool:
with self._run_lock:
with database.connection_context():
with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
deleted = (
JobExecution.delete()
@ -392,7 +403,7 @@ class JobRuntime:
def move_queued_execution(self, execution_id: int, *, direction: str) -> bool:
offset = -1 if direction == "up" else 1
with self._run_lock:
with database.connection_context():
with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
queued_executions = tuple(
JobExecution.select()
@ -425,59 +436,50 @@ class JobRuntime:
cast(datetime | str, target_execution.created_at)
)
with database.atomic():
if current_created_at == target_created_at:
adjusted_created_at = target_created_at + timedelta(
microseconds=-1 if offset < 0 else 1
if current_created_at == target_created_at:
adjusted_created_at = target_created_at + timedelta(
microseconds=-1 if offset < 0 else 1
)
(
JobExecution.update(created_at=adjusted_created_at)
.where(
execution_primary_key == _execution_id(current_execution)
)
(
JobExecution.update(created_at=adjusted_created_at)
.where(
execution_primary_key
== _execution_id(current_execution)
)
.execute()
)
else:
(
JobExecution.update(created_at=target_created_at)
.where(
execution_primary_key
== _execution_id(current_execution)
)
.execute()
)
(
JobExecution.update(created_at=current_created_at)
.where(
execution_primary_key == _execution_id(target_execution)
)
.execute()
.execute()
)
else:
(
JobExecution.update(created_at=target_created_at)
.where(
execution_primary_key == _execution_id(current_execution)
)
.execute()
)
(
JobExecution.update(created_at=current_created_at)
.where(execution_primary_key == _execution_id(target_execution))
.execute()
)
self._trigger_refresh()
return True
def set_job_enabled(self, job_id: int, *, enabled: bool) -> bool:
with database.connection_context():
with database.atomic():
job = Job.get_or_none(id=job_id)
if job is None:
return False
job.enabled = enabled
job.save()
if not enabled:
(
JobExecution.delete()
.where(
(JobExecution.job == job)
& (
JobExecution.running_status
== JobExecutionStatus.PENDING
)
)
.execute()
with database.writer():
job = Job.get_or_none(id=job_id)
if job is None:
return False
job.enabled = enabled
job.save()
if not enabled:
(
JobExecution.delete()
.where(
(JobExecution.job == job)
& (JobExecution.running_status == JobExecutionStatus.PENDING)
)
.execute()
)
self.sync_jobs()
self._trigger_refresh()
return True
@ -493,7 +495,7 @@ class JobRuntime:
continue
self._apply_stats(worker)
with database.connection_context():
with database.writer():
execution = JobExecution.get_by_id(execution_id)
execution.ended_at = utc_now()
execution.running_status = _worker_final_status(
@ -527,7 +529,7 @@ class JobRuntime:
return
stats = json.loads(lines[-1])
with database.connection_context():
with database.writer():
execution = JobExecution.get_by_id(worker.execution_id)
execution.requests_count = int(stats.get("requests_count", 0))
execution.items_count = int(stats.get("items_count", 0))
@ -544,7 +546,7 @@ class JobRuntime:
self._trigger_refresh()
def _enforce_graceful_stop(self, worker: RunningWorker) -> None:
with database.connection_context():
with database.reader():
execution = JobExecution.get_by_id(worker.execution_id)
if execution.stop_requested_at is None:
return
@ -572,16 +574,17 @@ class JobRuntime:
self._trigger_refresh()
def _has_running_executions(self) -> bool:
return (
JobExecution.select()
.where(JobExecution.running_status == JobExecutionStatus.RUNNING)
.exists()
)
with database.reader():
return (
JobExecution.select()
.where(JobExecution.running_status == JobExecutionStatus.RUNNING)
.exists()
)
def _reconcile_stale_executions(self) -> None:
live_workers = _find_live_workers()
recovered_execution_ids: set[int] = set()
with database.connection_context():
with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
if live_workers:
live_executions = tuple(
@ -669,7 +672,7 @@ def load_runs_view(
reference_time = now or datetime.now(UTC)
resolved_log_dir = Path(log_dir)
sanitized_page_size = max(1, completed_page_size)
with database.connection_context():
with database.reader():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
jobs = tuple(Job.select(Job, Source).join(Source).order_by(Source.name.asc()))
queued_executions = tuple(
@ -718,48 +721,52 @@ def load_runs_view(
_job_id(cast(Job, execution.job)): execution
for execution in queued_executions
}
return {
"running": tuple(
_project_running_execution(
execution,
resolved_log_dir,
reference_time,
queued_follow_up=queued_by_job.get(_job_id(cast(Job, execution.job))),
)
for execution in running_executions
),
"queued": tuple(
_project_queued_execution(
execution,
reference_time,
position=position,
total_count=len(queued_executions),
)
for position, execution in enumerate(queued_executions, start=1)
),
"upcoming": tuple(
_project_upcoming_job(
job,
running_by_job.get(job.id),
queued_by_job.get(job.id),
reference_time,
)
for job in jobs
),
"completed": tuple(
_project_completed_execution(execution, resolved_log_dir, reference_time)
for execution in completed_executions
),
"completed_page": sanitized_completed_page,
"completed_page_size": sanitized_page_size,
"completed_total_count": completed_total_count,
"completed_total_pages": completed_total_pages,
}
return {
"running": tuple(
_project_running_execution(
execution,
resolved_log_dir,
reference_time,
queued_follow_up=queued_by_job.get(
_job_id(cast(Job, execution.job))
),
)
for execution in running_executions
),
"queued": tuple(
_project_queued_execution(
execution,
reference_time,
position=position,
total_count=len(queued_executions),
)
for position, execution in enumerate(queued_executions, start=1)
),
"upcoming": tuple(
_project_upcoming_job(
job,
running_by_job.get(job.id),
queued_by_job.get(job.id),
reference_time,
)
for job in jobs
),
"completed": tuple(
_project_completed_execution(
execution, resolved_log_dir, reference_time
)
for execution in completed_executions
),
"completed_page": sanitized_completed_page,
"completed_page_size": sanitized_page_size,
"completed_total_count": completed_total_count,
"completed_total_pages": completed_total_pages,
}
def clear_completed_executions(*, log_dir: str | Path) -> int:
resolved_log_dir = Path(log_dir)
with database.connection_context():
with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
completed_executions = tuple(
JobExecution.select(JobExecution, Job)
@ -810,7 +817,7 @@ def load_dashboard_view(
upcoming_by_job_id = {
int(cast(int, job["job_id"])): job for job in runs_view["upcoming"]
}
with database.connection_context():
with database.reader():
jobs = tuple(Job.select(Job, Source).join(Source).order_by(Source.name.asc()))
failed_last_day = (
JobExecution.select()
@ -820,80 +827,85 @@ def load_dashboard_view(
)
.count()
)
upcoming_ready = sum(
1 for job in runs_view["upcoming"] if str(job["run_reason"]) == "Ready"
)
footprint_bytes = _directory_size(output_dir)
return {
"running": runs_view["running"],
"queued": runs_view["queued"],
"source_feeds": tuple(
_project_source_feed(
cast(Job, job),
output_dir,
reference_time,
running_execution=running_by_job_id.get(_job_id(cast(Job, job))),
queued_execution=queued_by_job_id.get(_job_id(cast(Job, job))),
upcoming_job=upcoming_by_job_id.get(_job_id(cast(Job, job))),
)
for job in jobs
),
"snapshot": {
"running_now": str(len(runs_view["running"])),
"upcoming_today": str(upcoming_ready),
"failures_24h": str(failed_last_day),
"artifact_footprint": _format_bytes(footprint_bytes),
},
}
upcoming_ready = sum(
1 for job in runs_view["upcoming"] if str(job["run_reason"]) == "Ready"
)
footprint_bytes = _directory_size(output_dir)
return {
"running": runs_view["running"],
"queued": runs_view["queued"],
"source_feeds": tuple(
_project_source_feed(
cast(Job, job),
output_dir,
reference_time,
running_execution=running_by_job_id.get(_job_id(cast(Job, job))),
queued_execution=queued_by_job_id.get(_job_id(cast(Job, job))),
upcoming_job=upcoming_by_job_id.get(_job_id(cast(Job, job))),
)
for job in jobs
),
"snapshot": {
"running_now": str(len(runs_view["running"])),
"upcoming_today": str(upcoming_ready),
"failures_24h": str(failed_last_day),
"artifact_footprint": _format_bytes(footprint_bytes),
},
}
def load_execution_log_view(
*, log_dir: str | Path, job_id: int, execution_id: int
) -> ExecutionLogView:
with database.connection_context():
execution = JobExecution.get_or_none(id=execution_id)
route = f"/job/{job_id}/execution/{execution_id}/logs"
if execution is None or _job_id(cast(Job, execution.job)) != job_id:
return ExecutionLogView(
job_id=job_id,
execution_id=execution_id,
title=f"Job {job_id} / execution {execution_id}",
description="Plain text log view routed through the app.",
status_label="Unavailable",
status_tone="failed",
log_text="",
error_message="Execution does not exist.",
with database.reader():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
execution = (
JobExecution.select(JobExecution, Job)
.join(Job)
.where(execution_primary_key == execution_id)
.get_or_none()
)
artifacts = JobArtifacts.for_execution(
log_dir=Path(log_dir),
job_id=job_id,
execution_id=execution_id,
)
if not artifacts.log_path.exists():
if execution is None or _job_id(cast(Job, execution.job)) != job_id:
return ExecutionLogView(
job_id=job_id,
execution_id=execution_id,
title=f"Job {job_id} / execution {execution_id}",
description="Plain text log view routed through the app.",
status_label="Unavailable",
status_tone="failed",
log_text="",
error_message="Execution does not exist.",
)
artifacts = JobArtifacts.for_execution(
log_dir=Path(log_dir),
job_id=job_id,
execution_id=execution_id,
)
if not artifacts.log_path.exists():
return ExecutionLogView(
job_id=job_id,
execution_id=execution_id,
title=f"Job {job_id} / execution {execution_id}",
description="Plain text log view routed through the app.",
status_label=_execution_status_label(execution),
status_tone=_execution_status_tone(execution),
log_text="",
error_message="Log file has not been created yet.",
)
return ExecutionLogView(
job_id=job_id,
execution_id=execution_id,
title=f"Job {job_id} / execution {execution_id}",
description="Plain text log view routed through the app.",
description=f"Route: {route}",
status_label=_execution_status_label(execution),
status_tone=_execution_status_tone(execution),
log_text="",
error_message="Log file has not been created yet.",
log_text=artifacts.log_path.read_text(encoding="utf-8"),
)
return ExecutionLogView(
job_id=job_id,
execution_id=execution_id,
title=f"Job {job_id} / execution {execution_id}",
description=f"Route: {route}",
status_label=_execution_status_label(execution),
status_tone=_execution_status_tone(execution),
log_text=artifacts.log_path.read_text(encoding="utf-8"),
)
def _job_trigger(job: Job) -> CronTrigger:
expression = " ".join(