Refactor database access through managed connections

This commit is contained in:
Abel Luck 2026-03-31 17:30:07 +02:00
parent f19bab6fa2
commit 3f28e46ff6
10 changed files with 1327 additions and 716 deletions

359
repub/db.py Normal file
View file

@ -0,0 +1,359 @@
from __future__ import annotations
import os
import queue
import re
import threading
from contextlib import contextmanager
from contextvars import ContextVar
from importlib import resources
from importlib.resources.abc import Traversable
from pathlib import Path
from typing import Iterator
from peewee import BooleanField, Check, SqliteDatabase
from playhouse.migrate import SchemaMigrator, migrate
DEFAULT_DB_PATH = Path("republisher.db")
DATABASE_PRAGMAS = {
"busy_timeout": 5000,
"cache_size": 15625,
"foreign_keys": 1,
"journal_mode": "wal",
"page_size": 4096,
"synchronous": "normal",
"temp_store": "memory",
}
SCHEMA_GLOB = "*.sql"
_WRITE_SQL_PREFIX = re.compile(
r"^\s*(INSERT|UPDATE|DELETE|REPLACE|CREATE|DROP|ALTER)\b"
)
_current_database: ContextVar[ManagedSqliteDatabase | None] = ContextVar(
"repub_current_database",
default=None,
)
_current_scope: ContextVar[str | None] = ContextVar(
"repub_current_database_scope",
default=None,
)
_database_connection: DatabaseConnection | None = None
def resolve_database_path(db_path: str | Path | None = None) -> Path:
raw_value = (
os.environ.get("REPUBLISHER_DB_PATH", DEFAULT_DB_PATH)
if db_path is None
else db_path
)
return Path(raw_value).expanduser().resolve()
def schema_paths() -> tuple[Traversable, ...]:
schema_dir = resources.files("repub").joinpath("sql")
return tuple(
sorted(
(path for path in schema_dir.iterdir() if path.name.endswith(".sql")),
key=lambda path: path.name,
)
)
class ManagedSqliteDatabase(SqliteDatabase):
def __init__(self, database: str, **kwargs) -> None:
pragmas = kwargs.pop("pragmas", DATABASE_PRAGMAS)
kwargs.setdefault("check_same_thread", False)
super().__init__(
database,
autoconnect=False,
thread_safe=False,
pragmas=pragmas,
**kwargs,
)
class RoutedSqliteDatabase(SqliteDatabase):
def __init__(self) -> None:
super().__init__(
":managed:",
autoconnect=False,
pragmas=DATABASE_PRAGMAS,
)
def _require_active_database(self) -> ManagedSqliteDatabase:
active_database = _current_database.get()
if active_database is None:
raise RuntimeError(
"Database access requires a database.reader() or database.writer() context."
)
return active_database
def _validate_sql(self, sql: str) -> None:
scope = _current_scope.get()
if scope in {"reader", "reader_conn"} and _WRITE_SQL_PREFIX.match(sql):
raise RuntimeError("Write query attempted inside database.reader() scope.")
def connect(self, reuse_if_open: bool = False): # type: ignore[override]
raise RuntimeError(
"Do not call database.connect() directly; use database.reader() or database.writer()."
)
def close(self): # type: ignore[override]
raise RuntimeError(
"Do not call database.close() directly; use database.reader() or database.writer()."
)
def is_closed(self) -> bool: # type: ignore[override]
active_database = _current_database.get()
return True if active_database is None else active_database.is_closed()
def connection(self): # type: ignore[override]
return self._require_active_database().connection()
def cursor(self, named_cursor=None): # type: ignore[override]
return self._require_active_database().cursor(named_cursor=named_cursor)
def execute_sql(self, sql, params=None): # type: ignore[override]
self._validate_sql(str(sql))
return self._require_active_database().execute_sql(sql, params)
def execute(self, query, **context_options): # type: ignore[override]
return self._require_active_database().execute(query, **context_options)
def atomic(self, *args, **kwargs): # type: ignore[override]
return self._require_active_database().atomic(*args, **kwargs)
def transaction(self, *args, **kwargs): # type: ignore[override]
return self._require_active_database().transaction(*args, **kwargs)
def savepoint(self, *args, **kwargs): # type: ignore[override]
return self._require_active_database().savepoint(*args, **kwargs)
def connection_context(self): # type: ignore[override]
raise RuntimeError(
"Do not call database.connection_context() directly; use database.reader() or database.writer()."
)
def reader(self):
return _require_database_connection().reader()
def writer(self):
return _require_database_connection().writer()
def reader_conn(self):
return _require_database_connection().reader_conn()
def writer_conn(self):
return _require_database_connection().writer_conn()
class DatabaseConnection:
def __init__(
self,
db_path: str | Path,
*,
pool_size: int = 4,
pragmas: dict[str, object] | None = None,
) -> None:
self.db_path = resolve_database_path(db_path)
self.pool_size = pool_size
self.pragmas = dict(DATABASE_PRAGMAS if pragmas is None else pragmas)
self.writer_db = ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas)
self.reader_dbs = tuple(
ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas)
for _ in range(pool_size)
)
self._reader_pool: queue.Queue[ManagedSqliteDatabase] = queue.Queue()
self._writer_lock = threading.RLock()
for reader_db in self.reader_dbs:
self._reader_pool.put(reader_db)
def initialize(self) -> Path:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.writer_db.connect(reuse_if_open=True)
try:
for path in schema_paths():
self.writer_db.connection().executescript(
path.read_text(encoding="utf-8")
)
_run_legacy_migrations(self.writer_db)
except Exception:
self.writer_db.close()
raise
for reader_db in self.reader_dbs:
reader_db.connect(reuse_if_open=True)
return self.db_path
def close(self) -> None:
for reader_db in self.reader_dbs:
if not reader_db.is_closed():
reader_db.close()
if not self.writer_db.is_closed():
self.writer_db.close()
@contextmanager
def reader(self) -> Iterator[ManagedSqliteDatabase]:
scope = _current_scope.get()
if scope in {"reader", "writer"}:
yield _require_active_database()
return
if scope == "reader_conn":
active_database = _require_active_database()
with active_database.atomic():
yield active_database
return
if scope == "writer_conn":
active_database = _require_active_database()
with active_database.atomic():
yield active_database
return
if scope == "writer":
yield _require_active_database()
return
leased_database = self._reader_pool.get()
database_token = _current_database.set(leased_database)
scope_token = _current_scope.set("reader")
try:
with leased_database.atomic():
yield leased_database
finally:
_current_scope.reset(scope_token)
_current_database.reset(database_token)
self._reader_pool.put(leased_database)
@contextmanager
def writer(self) -> Iterator[ManagedSqliteDatabase]:
scope = _current_scope.get()
if scope == "writer":
yield _require_active_database()
return
if scope == "writer_conn":
active_database = _require_active_database()
with active_database.atomic():
yield active_database
return
if scope in {"reader", "reader_conn"}:
raise RuntimeError(
"Cannot enter database.writer() inside database.reader()."
)
with self._writer_lock:
database_token = _current_database.set(self.writer_db)
scope_token = _current_scope.set("writer")
try:
with self.writer_db.atomic():
yield self.writer_db
finally:
_current_scope.reset(scope_token)
_current_database.reset(database_token)
@contextmanager
def reader_conn(self) -> Iterator[ManagedSqliteDatabase]:
scope = _current_scope.get()
if scope is not None:
yield _require_active_database()
return
leased_database = self._reader_pool.get()
database_token = _current_database.set(leased_database)
scope_token = _current_scope.set("reader_conn")
try:
yield leased_database
finally:
_current_scope.reset(scope_token)
_current_database.reset(database_token)
self._reader_pool.put(leased_database)
@contextmanager
def writer_conn(self) -> Iterator[ManagedSqliteDatabase]:
scope = _current_scope.get()
if scope in {"writer", "writer_conn"}:
yield _require_active_database()
return
if scope in {"reader", "reader_conn"}:
raise RuntimeError(
"Cannot enter database.writer_conn() inside database.reader()."
)
with self._writer_lock:
database_token = _current_database.set(self.writer_db)
scope_token = _current_scope.set("writer_conn")
try:
yield self.writer_db
finally:
_current_scope.reset(scope_token)
_current_database.reset(database_token)
def _run_legacy_migrations(database: ManagedSqliteDatabase) -> None:
job_columns = {column.name for column in database.get_columns("job")}
operations = []
migrator = SchemaMigrator.from_database(database)
if "convert_images" not in job_columns:
operations.extend(
(
migrator.add_column(
"job",
"convert_images",
BooleanField(
default=True,
constraints=[Check("convert_images IN (0, 1)")],
),
),
migrator.add_column_default("job", "convert_images", 1),
)
)
if "convert_video" not in job_columns:
operations.extend(
(
migrator.add_column(
"job",
"convert_video",
BooleanField(
default=True,
constraints=[Check("convert_video IN (0, 1)")],
),
),
migrator.add_column_default("job", "convert_video", 1),
)
)
if operations:
with database.atomic():
migrate(*operations)
def initialize_database(db_path: str | Path | None = None) -> Path:
global _database_connection
if _database_connection is not None:
_database_connection.close()
connection = DatabaseConnection(resolve_database_path(db_path))
resolved_path = connection.initialize()
_database_connection = connection
return resolved_path
def get_database_connection() -> DatabaseConnection | None:
return _database_connection
def _require_database_connection() -> DatabaseConnection:
database_connection = get_database_connection()
if database_connection is None:
raise RuntimeError("Database has not been initialized.")
return database_connection
def _require_active_database() -> ManagedSqliteDatabase:
active_database = _current_database.get()
if active_database is None:
raise RuntimeError(
"Database access requires a database.reader() or database.writer() context."
)
return active_database
database = RoutedSqliteDatabase()

View file

@ -309,7 +309,7 @@ def main(argv: list[str] | None = None) -> int:
def _load_job_source_config(*, db_path: str, job_id: int) -> JobSourceConfig: def _load_job_source_config(*, db_path: str, job_id: int) -> JobSourceConfig:
initialize_database(db_path) initialize_database(db_path)
primary_key = getattr(Job, "_meta").primary_key primary_key = getattr(Job, "_meta").primary_key
with database.connection_context(): with database.reader():
job = ( job = (
Job.select(Job, Source) Job.select(Job, Source)
.join(Source) .join(Source)

View file

@ -18,6 +18,7 @@ from apscheduler.triggers.cron import CronTrigger
from peewee import IntegrityError from peewee import IntegrityError
from repub.config import feed_output_dir, feed_output_path from repub.config import feed_output_dir, feed_output_path
from repub.db import get_database_connection
from repub.model import ( from repub.model import (
Job, Job,
JobExecution, JobExecution,
@ -173,7 +174,7 @@ class JobRuntime:
self._started = False self._started = False
def sync_jobs(self) -> None: def sync_jobs(self) -> None:
with database.connection_context(): with database.reader():
jobs = tuple(Job.select().where(Job.enabled == True)) # noqa: E712 jobs = tuple(Job.select().where(Job.enabled == True)) # noqa: E712
desired_ids = set() desired_ids = set()
@ -216,8 +217,7 @@ class JobRuntime:
return execution_id return execution_id
def _enqueue_job_run_locked(self, job_id: int) -> int | None: def _enqueue_job_run_locked(self, job_id: int) -> int | None:
with database.connection_context(): with database.writer():
with database.atomic():
job = Job.get_or_none(id=job_id) job = Job.get_or_none(id=job_id)
if job is None: if job is None:
return None return None
@ -259,14 +259,13 @@ class JobRuntime:
if claimed_execution is None: if claimed_execution is None:
return return
job = cast(Job, claimed_execution.job)
self._start_worker_for_execution( self._start_worker_for_execution(
job_id=_job_id(job), job_id=claimed_execution[0],
execution_id=_execution_id(claimed_execution), execution_id=claimed_execution[1],
) )
def _claim_next_pending_execution(self) -> JobExecution | None: def _claim_next_pending_execution(self) -> tuple[int, int] | None:
with database.connection_context(): with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key execution_primary_key = getattr(JobExecution, "_meta").primary_key
pending_executions = tuple( pending_executions = tuple(
JobExecution.select(JobExecution, Job) JobExecution.select(JobExecution, Job)
@ -301,10 +300,21 @@ class JobRuntime:
.execute() .execute()
) )
if claimed: if claimed:
return JobExecution.get_by_id(_execution_id(execution)) claimed_execution = (
JobExecution.select(JobExecution, Job)
.join(Job)
.where(execution_primary_key == _execution_id(execution))
.get()
)
claimed_job = cast(Job, claimed_execution.job)
return (_job_id(claimed_job), _execution_id(claimed_execution))
return None return None
def _start_worker_for_execution(self, *, job_id: int, execution_id: int) -> None: def _start_worker_for_execution(self, *, job_id: int, execution_id: int) -> None:
database_connection = get_database_connection()
if database_connection is None:
raise RuntimeError("Database has not been initialized.")
artifacts = JobArtifacts.for_execution( artifacts = JobArtifacts.for_execution(
log_dir=self.log_dir, job_id=job_id, execution_id=execution_id log_dir=self.log_dir, job_id=job_id, execution_id=execution_id
) )
@ -324,7 +334,7 @@ class JobRuntime:
"--execution-id", "--execution-id",
str(execution_id), str(execution_id),
"--db-path", "--db-path",
str(database.database), str(database_connection.db_path),
"--out-dir", "--out-dir",
str(self.log_dir.parent), str(self.log_dir.parent),
"--stats-path", "--stats-path",
@ -342,6 +352,7 @@ class JobRuntime:
) )
def _max_concurrent_jobs_reached(self) -> bool: def _max_concurrent_jobs_reached(self) -> bool:
with database.reader():
return ( return (
JobExecution.select() JobExecution.select()
.where(JobExecution.running_status == JobExecutionStatus.RUNNING) .where(JobExecution.running_status == JobExecutionStatus.RUNNING)
@ -350,7 +361,7 @@ class JobRuntime:
) )
def request_execution_cancel(self, execution_id: int) -> bool: def request_execution_cancel(self, execution_id: int) -> bool:
with database.connection_context(): with database.writer():
execution = JobExecution.get_or_none(id=execution_id) execution = JobExecution.get_or_none(id=execution_id)
if execution is None: if execution is None:
return False return False
@ -372,7 +383,7 @@ class JobRuntime:
def cancel_queued_execution(self, execution_id: int) -> bool: def cancel_queued_execution(self, execution_id: int) -> bool:
with self._run_lock: with self._run_lock:
with database.connection_context(): with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key execution_primary_key = getattr(JobExecution, "_meta").primary_key
deleted = ( deleted = (
JobExecution.delete() JobExecution.delete()
@ -392,7 +403,7 @@ class JobRuntime:
def move_queued_execution(self, execution_id: int, *, direction: str) -> bool: def move_queued_execution(self, execution_id: int, *, direction: str) -> bool:
offset = -1 if direction == "up" else 1 offset = -1 if direction == "up" else 1
with self._run_lock: with self._run_lock:
with database.connection_context(): with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key execution_primary_key = getattr(JobExecution, "_meta").primary_key
queued_executions = tuple( queued_executions = tuple(
JobExecution.select() JobExecution.select()
@ -425,7 +436,6 @@ class JobRuntime:
cast(datetime | str, target_execution.created_at) cast(datetime | str, target_execution.created_at)
) )
with database.atomic():
if current_created_at == target_created_at: if current_created_at == target_created_at:
adjusted_created_at = target_created_at + timedelta( adjusted_created_at = target_created_at + timedelta(
microseconds=-1 if offset < 0 else 1 microseconds=-1 if offset < 0 else 1
@ -433,8 +443,7 @@ class JobRuntime:
( (
JobExecution.update(created_at=adjusted_created_at) JobExecution.update(created_at=adjusted_created_at)
.where( .where(
execution_primary_key execution_primary_key == _execution_id(current_execution)
== _execution_id(current_execution)
) )
.execute() .execute()
) )
@ -442,16 +451,13 @@ class JobRuntime:
( (
JobExecution.update(created_at=target_created_at) JobExecution.update(created_at=target_created_at)
.where( .where(
execution_primary_key execution_primary_key == _execution_id(current_execution)
== _execution_id(current_execution)
) )
.execute() .execute()
) )
( (
JobExecution.update(created_at=current_created_at) JobExecution.update(created_at=current_created_at)
.where( .where(execution_primary_key == _execution_id(target_execution))
execution_primary_key == _execution_id(target_execution)
)
.execute() .execute()
) )
@ -459,8 +465,7 @@ class JobRuntime:
return True return True
def set_job_enabled(self, job_id: int, *, enabled: bool) -> bool: def set_job_enabled(self, job_id: int, *, enabled: bool) -> bool:
with database.connection_context(): with database.writer():
with database.atomic():
job = Job.get_or_none(id=job_id) job = Job.get_or_none(id=job_id)
if job is None: if job is None:
return False return False
@ -471,10 +476,7 @@ class JobRuntime:
JobExecution.delete() JobExecution.delete()
.where( .where(
(JobExecution.job == job) (JobExecution.job == job)
& ( & (JobExecution.running_status == JobExecutionStatus.PENDING)
JobExecution.running_status
== JobExecutionStatus.PENDING
)
) )
.execute() .execute()
) )
@ -493,7 +495,7 @@ class JobRuntime:
continue continue
self._apply_stats(worker) self._apply_stats(worker)
with database.connection_context(): with database.writer():
execution = JobExecution.get_by_id(execution_id) execution = JobExecution.get_by_id(execution_id)
execution.ended_at = utc_now() execution.ended_at = utc_now()
execution.running_status = _worker_final_status( execution.running_status = _worker_final_status(
@ -527,7 +529,7 @@ class JobRuntime:
return return
stats = json.loads(lines[-1]) stats = json.loads(lines[-1])
with database.connection_context(): with database.writer():
execution = JobExecution.get_by_id(worker.execution_id) execution = JobExecution.get_by_id(worker.execution_id)
execution.requests_count = int(stats.get("requests_count", 0)) execution.requests_count = int(stats.get("requests_count", 0))
execution.items_count = int(stats.get("items_count", 0)) execution.items_count = int(stats.get("items_count", 0))
@ -544,7 +546,7 @@ class JobRuntime:
self._trigger_refresh() self._trigger_refresh()
def _enforce_graceful_stop(self, worker: RunningWorker) -> None: def _enforce_graceful_stop(self, worker: RunningWorker) -> None:
with database.connection_context(): with database.reader():
execution = JobExecution.get_by_id(worker.execution_id) execution = JobExecution.get_by_id(worker.execution_id)
if execution.stop_requested_at is None: if execution.stop_requested_at is None:
return return
@ -572,6 +574,7 @@ class JobRuntime:
self._trigger_refresh() self._trigger_refresh()
def _has_running_executions(self) -> bool: def _has_running_executions(self) -> bool:
with database.reader():
return ( return (
JobExecution.select() JobExecution.select()
.where(JobExecution.running_status == JobExecutionStatus.RUNNING) .where(JobExecution.running_status == JobExecutionStatus.RUNNING)
@ -581,7 +584,7 @@ class JobRuntime:
def _reconcile_stale_executions(self) -> None: def _reconcile_stale_executions(self) -> None:
live_workers = _find_live_workers() live_workers = _find_live_workers()
recovered_execution_ids: set[int] = set() recovered_execution_ids: set[int] = set()
with database.connection_context(): with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key execution_primary_key = getattr(JobExecution, "_meta").primary_key
if live_workers: if live_workers:
live_executions = tuple( live_executions = tuple(
@ -669,7 +672,7 @@ def load_runs_view(
reference_time = now or datetime.now(UTC) reference_time = now or datetime.now(UTC)
resolved_log_dir = Path(log_dir) resolved_log_dir = Path(log_dir)
sanitized_page_size = max(1, completed_page_size) sanitized_page_size = max(1, completed_page_size)
with database.connection_context(): with database.reader():
execution_primary_key = getattr(JobExecution, "_meta").primary_key execution_primary_key = getattr(JobExecution, "_meta").primary_key
jobs = tuple(Job.select(Job, Source).join(Source).order_by(Source.name.asc())) jobs = tuple(Job.select(Job, Source).join(Source).order_by(Source.name.asc()))
queued_executions = tuple( queued_executions = tuple(
@ -724,7 +727,9 @@ def load_runs_view(
execution, execution,
resolved_log_dir, resolved_log_dir,
reference_time, reference_time,
queued_follow_up=queued_by_job.get(_job_id(cast(Job, execution.job))), queued_follow_up=queued_by_job.get(
_job_id(cast(Job, execution.job))
),
) )
for execution in running_executions for execution in running_executions
), ),
@ -747,7 +752,9 @@ def load_runs_view(
for job in jobs for job in jobs
), ),
"completed": tuple( "completed": tuple(
_project_completed_execution(execution, resolved_log_dir, reference_time) _project_completed_execution(
execution, resolved_log_dir, reference_time
)
for execution in completed_executions for execution in completed_executions
), ),
"completed_page": sanitized_completed_page, "completed_page": sanitized_completed_page,
@ -759,7 +766,7 @@ def load_runs_view(
def clear_completed_executions(*, log_dir: str | Path) -> int: def clear_completed_executions(*, log_dir: str | Path) -> int:
resolved_log_dir = Path(log_dir) resolved_log_dir = Path(log_dir)
with database.connection_context(): with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key execution_primary_key = getattr(JobExecution, "_meta").primary_key
completed_executions = tuple( completed_executions = tuple(
JobExecution.select(JobExecution, Job) JobExecution.select(JobExecution, Job)
@ -810,7 +817,7 @@ def load_dashboard_view(
upcoming_by_job_id = { upcoming_by_job_id = {
int(cast(int, job["job_id"])): job for job in runs_view["upcoming"] int(cast(int, job["job_id"])): job for job in runs_view["upcoming"]
} }
with database.connection_context(): with database.reader():
jobs = tuple(Job.select(Job, Source).join(Source).order_by(Source.name.asc())) jobs = tuple(Job.select(Job, Source).join(Source).order_by(Source.name.asc()))
failed_last_day = ( failed_last_day = (
JobExecution.select() JobExecution.select()
@ -820,7 +827,6 @@ def load_dashboard_view(
) )
.count() .count()
) )
upcoming_ready = sum( upcoming_ready = sum(
1 for job in runs_view["upcoming"] if str(job["run_reason"]) == "Ready" 1 for job in runs_view["upcoming"] if str(job["run_reason"]) == "Ready"
) )
@ -851,10 +857,16 @@ def load_dashboard_view(
def load_execution_log_view( def load_execution_log_view(
*, log_dir: str | Path, job_id: int, execution_id: int *, log_dir: str | Path, job_id: int, execution_id: int
) -> ExecutionLogView: ) -> ExecutionLogView:
with database.connection_context():
execution = JobExecution.get_or_none(id=execution_id)
route = f"/job/{job_id}/execution/{execution_id}/logs" route = f"/job/{job_id}/execution/{execution_id}/logs"
with database.reader():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
execution = (
JobExecution.select(JobExecution, Job)
.join(Job)
.where(execution_primary_key == execution_id)
.get_or_none()
)
if execution is None or _job_id(cast(Job, execution.job)) != job_id: if execution is None or _job_id(cast(Job, execution.job)) != job_id:
return ExecutionLogView( return ExecutionLogView(
job_id=job_id, job_id=job_id,

View file

@ -1,12 +1,8 @@
from __future__ import annotations from __future__ import annotations
import json import json
import os
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import IntEnum from enum import IntEnum
from importlib import resources
from importlib.resources.abc import Traversable
from pathlib import Path
from typing import Any from typing import Any
from peewee import ( from peewee import (
@ -16,29 +12,24 @@ from peewee import (
ForeignKeyField, ForeignKeyField,
IntegerField, IntegerField,
Model, Model,
SqliteDatabase,
TextField, TextField,
) )
from playhouse.migrate import SchemaMigrator, migrate
DEFAULT_DB_PATH = Path("republisher.db") from repub import db as db_module
DATABASE_PRAGMAS = {
"busy_timeout": 5000, DEFAULT_DB_PATH = db_module.DEFAULT_DB_PATH
"cache_size": 15625, DATABASE_PRAGMAS = db_module.DATABASE_PRAGMAS
"foreign_keys": 1, SCHEMA_GLOB = db_module.SCHEMA_GLOB
"journal_mode": "wal", database = db_module.database
"page_size": 4096, initialize_database = db_module.initialize_database
"synchronous": "normal", resolve_database_path = db_module.resolve_database_path
"temp_store": "memory", schema_paths = db_module.schema_paths
}
SCHEMA_GLOB = "*.sql"
MAX_CONCURRENT_JOBS_SETTING_KEY = "max_concurrent_jobs" MAX_CONCURRENT_JOBS_SETTING_KEY = "max_concurrent_jobs"
DEFAULT_MAX_CONCURRENT_JOBS = 1 DEFAULT_MAX_CONCURRENT_JOBS = 1
FEED_URL_SETTING_KEY = "feed_url" FEED_URL_SETTING_KEY = "feed_url"
DEFAULT_FEED_URL = "" DEFAULT_FEED_URL = ""
database = SqliteDatabase(None, pragmas=DATABASE_PRAGMAS)
class JobExecutionStatus(IntEnum): class JobExecutionStatus(IntEnum):
PENDING = 0 PENDING = 0
@ -52,91 +43,14 @@ def utc_now() -> datetime:
return datetime.now(UTC) return datetime.now(UTC)
def resolve_database_path(db_path: str | Path | None = None) -> Path:
raw_value = (
os.environ.get("REPUBLISHER_DB_PATH", DEFAULT_DB_PATH)
if db_path is None
else db_path
)
raw_path = Path(raw_value)
return raw_path.expanduser().resolve()
def schema_paths() -> tuple[Traversable, ...]:
schema_dir = resources.files("repub").joinpath("sql")
return tuple(
sorted(
(path for path in schema_dir.iterdir() if path.name.endswith(".sql")),
key=lambda path: path.name,
)
)
def initialize_database(db_path: str | Path | None = None) -> Path:
resolved_path = resolve_database_path(db_path)
resolved_path.parent.mkdir(parents=True, exist_ok=True)
if not database.is_closed():
database.close()
database.init(str(resolved_path), pragmas=DATABASE_PRAGMAS)
database.connect(reuse_if_open=True)
try:
for path in schema_paths():
database.connection().executescript(path.read_text(encoding="utf-8"))
_run_legacy_migrations()
finally:
database.close()
return resolved_path
def _run_legacy_migrations() -> None:
job_columns = {column.name for column in database.get_columns("job")}
operations = []
migrator = SchemaMigrator.from_database(database)
if "convert_images" not in job_columns:
operations.extend(
(
migrator.add_column(
"job",
"convert_images",
BooleanField(
default=True,
constraints=[Check("convert_images IN (0, 1)")],
),
),
migrator.add_column_default("job", "convert_images", 1),
)
)
if "convert_video" not in job_columns:
operations.extend(
(
migrator.add_column(
"job",
"convert_video",
BooleanField(
default=True,
constraints=[Check("convert_video IN (0, 1)")],
),
),
migrator.add_column_default("job", "convert_video", 1),
)
)
if operations:
with database.atomic():
migrate(*operations)
def source_slug_exists(slug: str) -> bool: def source_slug_exists(slug: str) -> bool:
with database.connection_context(): with database.reader():
return Source.select().where(Source.slug == slug).exists() return Source.select().where(Source.slug == slug).exists()
def save_setting(key: str, value: Any) -> None: def save_setting(key: str, value: Any) -> None:
payload = json.dumps(value, sort_keys=True) payload = json.dumps(value, sort_keys=True)
with database.connection_context(): with database.writer():
with database.atomic():
setting = AppSetting.get_or_none(AppSetting.key == key) setting = AppSetting.get_or_none(AppSetting.key == key)
if setting is None: if setting is None:
AppSetting.create(key=key, value=payload) AppSetting.create(key=key, value=payload)
@ -146,7 +60,7 @@ def save_setting(key: str, value: Any) -> None:
def load_setting(key: str, default: Any) -> Any: def load_setting(key: str, default: Any) -> Any:
with database.connection_context(): with database.reader():
setting = AppSetting.get_or_none(AppSetting.key == key) setting = AppSetting.get_or_none(AppSetting.key == key)
if setting is None: if setting is None:
return default return default
@ -177,8 +91,14 @@ def load_settings_form() -> dict[str, object]:
} }
def load_job_enabled(job_id: int) -> bool | None:
with database.reader():
job = Job.get_or_none(id=job_id)
return None if job is None else job.enabled
def load_source_form(slug: str) -> dict[str, object] | None: def load_source_form(slug: str) -> dict[str, object] | None:
with database.connection_context(): with database.reader():
source = Source.get_or_none(Source.slug == slug) source = Source.get_or_none(Source.slug == slug)
if source is None: if source is None:
return None return None
@ -259,8 +179,7 @@ def create_source(
include_content: bool = True, include_content: bool = True,
content_format: str = "", content_format: str = "",
) -> Source: ) -> Source:
with database.connection_context(): with database.writer():
with database.atomic():
source = Source.create( source = Source.create(
name=name, name=name,
slug=slug, slug=slug,
@ -329,8 +248,7 @@ def update_source(
include_content: bool = True, include_content: bool = True,
content_format: str = "", content_format: str = "",
) -> Source | None: ) -> Source | None:
with database.connection_context(): with database.writer():
with database.atomic():
source = Source.get_or_none(Source.slug == source_slug) source = Source.get_or_none(Source.slug == source_slug)
if source is None: if source is None:
return None return None
@ -394,8 +312,7 @@ def update_source(
def delete_job_source(job_id: int) -> bool: def delete_job_source(job_id: int) -> bool:
with database.connection_context(): with database.writer():
with database.atomic():
job = Job.get_or_none(id=job_id) job = Job.get_or_none(id=job_id)
if job is None: if job is None:
return False return False
@ -404,8 +321,7 @@ def delete_job_source(job_id: int) -> bool:
def delete_source(slug: str) -> bool: def delete_source(slug: str) -> bool:
with database.connection_context(): with database.writer():
with database.atomic():
source = Source.get_or_none(Source.slug == slug) source = Source.get_or_none(Source.slug == slug)
if source is None: if source is None:
return False return False
@ -413,7 +329,7 @@ def delete_source(slug: str) -> bool:
def load_sources() -> tuple[dict[str, object], ...]: def load_sources() -> tuple[dict[str, object], ...]:
with database.connection_context(): with database.reader():
sources = tuple(Source.select().order_by(Source.created_at.desc())) sources = tuple(Source.select().order_by(Source.created_at.desc()))
source_ids = tuple(int(source.get_id()) for source in sources) source_ids = tuple(int(source.get_id()) for source in sources)
if not source_ids: if not source_ids:

View file

@ -27,11 +27,11 @@ from repub.jobs import (
load_runs_view, load_runs_view,
) )
from repub.model import ( from repub.model import (
Job,
create_source, create_source,
delete_job_source, delete_job_source,
delete_source, delete_source,
initialize_database, initialize_database,
load_job_enabled,
load_settings_form, load_settings_form,
load_source_form, load_source_form,
load_sources, load_sources,
@ -329,9 +329,9 @@ def create_app(*, dev_mode: bool = False) -> Quart:
@app.post("/actions/jobs/<int:job_id>/toggle-enabled") @app.post("/actions/jobs/<int:job_id>/toggle-enabled")
async def toggle_job_enabled_action(job_id: int) -> Response: async def toggle_job_enabled_action(job_id: int) -> Response:
job = Job.get_or_none(id=job_id) enabled = load_job_enabled(job_id)
if job is not None: if enabled is not None:
get_job_runtime(app).set_job_enabled(job_id, enabled=not job.enabled) get_job_runtime(app).set_job_enabled(job_id, enabled=not enabled)
trigger_refresh(app) trigger_refresh(app)
return Response(status=204) return Response(status=204)

111
tests/test_db.py Normal file
View file

@ -0,0 +1,111 @@
from __future__ import annotations
import threading
import time
from pathlib import Path
import pytest
from peewee import InterfaceError
from repub.db import get_database_connection
from repub.model import AppSetting, database, initialize_database
def test_queries_require_managed_database_context(tmp_path: Path) -> None:
initialize_database(tmp_path / "managed-context.db")
with pytest.raises(
RuntimeError, match="database.reader\\(\\)|database.writer\\(\\)"
):
AppSetting.select().count()
def test_writer_and_reader_contexts_allow_persisted_queries(tmp_path: Path) -> None:
initialize_database(tmp_path / "reader-writer.db")
with database.writer():
AppSetting.create(key="feed_url", value='"https://mirror.example"')
with database.reader():
setting = AppSetting.get(AppSetting.key == "feed_url")
assert setting.value == '"https://mirror.example"'
def test_managed_connections_disable_peewee_autoconnect(tmp_path: Path) -> None:
initialize_database(tmp_path / "autoconnect-disabled.db")
connection = get_database_connection()
assert connection is not None
assert connection.writer_db.autoconnect is False
assert all(reader_db.autoconnect is False for reader_db in connection.reader_dbs)
reader_db = connection.reader_dbs[0]
reader_db.close()
with pytest.raises(InterfaceError, match="database connection not opened"):
reader_db.execute_sql("SELECT 1")
def test_database_connection_initializes_four_readers_and_one_writer(
tmp_path: Path,
) -> None:
initialize_database(tmp_path / "pool-shape.db")
connection = get_database_connection()
assert connection is not None
assert connection.pool_size == 4
assert len(connection.reader_dbs) == 4
assert connection._reader_pool.qsize() == 4
assert connection.writer_db is not None
def test_reader_lease_is_returned_to_the_pool_after_use(tmp_path: Path) -> None:
initialize_database(tmp_path / "reader-lease.db")
connection = get_database_connection()
assert connection is not None
initial_size = connection._reader_pool.qsize()
with database.reader():
assert connection._reader_pool.qsize() == initial_size - 1
assert connection._reader_pool.qsize() == initial_size
def test_writer_contexts_serialize_through_the_single_writer(tmp_path: Path) -> None:
initialize_database(tmp_path / "single-writer.db")
events: list[str] = []
entered_first_writer = threading.Event()
allow_first_writer_to_exit = threading.Event()
def first_writer() -> None:
with database.writer():
events.append("first-entered")
entered_first_writer.set()
allow_first_writer_to_exit.wait(timeout=1)
events.append("first-exiting")
def second_writer() -> None:
entered_first_writer.wait(timeout=1)
with database.writer():
events.append("second-entered")
first_thread = threading.Thread(target=first_writer)
second_thread = threading.Thread(target=second_writer)
first_thread.start()
second_thread.start()
assert entered_first_writer.wait(timeout=1) is True
time.sleep(0.05)
assert events == ["first-entered"]
allow_first_writer_to_exit.set()
first_thread.join(timeout=1)
second_thread.join(timeout=1)
assert events == ["first-entered", "first-exiting", "second-entered"]

View file

@ -9,6 +9,7 @@ from repub.model import (
JobExecution, JobExecution,
JobExecutionStatus, JobExecutionStatus,
create_source, create_source,
database,
initialize_database, initialize_database,
) )
@ -31,6 +32,7 @@ def test_load_runs_view_humanizes_completed_execution_summary_bytes(
cron_month="*", cron_month="*",
feed_url="https://example.com/completed.xml", feed_url="https://example.com/completed.xml",
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
JobExecution.create( JobExecution.create(
job=job, job=job,
@ -67,6 +69,7 @@ def test_load_runs_view_projects_completed_execution_duration(
cron_month="*", cron_month="*",
feed_url="https://example.com/completed.xml", feed_url="https://example.com/completed.xml",
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
JobExecution.create( JobExecution.create(
job=job, job=job,
@ -101,6 +104,7 @@ def test_load_runs_view_humanizes_running_execution_summary_bytes(
cron_month="*", cron_month="*",
feed_url="https://example.com/running.xml", feed_url="https://example.com/running.xml",
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
JobExecution.create( JobExecution.create(
job=job, job=job,
@ -137,6 +141,7 @@ def test_load_runs_view_projects_running_execution_duration(
cron_month="*", cron_month="*",
feed_url="https://example.com/running.xml", feed_url="https://example.com/running.xml",
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
JobExecution.create( JobExecution.create(
job=job, job=job,
@ -184,11 +189,12 @@ def test_load_runs_view_projects_queued_executions_in_fifo_order(
cron_month="*", cron_month="*",
feed_url="https://example.com/second.xml", feed_url="https://example.com/second.xml",
) )
first_job = Job.get(Job.source == first_source)
second_job = Job.get(Job.source == second_source)
reference_time = datetime(2026, 3, 30, 12, 30, tzinfo=UTC) reference_time = datetime(2026, 3, 30, 12, 30, tzinfo=UTC)
first_created_at = reference_time - timedelta(minutes=7) first_created_at = reference_time - timedelta(minutes=7)
second_created_at = reference_time - timedelta(minutes=3) second_created_at = reference_time - timedelta(minutes=3)
with database.writer():
first_job = Job.get(Job.source == first_source)
second_job = Job.get(Job.source == second_source)
first_execution = JobExecution.create( first_execution = JobExecution.create(
job=first_job, job=first_job,
created_at=first_created_at, created_at=first_created_at,
@ -258,6 +264,7 @@ def test_load_runs_view_keeps_queued_jobs_in_scheduled_jobs(
cron_month="*", cron_month="*",
feed_url="https://example.com/scheduled.xml", feed_url="https://example.com/scheduled.xml",
) )
with database.writer():
queued_job = Job.get(Job.source == queued_source) queued_job = Job.get(Job.source == queued_source)
Job.get(Job.source == scheduled_source) Job.get(Job.source == scheduled_source)
JobExecution.create( JobExecution.create(
@ -299,6 +306,7 @@ def test_load_runs_view_running_row_targets_queued_follow_up_cancel(
cron_month="*", cron_month="*",
feed_url="https://example.com/running.xml", feed_url="https://example.com/running.xml",
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
JobExecution.create( JobExecution.create(
job=job, job=job,
@ -341,6 +349,7 @@ def test_load_runs_view_paginates_completed_executions_after_20_rows(
cron_month="*", cron_month="*",
feed_url="https://example.com/completed.xml", feed_url="https://example.com/completed.xml",
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
base_time = datetime(2026, 3, 30, 12, 0, tzinfo=UTC) base_time = datetime(2026, 3, 30, 12, 0, tzinfo=UTC)
for offset in range(21): for offset in range(21):

View file

@ -112,8 +112,7 @@ def test_initialize_database_configures_sqlite_pragmas(tmp_path: Path) -> None:
initialize_database(db_path) initialize_database(db_path)
database.connect(reuse_if_open=True) with database.reader_conn():
try:
pragma_values = { pragma_values = {
"cache_size": database.execute_sql("PRAGMA cache_size").fetchone()[0], "cache_size": database.execute_sql("PRAGMA cache_size").fetchone()[0],
"page_size": database.execute_sql("PRAGMA page_size").fetchone()[0], "page_size": database.execute_sql("PRAGMA page_size").fetchone()[0],
@ -132,8 +131,6 @@ def test_initialize_database_configures_sqlite_pragmas(tmp_path: Path) -> None:
"foreign_keys": 1, "foreign_keys": 1,
"busy_timeout": 5000, "busy_timeout": 5000,
} }
finally:
database.close()
def test_initialize_database_creates_scheduler_and_execution_indexes( def test_initialize_database_creates_scheduler_and_execution_indexes(
@ -208,6 +205,7 @@ def test_initialize_database_creates_run_queue_indexes(tmp_path: Path) -> None:
def test_job_table_allows_exactly_one_job_per_source(tmp_path: Path) -> None: def test_job_table_allows_exactly_one_job_per_source(tmp_path: Path) -> None:
initialize_database(tmp_path / "jobs.db") initialize_database(tmp_path / "jobs.db")
with database.writer():
source = Source.create( source = Source.create(
name="Guardian feed mirror", name="Guardian feed mirror",
slug="guardian-feed", slug="guardian-feed",
@ -248,6 +246,7 @@ def test_save_setting_persists_json_value(tmp_path: Path) -> None:
save_setting("max_concurrent_jobs", 4) save_setting("max_concurrent_jobs", 4)
with database.reader():
row = AppSetting.get(AppSetting.key == "max_concurrent_jobs") row = AppSetting.get(AppSetting.key == "max_concurrent_jobs")
assert row.value == "4" assert row.value == "4"

View file

@ -19,6 +19,7 @@ from repub.model import (
JobExecutionStatus, JobExecutionStatus,
Source, Source,
create_source, create_source,
database,
initialize_database, initialize_database,
save_setting, save_setting,
) )
@ -29,6 +30,16 @@ FIXTURE_FEED_PATH = (
).resolve() ).resolve()
def _db_reader(callable_):
with database.reader():
return callable_()
def _db_writer(callable_):
with database.writer():
return callable_()
def initialize_runtime_database(db_path: Path) -> None: def initialize_runtime_database(db_path: Path) -> None:
initialize_database(db_path) initialize_database(db_path)
save_setting("feed_url", "http://localhost:8080") save_setting("feed_url", "http://localhost:8080")
@ -64,6 +75,7 @@ def test_job_runtime_syncs_enabled_jobs_into_apscheduler(tmp_path: Path) -> None
cron_month="*", cron_month="*",
feed_url="https://example.com/disabled.xml", feed_url="https://example.com/disabled.xml",
) )
with database.reader():
enabled_job = Job.get(Job.source == enabled_source) enabled_job = Job.get(Job.source == enabled_source)
disabled_job = Job.get(Job.source == disabled_source) disabled_job = Job.get(Job.source == disabled_source)
@ -77,6 +89,8 @@ def test_job_runtime_syncs_enabled_jobs_into_apscheduler(tmp_path: Path) -> None
assert f"job-{enabled_job.id}" in scheduled_ids assert f"job-{enabled_job.id}" in scheduled_ids
assert f"job-{disabled_job.id}" not in scheduled_ids assert f"job-{disabled_job.id}" not in scheduled_ids
with database.writer():
enabled_job = Job.get_by_id(enabled_job.id)
enabled_job.enabled = False enabled_job.enabled = False
enabled_job.save() enabled_job.save()
runtime.sync_jobs() runtime.sync_jobs()
@ -105,6 +119,7 @@ def test_job_runtime_run_now_writes_log_and_stats_and_marks_success(
cron_month="*", cron_month="*",
feed_url=FIXTURE_FEED_PATH.as_uri(), feed_url=FIXTURE_FEED_PATH.as_uri(),
) )
with database.reader():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
runtime = JobRuntime(log_dir=tmp_path / "out" / "logs") runtime = JobRuntime(log_dir=tmp_path / "out" / "logs")
@ -178,6 +193,7 @@ def test_job_runtime_respects_max_concurrent_jobs_setting(tmp_path: Path) -> Non
cron_month="*", cron_month="*",
feed_url=feed_url, feed_url=feed_url,
) )
with database.reader():
first_job = Job.get(Job.source == first_source) first_job = Job.get(Job.source == first_source)
second_job = Job.get(Job.source == second_source) second_job = Job.get(Job.source == second_source)
@ -197,16 +213,20 @@ def test_job_runtime_respects_max_concurrent_jobs_setting(tmp_path: Path) -> Non
JobExecutionStatus.PENDING, JobExecutionStatus.PENDING,
) )
assert ( assert (
JobExecution.select() _db_reader(
lambda: JobExecution.select()
.where(JobExecution.running_status == JobExecutionStatus.RUNNING) .where(JobExecution.running_status == JobExecutionStatus.RUNNING)
.count() .count()
)
== 1 == 1
) )
assert second_execution.started_at is None assert second_execution.started_at is None
assert ( assert (
JobExecution.select() _db_reader(
lambda: JobExecution.select()
.where(JobExecution.running_status == JobExecutionStatus.PENDING) .where(JobExecution.running_status == JobExecutionStatus.PENDING)
.count() .count()
)
== 1 == 1
) )
runtime.request_execution_cancel(first_execution_id) runtime.request_execution_cancel(first_execution_id)
@ -253,6 +273,7 @@ def test_job_runtime_starts_queued_execution_after_capacity_opens(
cron_month="*", cron_month="*",
feed_url=FIXTURE_FEED_PATH.as_uri(), feed_url=FIXTURE_FEED_PATH.as_uri(),
) )
with database.reader():
first_job = Job.get(Job.source == first_source) first_job = Job.get(Job.source == first_source)
second_job = Job.get(Job.source == second_source) second_job = Job.get(Job.source == second_source)
@ -314,6 +335,7 @@ def test_job_runtime_deduplicates_manual_queue_requests(tmp_path: Path) -> None:
cron_month="*", cron_month="*",
feed_url="https://example.com/queued.xml", feed_url="https://example.com/queued.xml",
) )
with database.reader():
blocking_job = Job.get(Job.source == blocking_source) blocking_job = Job.get(Job.source == blocking_source)
queued_job = Job.get(Job.source == queued_source) queued_job = Job.get(Job.source == queued_source)
@ -332,12 +354,14 @@ def test_job_runtime_deduplicates_manual_queue_requests(tmp_path: Path) -> None:
assert first_pending_id is not None assert first_pending_id is not None
assert second_pending_id == first_pending_id assert second_pending_id == first_pending_id
assert ( assert (
JobExecution.select() _db_reader(
lambda: JobExecution.select()
.where( .where(
(JobExecution.job == queued_job) (JobExecution.job == queued_job)
& (JobExecution.running_status == JobExecutionStatus.PENDING) & (JobExecution.running_status == JobExecutionStatus.PENDING)
) )
.count() .count()
)
== 1 == 1
) )
finally: finally:
@ -367,6 +391,7 @@ def test_job_runtime_allows_one_running_and_one_pending_per_job(
cron_month="*", cron_month="*",
feed_url=feed_url, feed_url=feed_url,
) )
with database.reader():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
runtime = JobRuntime(log_dir=log_dir) runtime = JobRuntime(log_dir=log_dir)
@ -383,17 +408,21 @@ def test_job_runtime_allows_one_running_and_one_pending_per_job(
assert pending_execution_id is not None assert pending_execution_id is not None
assert duplicate_pending_id == pending_execution_id assert duplicate_pending_id == pending_execution_id
assert ( assert (
JobExecution.select() _db_reader(
lambda: JobExecution.select()
.where(JobExecution.job == job) .where(JobExecution.job == job)
.where(JobExecution.running_status == JobExecutionStatus.RUNNING) .where(JobExecution.running_status == JobExecutionStatus.RUNNING)
.count() .count()
)
== 1 == 1
) )
assert ( assert (
JobExecution.select() _db_reader(
lambda: JobExecution.select()
.where(JobExecution.job == job) .where(JobExecution.job == job)
.where(JobExecution.running_status == JobExecutionStatus.PENDING) .where(JobExecution.running_status == JobExecutionStatus.PENDING)
.count() .count()
)
== 1 == 1
) )
finally: finally:
@ -420,6 +449,7 @@ def test_job_runtime_start_drains_pending_rows_created_before_start(
cron_month="*", cron_month="*",
feed_url=FIXTURE_FEED_PATH.as_uri(), feed_url=FIXTURE_FEED_PATH.as_uri(),
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
pending_execution = JobExecution.create( pending_execution = JobExecution.create(
job=job, job=job,
@ -477,6 +507,7 @@ def test_job_runtime_scheduled_runs_use_the_persistent_queue(
cron_month="*", cron_month="*",
feed_url="https://example.com/second-scheduled.xml", feed_url="https://example.com/second-scheduled.xml",
) )
with database.reader():
first_job = Job.get(Job.source == first_source) first_job = Job.get(Job.source == first_source)
second_job = Job.get(Job.source == second_source) second_job = Job.get(Job.source == second_source)
@ -484,11 +515,15 @@ def test_job_runtime_scheduled_runs_use_the_persistent_queue(
try: try:
runtime.start() runtime.start()
runtime.run_scheduled_job(first_job.id) runtime.run_scheduled_job(first_job.id)
first_execution = JobExecution.get(JobExecution.job == first_job) first_execution = _db_reader(
lambda: JobExecution.get(JobExecution.job == first_job)
)
_wait_for_running_execution(int(first_execution.get_id())) _wait_for_running_execution(int(first_execution.get_id()))
runtime.run_scheduled_job(second_job.id) runtime.run_scheduled_job(second_job.id)
second_execution = JobExecution.get(JobExecution.job == second_job) second_execution = _db_reader(
lambda: JobExecution.get(JobExecution.job == second_job)
)
assert second_execution.running_status == JobExecutionStatus.PENDING assert second_execution.running_status == JobExecutionStatus.PENDING
assert second_execution.started_at is None assert second_execution.started_at is None
@ -519,6 +554,7 @@ def test_job_runtime_cancel_pending_follow_up_keeps_running_worker_alive(
cron_month="*", cron_month="*",
feed_url=feed_url, feed_url=feed_url,
) )
with database.reader():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
runtime = JobRuntime(log_dir=log_dir) runtime = JobRuntime(log_dir=log_dir)
@ -533,9 +569,14 @@ def test_job_runtime_cancel_pending_follow_up_keeps_running_worker_alive(
_wait_for_execution_status(pending_execution_id, JobExecutionStatus.PENDING) _wait_for_execution_status(pending_execution_id, JobExecutionStatus.PENDING)
assert runtime.cancel_queued_execution(pending_execution_id) is True assert runtime.cancel_queued_execution(pending_execution_id) is True
assert JobExecution.get_or_none(id=pending_execution_id) is None
assert ( assert (
JobExecution.get_by_id(running_execution_id).running_status _db_reader(lambda: JobExecution.get_or_none(id=pending_execution_id))
is None
)
assert (
_db_reader(
lambda: JobExecution.get_by_id(running_execution_id).running_status
)
== JobExecutionStatus.RUNNING == JobExecutionStatus.RUNNING
) )
finally: finally:
@ -559,6 +600,7 @@ def test_job_runtime_cancel_marks_execution_canceled(tmp_path: Path) -> None:
cron_month="*", cron_month="*",
feed_url=feed_url, feed_url=feed_url,
) )
with database.reader():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
runtime = JobRuntime(log_dir=tmp_path / "out" / "logs") runtime = JobRuntime(log_dir=tmp_path / "out" / "logs")
@ -602,6 +644,7 @@ def test_job_runtime_start_reconciles_stale_running_execution(tmp_path: Path) ->
cron_month="*", cron_month="*",
feed_url="https://example.com/stale.xml", feed_url="https://example.com/stale.xml",
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
execution = JobExecution.create( execution = JobExecution.create(
job=job, job=job,
@ -622,7 +665,9 @@ def test_job_runtime_start_reconciles_stale_running_execution(tmp_path: Path) ->
runtime = JobRuntime(log_dir=tmp_path / "out" / "logs") runtime = JobRuntime(log_dir=tmp_path / "out" / "logs")
try: try:
runtime.start() runtime.start()
reconciled_execution = JobExecution.get_by_id(execution.get_id()) reconciled_execution = _db_reader(
lambda: JobExecution.get_by_id(execution.get_id())
)
assert reconciled_execution.running_status == JobExecutionStatus.FAILED assert reconciled_execution.running_status == JobExecutionStatus.FAILED
assert reconciled_execution.ended_at is not None assert reconciled_execution.ended_at is not None
@ -649,6 +694,7 @@ def test_job_runtime_publishes_refresh_while_jobs_are_running(tmp_path: Path) ->
cron_month="*", cron_month="*",
feed_url="https://example.com/running.xml", feed_url="https://example.com/running.xml",
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
JobExecution.create( JobExecution.create(
job=job, job=job,
@ -688,6 +734,7 @@ def test_job_runtime_start_reattaches_live_worker_after_app_restart(
cron_month="*", cron_month="*",
feed_url=feed_url, feed_url=feed_url,
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
execution = JobExecution.create( execution = JobExecution.create(
job=job, job=job,
@ -728,7 +775,9 @@ def test_job_runtime_start_reattaches_live_worker_after_app_restart(
time.sleep(0.1) time.sleep(0.1)
runtime.start() runtime.start()
running_execution = JobExecution.get_by_id(execution.get_id()) running_execution = _db_reader(
lambda: JobExecution.get_by_id(execution.get_id())
)
assert running_execution.running_status == JobExecutionStatus.RUNNING assert running_execution.running_status == JobExecutionStatus.RUNNING
assert running_execution.ended_at is None assert running_execution.ended_at is None
@ -764,6 +813,7 @@ def test_job_runtime_start_restores_live_worker_marked_failed_by_restart_bug(
cron_month="*", cron_month="*",
feed_url=feed_url, feed_url=feed_url,
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
execution = JobExecution.create( execution = JobExecution.create(
job=job, job=job,
@ -805,7 +855,9 @@ def test_job_runtime_start_restores_live_worker_marked_failed_by_restart_bug(
time.sleep(0.1) time.sleep(0.1)
runtime.start() runtime.start()
restored_execution = JobExecution.get_by_id(execution.get_id()) restored_execution = _db_reader(
lambda: JobExecution.get_by_id(execution.get_id())
)
assert restored_execution.running_status == JobExecutionStatus.RUNNING assert restored_execution.running_status == JobExecutionStatus.RUNNING
assert restored_execution.ended_at is None assert restored_execution.ended_at is None
@ -895,6 +947,7 @@ def test_load_runs_view_humanizes_completed_execution_end_time(
cron_month="*", cron_month="*",
feed_url="https://example.com/completed.xml", feed_url="https://example.com/completed.xml",
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
reference_time = datetime(2026, 1, 15, 12, 0, tzinfo=UTC) reference_time = datetime(2026, 1, 15, 12, 0, tzinfo=UTC)
ended_at = reference_time - timedelta(hours=2) ended_at = reference_time - timedelta(hours=2)
@ -934,6 +987,7 @@ def test_load_runs_view_humanizes_running_execution_start_time(
cron_month="*", cron_month="*",
feed_url="https://example.com/running.xml", feed_url="https://example.com/running.xml",
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
reference_time = datetime(2026, 1, 15, 12, 0, tzinfo=UTC) reference_time = datetime(2026, 1, 15, 12, 0, tzinfo=UTC)
started_at = reference_time - timedelta(hours=2) started_at = reference_time - timedelta(hours=2)
@ -974,6 +1028,7 @@ def test_render_runs_uses_database_backed_jobs_and_executions(
cron_month="*", cron_month="*",
feed_url=FIXTURE_FEED_PATH.as_uri(), feed_url=FIXTURE_FEED_PATH.as_uri(),
) )
with database.reader():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
runtime = get_job_runtime(app) runtime = get_job_runtime(app)
runtime.start() runtime.start()
@ -1021,6 +1076,7 @@ def test_render_execution_logs_handles_missing_execution_and_missing_log_file(
cron_month="*", cron_month="*",
feed_url="https://example.com/log-source.xml", feed_url="https://example.com/log-source.xml",
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
execution = JobExecution.create( execution = JobExecution.create(
job=job, job=job,
@ -1067,6 +1123,7 @@ def test_delete_job_action_removes_source_job_and_execution_history(
cron_month="*", cron_month="*",
feed_url="https://example.com/delete.xml", feed_url="https://example.com/delete.xml",
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
execution = JobExecution.create( execution = JobExecution.create(
job=job, job=job,
@ -1076,9 +1133,15 @@ def test_delete_job_action_removes_source_job_and_execution_history(
response = await client.post(f"/actions/jobs/{job.id}/delete") response = await client.post(f"/actions/jobs/{job.id}/delete")
assert response.status_code == 204 assert response.status_code == 204
assert Source.get_or_none(Source.slug == "delete-source") is None assert (
assert Job.get_or_none(id=job.id) is None _db_reader(lambda: Source.get_or_none(Source.slug == "delete-source"))
assert JobExecution.get_or_none(id=int(execution.get_id())) is None is None
)
assert _db_reader(lambda: Job.get_or_none(id=job.id)) is None
assert (
_db_reader(lambda: JobExecution.get_or_none(id=int(execution.get_id())))
is None
)
asyncio.run(run()) asyncio.run(run())
@ -1107,6 +1170,7 @@ def test_delete_source_action_removes_source_job_and_execution_history(
cron_month="*", cron_month="*",
feed_url="https://example.com/delete-source-row.xml", feed_url="https://example.com/delete-source-row.xml",
) )
with database.writer():
job = Job.get(Job.source == source) job = Job.get(Job.source == source)
execution = JobExecution.create( execution = JobExecution.create(
job=job, job=job,
@ -1116,9 +1180,15 @@ def test_delete_source_action_removes_source_job_and_execution_history(
response = await client.post("/actions/sources/delete-source-row/delete") response = await client.post("/actions/sources/delete-source-row/delete")
assert response.status_code == 204 assert response.status_code == 204
assert Source.get_or_none(Source.slug == "delete-source-row") is None assert (
assert Job.get_or_none(id=job.id) is None _db_reader(lambda: Source.get_or_none(Source.slug == "delete-source-row"))
assert JobExecution.get_or_none(id=int(execution.get_id())) is None is None
)
assert _db_reader(lambda: Job.get_or_none(id=job.id)) is None
assert (
_db_reader(lambda: JobExecution.get_or_none(id=int(execution.get_id())))
is None
)
asyncio.run(run()) asyncio.run(run())
@ -1128,7 +1198,7 @@ def _wait_for_running_execution(
) -> JobExecution: ) -> JobExecution:
deadline = time.monotonic() + timeout_seconds deadline = time.monotonic() + timeout_seconds
while time.monotonic() < deadline: while time.monotonic() < deadline:
execution = JobExecution.get_by_id(execution_id) execution = _db_reader(lambda: JobExecution.get_by_id(execution_id))
if execution.running_status == JobExecutionStatus.RUNNING: if execution.running_status == JobExecutionStatus.RUNNING:
return execution return execution
time.sleep(0.02) time.sleep(0.02)
@ -1143,7 +1213,7 @@ def _wait_for_execution_status(
) -> JobExecution: ) -> JobExecution:
deadline = time.monotonic() + timeout_seconds deadline = time.monotonic() + timeout_seconds
while time.monotonic() < deadline: while time.monotonic() < deadline:
execution = JobExecution.get_by_id(execution_id) execution = _db_reader(lambda: JobExecution.get_by_id(execution_id))
if execution.running_status == status: if execution.running_status == status:
return execution return execution
time.sleep(0.02) time.sleep(0.02)
@ -1155,7 +1225,7 @@ def _wait_for_terminal_execution(
) -> JobExecution: ) -> JobExecution:
deadline = time.monotonic() + timeout_seconds deadline = time.monotonic() + timeout_seconds
while time.monotonic() < deadline: while time.monotonic() < deadline:
execution = JobExecution.get_by_id(execution_id) execution = _db_reader(lambda: JobExecution.get_by_id(execution_id))
if execution.running_status in { if execution.running_status in {
JobExecutionStatus.SUCCEEDED, JobExecutionStatus.SUCCEEDED,
JobExecutionStatus.FAILED, JobExecutionStatus.FAILED,

View file

@ -21,6 +21,7 @@ from repub.model import (
SourceFeed, SourceFeed,
SourcePangea, SourcePangea,
create_source, create_source,
database,
load_max_concurrent_jobs, load_max_concurrent_jobs,
load_settings_form, load_settings_form,
save_setting, save_setting,
@ -43,6 +44,35 @@ from repub.web import (
) )
def _db_reader(fn):
with database.reader():
return fn()
def _db_writer(fn):
with database.writer():
return fn()
def test_web_routes_do_not_access_peewee_models_directly() -> None:
web_source = Path("repub/web.py").read_text(encoding="utf-8")
assert (
re.search(
r"\b(Job|Source|JobExecution|SourceFeed|SourcePangea)\.get",
web_source,
)
is None
)
assert (
re.search(
r"\b(Job|Source|JobExecution|SourceFeed|SourcePangea)\.select",
web_source,
)
is None
)
def test_status_badge_uses_green_done_tone() -> None: def test_status_badge_uses_green_done_tone() -> None:
badge = str(status_badge(label="Succeeded", tone="done")) badge = str(status_badge(label="Succeeded", tone="done"))
@ -790,8 +820,12 @@ def test_load_dashboard_view_lists_source_feed_artifacts(
updated_at = reference_time - timedelta(minutes=32) updated_at = reference_time - timedelta(minutes=32)
updated_at_epoch = updated_at.timestamp() updated_at_epoch = updated_at.timestamp()
os.utime(feed_path, (updated_at_epoch, updated_at_epoch)) os.utime(feed_path, (updated_at_epoch, updated_at_epoch))
available_job = Job.get(Job.source == available_source) available_job, missing_job = _db_reader(
missing_job = Job.get(Job.source == missing_source) lambda: (
Job.get(Job.source == available_source),
Job.get(Job.source == missing_source),
)
)
source_feeds = cast( source_feeds = cast(
tuple[dict[str, object], ...], tuple[dict[str, object], ...],
@ -871,16 +905,18 @@ def test_load_dashboard_view_projects_feed_status_from_job_runtime(
feed_url="https://example.com/queued.xml", feed_url="https://example.com/queued.xml",
) )
running_job = Job.get(Job.source == running_source) _db_writer(
queued_job = Job.get(Job.source == queued_source) lambda: (
JobExecution.create( JobExecution.create(
job=running_job, job=Job.get(Job.source == running_source),
running_status=JobExecutionStatus.RUNNING, running_status=JobExecutionStatus.RUNNING,
started_at=reference_time - timedelta(minutes=2), started_at=reference_time - timedelta(minutes=2),
) ),
JobExecution.create( JobExecution.create(
job=queued_job, job=Job.get(Job.source == queued_source),
running_status=JobExecutionStatus.PENDING, running_status=JobExecutionStatus.PENDING,
),
)
) )
source_feeds = cast( source_feeds = cast(
@ -938,8 +974,12 @@ def test_render_dashboard_shows_source_feed_links_and_statuses(
published_feed = tmp_path / "out" / "feeds" / "published-source" / "feed.rss" published_feed = tmp_path / "out" / "feeds" / "published-source" / "feed.rss"
published_feed.parent.mkdir(parents=True) published_feed.parent.mkdir(parents=True)
published_feed.write_text("<rss/>\n", encoding="utf-8") published_feed.write_text("<rss/>\n", encoding="utf-8")
published_job = Job.get(Job.source == published_source) published_job, missing_job = _db_reader(
missing_job = Job.get(Job.source == missing_source) lambda: (
Job.get(Job.source == published_source),
Job.get(Job.source == missing_source),
)
)
body = str(await render_dashboard(app)) body = str(await render_dashboard(app))
@ -1253,9 +1293,15 @@ def test_create_source_action_creates_pangea_source_and_job_in_database(
assert response.status_code == 200 assert response.status_code == 200
assert "window.location = '/sources'" in body assert "window.location = '/sources'" in body
source = Source.get(Source.slug == "kenya-health") source, pangea, job = _db_reader(
pangea = SourcePangea.get(SourcePangea.source == source) lambda: (
job = Job.get(Job.source == source) Source.get(Source.slug == "kenya-health"),
SourcePangea.get(
SourcePangea.source == Source.get(Source.slug == "kenya-health")
),
Job.get(Job.source == Source.get(Source.slug == "kenya-health")),
)
)
rendered_sources = str(await render_sources(app)) rendered_sources = str(await render_sources(app))
assert source.name == "Kenya health desk" assert source.name == "Kenya health desk"
@ -1307,9 +1353,15 @@ def test_create_source_action_creates_feed_source_and_job_in_database(
assert response.status_code == 200 assert response.status_code == 200
assert "window.location = '/sources'" in body assert "window.location = '/sources'" in body
source = Source.get(Source.slug == "nasa-feed") source, feed, job = _db_reader(
feed = SourceFeed.get(SourceFeed.source == source) lambda: (
job = Job.get(Job.source == source) Source.get(Source.slug == "nasa-feed"),
SourceFeed.get(
SourceFeed.source == Source.get(Source.slug == "nasa-feed")
),
Job.get(Job.source == Source.get(Source.slug == "nasa-feed")),
)
)
rendered_sources = str(await render_sources(app)) rendered_sources = str(await render_sources(app))
assert source.source_type == "feed" assert source.source_type == "feed"
@ -1390,9 +1442,15 @@ def test_edit_source_action_updates_existing_source_and_job_in_database(
assert response.status_code == 200 assert response.status_code == 200
assert "window.location = '/sources'" in body assert "window.location = '/sources'" in body
source = Source.get(Source.slug == "kenya-health") source, pangea, job = _db_reader(
pangea = SourcePangea.get(SourcePangea.source == source) lambda: (
job = Job.get(Job.source == source) Source.get(Source.slug == "kenya-health"),
SourcePangea.get(
SourcePangea.source == Source.get(Source.slug == "kenya-health")
),
Job.get(Job.source == Source.get(Source.slug == "kenya-health")),
)
)
rendered_sources = str(await render_sources(app)) rendered_sources = str(await render_sources(app))
assert source.name == "Kenya health desk nightly" assert source.name == "Kenya health desk nightly"
@ -1477,8 +1535,18 @@ def test_edit_source_action_rejects_slug_changes(monkeypatch, tmp_path: Path) ->
assert response.status_code == 200 assert response.status_code == 200
assert "Slug is immutable." in body assert "Slug is immutable." in body
assert Source.get(Source.slug == "kenya-health").name == "Kenya health desk" assert (
assert Source.select().where(Source.slug == "kenya-health-renamed").count() == 0 _db_reader(lambda: Source.get(Source.slug == "kenya-health").name)
== "Kenya health desk"
)
assert (
_db_reader(
lambda: Source.select()
.where(Source.slug == "kenya-health-renamed")
.count()
)
== 0
)
asyncio.run(run()) asyncio.run(run())
@ -1491,11 +1559,13 @@ def test_create_source_action_validates_duplicate_slug_and_pangea_type(
async def run() -> None: async def run() -> None:
app = create_app() app = create_app()
Source.create( _db_writer(
lambda: Source.create(
name="Guardian feed mirror", name="Guardian feed mirror",
slug="guardian-feed", slug="guardian-feed",
source_type="feed", source_type="feed",
) )
)
client = app.test_client() client = app.test_client()
response = await client.post( response = await client.post(
@ -1526,7 +1596,14 @@ def test_create_source_action_validates_duplicate_slug_and_pangea_type(
assert "Content format is invalid." in body assert "Content format is invalid." in body
assert "Content type is invalid." in body assert "Content type is invalid." in body
assert "Max articles must be an integer." in body assert "Max articles must be an integer." in body
assert Source.select().where(Source.name == "Duplicate guardian").count() == 0 assert (
_db_reader(
lambda: Source.select()
.where(Source.name == "Duplicate guardian")
.count()
)
== 0
)
asyncio.run(run()) asyncio.run(run())
@ -1629,10 +1706,14 @@ def test_render_runs_shows_running_scheduled_and_completed_tables(
cron_month="*", cron_month="*",
feed_url="https://example.com/runs.xml", feed_url="https://example.com/runs.xml",
) )
job = Job.get(Job.source == source) job, execution = _db_writer(
execution = JobExecution.create( lambda: (
job=job, Job.get(Job.source == source),
JobExecution.create(
job=Job.get(Job.source == source),
running_status=JobExecutionStatus.SUCCEEDED, running_status=JobExecutionStatus.SUCCEEDED,
),
)
) )
body = str(await render_runs(app)) body = str(await render_runs(app))
@ -1704,14 +1785,16 @@ def test_runs_pagination_action_updates_only_the_current_tab(
cron_month="*", cron_month="*",
feed_url="https://example.com/paged-runs.xml", feed_url="https://example.com/paged-runs.xml",
) )
job = Job.get(Job.source == source) _db_writer(
lambda: tuple(
for minute in range(21):
JobExecution.create( JobExecution.create(
job=job, job=Job.get(Job.source == source),
ended_at=datetime(2026, 3, 30, 12, minute, tzinfo=UTC), ended_at=datetime(2026, 3, 30, 12, minute, tzinfo=UTC),
running_status=JobExecutionStatus.SUCCEEDED, running_status=JobExecutionStatus.SUCCEEDED,
) )
for minute in range(21)
)
)
async with client.request( async with client.request(
"/runs?u=shim", "/runs?u=shim",
@ -1853,10 +1936,14 @@ def test_render_runs_keeps_queued_execution_in_scheduled_jobs_table(
cron_month="*", cron_month="*",
feed_url="https://example.com/scheduled.xml", feed_url="https://example.com/scheduled.xml",
) )
queued_job = Job.get(Job.source == queued_source) queued_job, queued_execution = _db_writer(
queued_execution = JobExecution.create( lambda: (
job=queued_job, Job.get(Job.source == queued_source),
JobExecution.create(
job=Job.get(Job.source == queued_source),
running_status=JobExecutionStatus.PENDING, running_status=JobExecutionStatus.PENDING,
),
)
) )
async def run() -> None: async def run() -> None:
@ -1899,15 +1986,19 @@ def test_render_runs_shows_cancel_button_for_running_row_with_queued_follow_up(
cron_month="*", cron_month="*",
feed_url="https://example.com/busy.xml", feed_url="https://example.com/busy.xml",
) )
job = Job.get(Job.source == source) job, running_execution, pending_execution = _db_writer(
running_execution = JobExecution.create( lambda: (
job=job, Job.get(Job.source == source),
JobExecution.create(
job=Job.get(Job.source == source),
started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC),
running_status=JobExecutionStatus.RUNNING, running_status=JobExecutionStatus.RUNNING,
) ),
pending_execution = JobExecution.create( JobExecution.create(
job=job, job=Job.get(Job.source == source),
running_status=JobExecutionStatus.PENDING, running_status=JobExecutionStatus.PENDING,
),
)
) )
async def run() -> None: async def run() -> None:
@ -2036,15 +2127,19 @@ def test_cancel_queued_execution_action_deletes_pending_row_without_touching_run
cron_month="*", cron_month="*",
feed_url="https://example.com/busy.xml", feed_url="https://example.com/busy.xml",
) )
job = Job.get(Job.source == source) job, running_execution, pending_execution = _db_writer(
running_execution = JobExecution.create( lambda: (
job=job, Job.get(Job.source == source),
JobExecution.create(
job=Job.get(Job.source == source),
started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC),
running_status=JobExecutionStatus.RUNNING, running_status=JobExecutionStatus.RUNNING,
) ),
pending_execution = JobExecution.create( JobExecution.create(
job=job, job=Job.get(Job.source == source),
running_status=JobExecutionStatus.PENDING, running_status=JobExecutionStatus.PENDING,
),
)
) )
response = await client.post( response = await client.post(
@ -2052,9 +2147,18 @@ def test_cancel_queued_execution_action_deletes_pending_row_without_touching_run
) )
assert response.status_code == 204 assert response.status_code == 204
assert JobExecution.get_or_none(id=int(pending_execution.get_id())) is None
assert ( assert (
JobExecution.get_by_id(int(running_execution.get_id())).running_status _db_reader(
lambda: JobExecution.get_or_none(id=int(pending_execution.get_id()))
)
is None
)
assert (
_db_reader(
lambda: JobExecution.get_by_id(
int(running_execution.get_id())
).running_status
)
== JobExecutionStatus.RUNNING == JobExecutionStatus.RUNNING
) )
@ -2087,16 +2191,20 @@ def test_clear_completed_executions_action_removes_history_and_log_artifacts(
cron_month="*", cron_month="*",
feed_url="https://example.com/history.xml", feed_url="https://example.com/history.xml",
) )
job = Job.get(Job.source == source) job, completed_execution, running_execution = _db_writer(
completed_execution = JobExecution.create( lambda: (
job=job, Job.get(Job.source == source),
JobExecution.create(
job=Job.get(Job.source == source),
running_status=JobExecutionStatus.SUCCEEDED, running_status=JobExecutionStatus.SUCCEEDED,
ended_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), ended_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC),
) ),
running_execution = JobExecution.create( JobExecution.create(
job=job, job=Job.get(Job.source == source),
running_status=JobExecutionStatus.RUNNING, running_status=JobExecutionStatus.RUNNING,
started_at=datetime(2026, 3, 30, 12, 5, tzinfo=UTC), started_at=datetime(2026, 3, 30, 12, 5, tzinfo=UTC),
),
)
) )
log_dir.mkdir(parents=True, exist_ok=True) log_dir.mkdir(parents=True, exist_ok=True)
completed_prefix = ( completed_prefix = (
@ -2112,8 +2220,18 @@ def test_clear_completed_executions_action_removes_history_and_log_artifacts(
response = await client.post("/actions/completed-executions/clear") response = await client.post("/actions/completed-executions/clear")
assert response.status_code == 204 assert response.status_code == 204
assert JobExecution.get_or_none(id=int(completed_execution.get_id())) is None assert (
assert JobExecution.get_or_none(id=int(running_execution.get_id())) is not None _db_reader(
lambda: JobExecution.get_or_none(id=int(completed_execution.get_id()))
)
is None
)
assert (
_db_reader(
lambda: JobExecution.get_or_none(id=int(running_execution.get_id()))
)
is not None
)
for suffix in (".log", ".jsonl", ".pygea.log"): for suffix in (".log", ".jsonl", ".pygea.log"):
assert not completed_prefix.with_suffix(suffix).exists() assert not completed_prefix.with_suffix(suffix).exists()
assert running_log_path.exists() assert running_log_path.exists()
@ -2161,17 +2279,21 @@ def test_move_queued_execution_action_reorders_queue(
cron_month="*", cron_month="*",
feed_url="https://example.com/second.xml", feed_url="https://example.com/second.xml",
) )
first_job = Job.get(Job.source == first_source) first_job, second_job, first_execution, second_execution = _db_writer(
second_job = Job.get(Job.source == second_source) lambda: (
first_execution = JobExecution.create( Job.get(Job.source == first_source),
job=first_job, Job.get(Job.source == second_source),
JobExecution.create(
job=Job.get(Job.source == first_source),
created_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), created_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC),
running_status=JobExecutionStatus.PENDING, running_status=JobExecutionStatus.PENDING,
) ),
second_execution = JobExecution.create( JobExecution.create(
job=second_job, job=Job.get(Job.source == second_source),
created_at=datetime(2026, 3, 30, 12, 5, tzinfo=UTC), created_at=datetime(2026, 3, 30, 12, 5, tzinfo=UTC),
running_status=JobExecutionStatus.PENDING, running_status=JobExecutionStatus.PENDING,
),
)
) )
response = await client.post( response = await client.post(
@ -2217,17 +2339,26 @@ def test_toggle_job_enabled_action_removes_queued_execution(
cron_month="*", cron_month="*",
feed_url="https://example.com/queued.xml", feed_url="https://example.com/queued.xml",
) )
job = Job.get(Job.source == source) job, queued_execution = _db_writer(
queued_execution = JobExecution.create( lambda: (
job=job, Job.get(Job.source == source),
JobExecution.create(
job=Job.get(Job.source == source),
running_status=JobExecutionStatus.PENDING, running_status=JobExecutionStatus.PENDING,
),
)
) )
response = await client.post(f"/actions/jobs/{job.id}/toggle-enabled") response = await client.post(f"/actions/jobs/{job.id}/toggle-enabled")
assert response.status_code == 204 assert response.status_code == 204
assert Job.get_by_id(job.id).enabled is False assert _db_reader(lambda: Job.get_by_id(job.id).enabled) is False
assert JobExecution.get_or_none(id=int(queued_execution.get_id())) is None assert (
_db_reader(
lambda: JobExecution.get_or_none(id=int(queued_execution.get_id()))
)
is None
)
body = str(await render_runs(app)) body = str(await render_runs(app))
assert ( assert (
f"/actions/queued-executions/{int(queued_execution.get_id())}/cancel" f"/actions/queued-executions/{int(queued_execution.get_id())}/cancel"
@ -2279,10 +2410,14 @@ def test_render_execution_logs_uses_app_route(monkeypatch, tmp_path: Path) -> No
cron_month="*", cron_month="*",
feed_url="https://example.com/logs.xml", feed_url="https://example.com/logs.xml",
) )
job = Job.get(Job.source == source) job, execution = _db_writer(
execution = JobExecution.create( lambda: (
job=job, Job.get(Job.source == source),
JobExecution.create(
job=Job.get(Job.source == source),
running_status=JobExecutionStatus.RUNNING, running_status=JobExecutionStatus.RUNNING,
),
)
) )
log_path = log_dir / f"job-{job.id}-execution-{execution.get_id()}.log" log_path = log_dir / f"job-{job.id}-execution-{execution.get_id()}.log"
log_path.parent.mkdir(parents=True, exist_ok=True) log_path.parent.mkdir(parents=True, exist_ok=True)