From 3f28e46ff65d28e7ad11b2964de19dcf8e0c6da5 Mon Sep 17 00:00:00 2001 From: Abel Luck Date: Tue, 31 Mar 2026 17:30:07 +0200 Subject: [PATCH] Refactor database access through managed connections --- repub/db.py | 359 ++++++++++++++++++++++++++++ repub/job_runner.py | 2 +- repub/jobs.py | 398 ++++++++++++++++---------------- repub/model.py | 346 +++++++++++---------------- repub/web.py | 8 +- tests/test_db.py | 111 +++++++++ tests/test_jobs.py | 145 ++++++------ tests/test_model.py | 47 ++-- tests/test_scheduler_runtime.py | 296 +++++++++++++++--------- tests/test_web.py | 331 ++++++++++++++++++-------- 10 files changed, 1327 insertions(+), 716 deletions(-) create mode 100644 repub/db.py create mode 100644 tests/test_db.py diff --git a/repub/db.py b/repub/db.py new file mode 100644 index 0000000..c1573a8 --- /dev/null +++ b/repub/db.py @@ -0,0 +1,359 @@ +from __future__ import annotations + +import os +import queue +import re +import threading +from contextlib import contextmanager +from contextvars import ContextVar +from importlib import resources +from importlib.resources.abc import Traversable +from pathlib import Path +from typing import Iterator + +from peewee import BooleanField, Check, SqliteDatabase +from playhouse.migrate import SchemaMigrator, migrate + +DEFAULT_DB_PATH = Path("republisher.db") +DATABASE_PRAGMAS = { + "busy_timeout": 5000, + "cache_size": 15625, + "foreign_keys": 1, + "journal_mode": "wal", + "page_size": 4096, + "synchronous": "normal", + "temp_store": "memory", +} +SCHEMA_GLOB = "*.sql" + +_WRITE_SQL_PREFIX = re.compile( + r"^\s*(INSERT|UPDATE|DELETE|REPLACE|CREATE|DROP|ALTER)\b" +) +_current_database: ContextVar[ManagedSqliteDatabase | None] = ContextVar( + "repub_current_database", + default=None, +) +_current_scope: ContextVar[str | None] = ContextVar( + "repub_current_database_scope", + default=None, +) +_database_connection: DatabaseConnection | None = None + + +def resolve_database_path(db_path: str | Path | None = None) -> Path: + raw_value = ( + os.environ.get("REPUBLISHER_DB_PATH", DEFAULT_DB_PATH) + if db_path is None + else db_path + ) + return Path(raw_value).expanduser().resolve() + + +def schema_paths() -> tuple[Traversable, ...]: + schema_dir = resources.files("repub").joinpath("sql") + return tuple( + sorted( + (path for path in schema_dir.iterdir() if path.name.endswith(".sql")), + key=lambda path: path.name, + ) + ) + + +class ManagedSqliteDatabase(SqliteDatabase): + def __init__(self, database: str, **kwargs) -> None: + pragmas = kwargs.pop("pragmas", DATABASE_PRAGMAS) + kwargs.setdefault("check_same_thread", False) + super().__init__( + database, + autoconnect=False, + thread_safe=False, + pragmas=pragmas, + **kwargs, + ) + + +class RoutedSqliteDatabase(SqliteDatabase): + def __init__(self) -> None: + super().__init__( + ":managed:", + autoconnect=False, + pragmas=DATABASE_PRAGMAS, + ) + + def _require_active_database(self) -> ManagedSqliteDatabase: + active_database = _current_database.get() + if active_database is None: + raise RuntimeError( + "Database access requires a database.reader() or database.writer() context." + ) + return active_database + + def _validate_sql(self, sql: str) -> None: + scope = _current_scope.get() + if scope in {"reader", "reader_conn"} and _WRITE_SQL_PREFIX.match(sql): + raise RuntimeError("Write query attempted inside database.reader() scope.") + + def connect(self, reuse_if_open: bool = False): # type: ignore[override] + raise RuntimeError( + "Do not call database.connect() directly; use database.reader() or database.writer()." + ) + + def close(self): # type: ignore[override] + raise RuntimeError( + "Do not call database.close() directly; use database.reader() or database.writer()." + ) + + def is_closed(self) -> bool: # type: ignore[override] + active_database = _current_database.get() + return True if active_database is None else active_database.is_closed() + + def connection(self): # type: ignore[override] + return self._require_active_database().connection() + + def cursor(self, named_cursor=None): # type: ignore[override] + return self._require_active_database().cursor(named_cursor=named_cursor) + + def execute_sql(self, sql, params=None): # type: ignore[override] + self._validate_sql(str(sql)) + return self._require_active_database().execute_sql(sql, params) + + def execute(self, query, **context_options): # type: ignore[override] + return self._require_active_database().execute(query, **context_options) + + def atomic(self, *args, **kwargs): # type: ignore[override] + return self._require_active_database().atomic(*args, **kwargs) + + def transaction(self, *args, **kwargs): # type: ignore[override] + return self._require_active_database().transaction(*args, **kwargs) + + def savepoint(self, *args, **kwargs): # type: ignore[override] + return self._require_active_database().savepoint(*args, **kwargs) + + def connection_context(self): # type: ignore[override] + raise RuntimeError( + "Do not call database.connection_context() directly; use database.reader() or database.writer()." + ) + + def reader(self): + return _require_database_connection().reader() + + def writer(self): + return _require_database_connection().writer() + + def reader_conn(self): + return _require_database_connection().reader_conn() + + def writer_conn(self): + return _require_database_connection().writer_conn() + + +class DatabaseConnection: + def __init__( + self, + db_path: str | Path, + *, + pool_size: int = 4, + pragmas: dict[str, object] | None = None, + ) -> None: + self.db_path = resolve_database_path(db_path) + self.pool_size = pool_size + self.pragmas = dict(DATABASE_PRAGMAS if pragmas is None else pragmas) + self.writer_db = ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas) + self.reader_dbs = tuple( + ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas) + for _ in range(pool_size) + ) + self._reader_pool: queue.Queue[ManagedSqliteDatabase] = queue.Queue() + self._writer_lock = threading.RLock() + for reader_db in self.reader_dbs: + self._reader_pool.put(reader_db) + + def initialize(self) -> Path: + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self.writer_db.connect(reuse_if_open=True) + try: + for path in schema_paths(): + self.writer_db.connection().executescript( + path.read_text(encoding="utf-8") + ) + _run_legacy_migrations(self.writer_db) + except Exception: + self.writer_db.close() + raise + + for reader_db in self.reader_dbs: + reader_db.connect(reuse_if_open=True) + return self.db_path + + def close(self) -> None: + for reader_db in self.reader_dbs: + if not reader_db.is_closed(): + reader_db.close() + if not self.writer_db.is_closed(): + self.writer_db.close() + + @contextmanager + def reader(self) -> Iterator[ManagedSqliteDatabase]: + scope = _current_scope.get() + if scope in {"reader", "writer"}: + yield _require_active_database() + return + if scope == "reader_conn": + active_database = _require_active_database() + with active_database.atomic(): + yield active_database + return + if scope == "writer_conn": + active_database = _require_active_database() + with active_database.atomic(): + yield active_database + return + if scope == "writer": + yield _require_active_database() + return + + leased_database = self._reader_pool.get() + database_token = _current_database.set(leased_database) + scope_token = _current_scope.set("reader") + try: + with leased_database.atomic(): + yield leased_database + finally: + _current_scope.reset(scope_token) + _current_database.reset(database_token) + self._reader_pool.put(leased_database) + + @contextmanager + def writer(self) -> Iterator[ManagedSqliteDatabase]: + scope = _current_scope.get() + if scope == "writer": + yield _require_active_database() + return + if scope == "writer_conn": + active_database = _require_active_database() + with active_database.atomic(): + yield active_database + return + if scope in {"reader", "reader_conn"}: + raise RuntimeError( + "Cannot enter database.writer() inside database.reader()." + ) + + with self._writer_lock: + database_token = _current_database.set(self.writer_db) + scope_token = _current_scope.set("writer") + try: + with self.writer_db.atomic(): + yield self.writer_db + finally: + _current_scope.reset(scope_token) + _current_database.reset(database_token) + + @contextmanager + def reader_conn(self) -> Iterator[ManagedSqliteDatabase]: + scope = _current_scope.get() + if scope is not None: + yield _require_active_database() + return + + leased_database = self._reader_pool.get() + database_token = _current_database.set(leased_database) + scope_token = _current_scope.set("reader_conn") + try: + yield leased_database + finally: + _current_scope.reset(scope_token) + _current_database.reset(database_token) + self._reader_pool.put(leased_database) + + @contextmanager + def writer_conn(self) -> Iterator[ManagedSqliteDatabase]: + scope = _current_scope.get() + if scope in {"writer", "writer_conn"}: + yield _require_active_database() + return + if scope in {"reader", "reader_conn"}: + raise RuntimeError( + "Cannot enter database.writer_conn() inside database.reader()." + ) + + with self._writer_lock: + database_token = _current_database.set(self.writer_db) + scope_token = _current_scope.set("writer_conn") + try: + yield self.writer_db + finally: + _current_scope.reset(scope_token) + _current_database.reset(database_token) + + +def _run_legacy_migrations(database: ManagedSqliteDatabase) -> None: + job_columns = {column.name for column in database.get_columns("job")} + operations = [] + migrator = SchemaMigrator.from_database(database) + if "convert_images" not in job_columns: + operations.extend( + ( + migrator.add_column( + "job", + "convert_images", + BooleanField( + default=True, + constraints=[Check("convert_images IN (0, 1)")], + ), + ), + migrator.add_column_default("job", "convert_images", 1), + ) + ) + if "convert_video" not in job_columns: + operations.extend( + ( + migrator.add_column( + "job", + "convert_video", + BooleanField( + default=True, + constraints=[Check("convert_video IN (0, 1)")], + ), + ), + migrator.add_column_default("job", "convert_video", 1), + ) + ) + if operations: + with database.atomic(): + migrate(*operations) + + +def initialize_database(db_path: str | Path | None = None) -> Path: + global _database_connection + + if _database_connection is not None: + _database_connection.close() + + connection = DatabaseConnection(resolve_database_path(db_path)) + resolved_path = connection.initialize() + _database_connection = connection + return resolved_path + + +def get_database_connection() -> DatabaseConnection | None: + return _database_connection + + +def _require_database_connection() -> DatabaseConnection: + database_connection = get_database_connection() + if database_connection is None: + raise RuntimeError("Database has not been initialized.") + return database_connection + + +def _require_active_database() -> ManagedSqliteDatabase: + active_database = _current_database.get() + if active_database is None: + raise RuntimeError( + "Database access requires a database.reader() or database.writer() context." + ) + return active_database + + +database = RoutedSqliteDatabase() diff --git a/repub/job_runner.py b/repub/job_runner.py index 90a8d96..68b3be1 100644 --- a/repub/job_runner.py +++ b/repub/job_runner.py @@ -309,7 +309,7 @@ def main(argv: list[str] | None = None) -> int: def _load_job_source_config(*, db_path: str, job_id: int) -> JobSourceConfig: initialize_database(db_path) primary_key = getattr(Job, "_meta").primary_key - with database.connection_context(): + with database.reader(): job = ( Job.select(Job, Source) .join(Source) diff --git a/repub/jobs.py b/repub/jobs.py index a8ee9fd..2664d10 100644 --- a/repub/jobs.py +++ b/repub/jobs.py @@ -18,6 +18,7 @@ from apscheduler.triggers.cron import CronTrigger from peewee import IntegrityError from repub.config import feed_output_dir, feed_output_path +from repub.db import get_database_connection from repub.model import ( Job, JobExecution, @@ -173,7 +174,7 @@ class JobRuntime: self._started = False def sync_jobs(self) -> None: - with database.connection_context(): + with database.reader(): jobs = tuple(Job.select().where(Job.enabled == True)) # noqa: E712 desired_ids = set() @@ -216,35 +217,34 @@ class JobRuntime: return execution_id def _enqueue_job_run_locked(self, job_id: int) -> int | None: - with database.connection_context(): - with database.atomic(): - job = Job.get_or_none(id=job_id) - if job is None: - return None + with database.writer(): + job = Job.get_or_none(id=job_id) + if job is None: + return None + pending_execution = JobExecution.get_or_none( + (JobExecution.job == job) + & (JobExecution.running_status == JobExecutionStatus.PENDING) + ) + if pending_execution is not None: + return _execution_id(pending_execution) + + try: + execution = JobExecution.create( + job=job, + running_status=JobExecutionStatus.PENDING, + ) + except IntegrityError: pending_execution = JobExecution.get_or_none( (JobExecution.job == job) & (JobExecution.running_status == JobExecutionStatus.PENDING) ) - if pending_execution is not None: - return _execution_id(pending_execution) - - try: - execution = JobExecution.create( - job=job, - running_status=JobExecutionStatus.PENDING, - ) - except IntegrityError: - pending_execution = JobExecution.get_or_none( - (JobExecution.job == job) - & (JobExecution.running_status == JobExecutionStatus.PENDING) - ) - return ( - _execution_id(pending_execution) - if pending_execution is not None - else None - ) - return _execution_id(execution) + return ( + _execution_id(pending_execution) + if pending_execution is not None + else None + ) + return _execution_id(execution) def _start_queued_jobs(self) -> None: with self._run_lock: @@ -259,14 +259,13 @@ class JobRuntime: if claimed_execution is None: return - job = cast(Job, claimed_execution.job) self._start_worker_for_execution( - job_id=_job_id(job), - execution_id=_execution_id(claimed_execution), + job_id=claimed_execution[0], + execution_id=claimed_execution[1], ) - def _claim_next_pending_execution(self) -> JobExecution | None: - with database.connection_context(): + def _claim_next_pending_execution(self) -> tuple[int, int] | None: + with database.writer(): execution_primary_key = getattr(JobExecution, "_meta").primary_key pending_executions = tuple( JobExecution.select(JobExecution, Job) @@ -301,10 +300,21 @@ class JobRuntime: .execute() ) if claimed: - return JobExecution.get_by_id(_execution_id(execution)) + claimed_execution = ( + JobExecution.select(JobExecution, Job) + .join(Job) + .where(execution_primary_key == _execution_id(execution)) + .get() + ) + claimed_job = cast(Job, claimed_execution.job) + return (_job_id(claimed_job), _execution_id(claimed_execution)) return None def _start_worker_for_execution(self, *, job_id: int, execution_id: int) -> None: + database_connection = get_database_connection() + if database_connection is None: + raise RuntimeError("Database has not been initialized.") + artifacts = JobArtifacts.for_execution( log_dir=self.log_dir, job_id=job_id, execution_id=execution_id ) @@ -324,7 +334,7 @@ class JobRuntime: "--execution-id", str(execution_id), "--db-path", - str(database.database), + str(database_connection.db_path), "--out-dir", str(self.log_dir.parent), "--stats-path", @@ -342,15 +352,16 @@ class JobRuntime: ) def _max_concurrent_jobs_reached(self) -> bool: - return ( - JobExecution.select() - .where(JobExecution.running_status == JobExecutionStatus.RUNNING) - .count() - >= load_max_concurrent_jobs() - ) + with database.reader(): + return ( + JobExecution.select() + .where(JobExecution.running_status == JobExecutionStatus.RUNNING) + .count() + >= load_max_concurrent_jobs() + ) def request_execution_cancel(self, execution_id: int) -> bool: - with database.connection_context(): + with database.writer(): execution = JobExecution.get_or_none(id=execution_id) if execution is None: return False @@ -372,7 +383,7 @@ class JobRuntime: def cancel_queued_execution(self, execution_id: int) -> bool: with self._run_lock: - with database.connection_context(): + with database.writer(): execution_primary_key = getattr(JobExecution, "_meta").primary_key deleted = ( JobExecution.delete() @@ -392,7 +403,7 @@ class JobRuntime: def move_queued_execution(self, execution_id: int, *, direction: str) -> bool: offset = -1 if direction == "up" else 1 with self._run_lock: - with database.connection_context(): + with database.writer(): execution_primary_key = getattr(JobExecution, "_meta").primary_key queued_executions = tuple( JobExecution.select() @@ -425,59 +436,50 @@ class JobRuntime: cast(datetime | str, target_execution.created_at) ) - with database.atomic(): - if current_created_at == target_created_at: - adjusted_created_at = target_created_at + timedelta( - microseconds=-1 if offset < 0 else 1 + if current_created_at == target_created_at: + adjusted_created_at = target_created_at + timedelta( + microseconds=-1 if offset < 0 else 1 + ) + ( + JobExecution.update(created_at=adjusted_created_at) + .where( + execution_primary_key == _execution_id(current_execution) ) - ( - JobExecution.update(created_at=adjusted_created_at) - .where( - execution_primary_key - == _execution_id(current_execution) - ) - .execute() - ) - else: - ( - JobExecution.update(created_at=target_created_at) - .where( - execution_primary_key - == _execution_id(current_execution) - ) - .execute() - ) - ( - JobExecution.update(created_at=current_created_at) - .where( - execution_primary_key == _execution_id(target_execution) - ) - .execute() + .execute() + ) + else: + ( + JobExecution.update(created_at=target_created_at) + .where( + execution_primary_key == _execution_id(current_execution) ) + .execute() + ) + ( + JobExecution.update(created_at=current_created_at) + .where(execution_primary_key == _execution_id(target_execution)) + .execute() + ) self._trigger_refresh() return True def set_job_enabled(self, job_id: int, *, enabled: bool) -> bool: - with database.connection_context(): - with database.atomic(): - job = Job.get_or_none(id=job_id) - if job is None: - return False - job.enabled = enabled - job.save() - if not enabled: - ( - JobExecution.delete() - .where( - (JobExecution.job == job) - & ( - JobExecution.running_status - == JobExecutionStatus.PENDING - ) - ) - .execute() + with database.writer(): + job = Job.get_or_none(id=job_id) + if job is None: + return False + job.enabled = enabled + job.save() + if not enabled: + ( + JobExecution.delete() + .where( + (JobExecution.job == job) + & (JobExecution.running_status == JobExecutionStatus.PENDING) ) + .execute() + ) self.sync_jobs() self._trigger_refresh() return True @@ -493,7 +495,7 @@ class JobRuntime: continue self._apply_stats(worker) - with database.connection_context(): + with database.writer(): execution = JobExecution.get_by_id(execution_id) execution.ended_at = utc_now() execution.running_status = _worker_final_status( @@ -527,7 +529,7 @@ class JobRuntime: return stats = json.loads(lines[-1]) - with database.connection_context(): + with database.writer(): execution = JobExecution.get_by_id(worker.execution_id) execution.requests_count = int(stats.get("requests_count", 0)) execution.items_count = int(stats.get("items_count", 0)) @@ -544,7 +546,7 @@ class JobRuntime: self._trigger_refresh() def _enforce_graceful_stop(self, worker: RunningWorker) -> None: - with database.connection_context(): + with database.reader(): execution = JobExecution.get_by_id(worker.execution_id) if execution.stop_requested_at is None: return @@ -572,16 +574,17 @@ class JobRuntime: self._trigger_refresh() def _has_running_executions(self) -> bool: - return ( - JobExecution.select() - .where(JobExecution.running_status == JobExecutionStatus.RUNNING) - .exists() - ) + with database.reader(): + return ( + JobExecution.select() + .where(JobExecution.running_status == JobExecutionStatus.RUNNING) + .exists() + ) def _reconcile_stale_executions(self) -> None: live_workers = _find_live_workers() recovered_execution_ids: set[int] = set() - with database.connection_context(): + with database.writer(): execution_primary_key = getattr(JobExecution, "_meta").primary_key if live_workers: live_executions = tuple( @@ -669,7 +672,7 @@ def load_runs_view( reference_time = now or datetime.now(UTC) resolved_log_dir = Path(log_dir) sanitized_page_size = max(1, completed_page_size) - with database.connection_context(): + with database.reader(): execution_primary_key = getattr(JobExecution, "_meta").primary_key jobs = tuple(Job.select(Job, Source).join(Source).order_by(Source.name.asc())) queued_executions = tuple( @@ -718,48 +721,52 @@ def load_runs_view( _job_id(cast(Job, execution.job)): execution for execution in queued_executions } - return { - "running": tuple( - _project_running_execution( - execution, - resolved_log_dir, - reference_time, - queued_follow_up=queued_by_job.get(_job_id(cast(Job, execution.job))), - ) - for execution in running_executions - ), - "queued": tuple( - _project_queued_execution( - execution, - reference_time, - position=position, - total_count=len(queued_executions), - ) - for position, execution in enumerate(queued_executions, start=1) - ), - "upcoming": tuple( - _project_upcoming_job( - job, - running_by_job.get(job.id), - queued_by_job.get(job.id), - reference_time, - ) - for job in jobs - ), - "completed": tuple( - _project_completed_execution(execution, resolved_log_dir, reference_time) - for execution in completed_executions - ), - "completed_page": sanitized_completed_page, - "completed_page_size": sanitized_page_size, - "completed_total_count": completed_total_count, - "completed_total_pages": completed_total_pages, - } + return { + "running": tuple( + _project_running_execution( + execution, + resolved_log_dir, + reference_time, + queued_follow_up=queued_by_job.get( + _job_id(cast(Job, execution.job)) + ), + ) + for execution in running_executions + ), + "queued": tuple( + _project_queued_execution( + execution, + reference_time, + position=position, + total_count=len(queued_executions), + ) + for position, execution in enumerate(queued_executions, start=1) + ), + "upcoming": tuple( + _project_upcoming_job( + job, + running_by_job.get(job.id), + queued_by_job.get(job.id), + reference_time, + ) + for job in jobs + ), + "completed": tuple( + _project_completed_execution( + execution, resolved_log_dir, reference_time + ) + for execution in completed_executions + ), + "completed_page": sanitized_completed_page, + "completed_page_size": sanitized_page_size, + "completed_total_count": completed_total_count, + "completed_total_pages": completed_total_pages, + } def clear_completed_executions(*, log_dir: str | Path) -> int: resolved_log_dir = Path(log_dir) - with database.connection_context(): + with database.writer(): execution_primary_key = getattr(JobExecution, "_meta").primary_key completed_executions = tuple( JobExecution.select(JobExecution, Job) @@ -810,7 +817,7 @@ def load_dashboard_view( upcoming_by_job_id = { int(cast(int, job["job_id"])): job for job in runs_view["upcoming"] } - with database.connection_context(): + with database.reader(): jobs = tuple(Job.select(Job, Source).join(Source).order_by(Source.name.asc())) failed_last_day = ( JobExecution.select() @@ -820,80 +827,85 @@ def load_dashboard_view( ) .count() ) - - upcoming_ready = sum( - 1 for job in runs_view["upcoming"] if str(job["run_reason"]) == "Ready" - ) - footprint_bytes = _directory_size(output_dir) - return { - "running": runs_view["running"], - "queued": runs_view["queued"], - "source_feeds": tuple( - _project_source_feed( - cast(Job, job), - output_dir, - reference_time, - running_execution=running_by_job_id.get(_job_id(cast(Job, job))), - queued_execution=queued_by_job_id.get(_job_id(cast(Job, job))), - upcoming_job=upcoming_by_job_id.get(_job_id(cast(Job, job))), - ) - for job in jobs - ), - "snapshot": { - "running_now": str(len(runs_view["running"])), - "upcoming_today": str(upcoming_ready), - "failures_24h": str(failed_last_day), - "artifact_footprint": _format_bytes(footprint_bytes), - }, - } + upcoming_ready = sum( + 1 for job in runs_view["upcoming"] if str(job["run_reason"]) == "Ready" + ) + footprint_bytes = _directory_size(output_dir) + return { + "running": runs_view["running"], + "queued": runs_view["queued"], + "source_feeds": tuple( + _project_source_feed( + cast(Job, job), + output_dir, + reference_time, + running_execution=running_by_job_id.get(_job_id(cast(Job, job))), + queued_execution=queued_by_job_id.get(_job_id(cast(Job, job))), + upcoming_job=upcoming_by_job_id.get(_job_id(cast(Job, job))), + ) + for job in jobs + ), + "snapshot": { + "running_now": str(len(runs_view["running"])), + "upcoming_today": str(upcoming_ready), + "failures_24h": str(failed_last_day), + "artifact_footprint": _format_bytes(footprint_bytes), + }, + } def load_execution_log_view( *, log_dir: str | Path, job_id: int, execution_id: int ) -> ExecutionLogView: - with database.connection_context(): - execution = JobExecution.get_or_none(id=execution_id) - route = f"/job/{job_id}/execution/{execution_id}/logs" - if execution is None or _job_id(cast(Job, execution.job)) != job_id: - return ExecutionLogView( - job_id=job_id, - execution_id=execution_id, - title=f"Job {job_id} / execution {execution_id}", - description="Plain text log view routed through the app.", - status_label="Unavailable", - status_tone="failed", - log_text="", - error_message="Execution does not exist.", + with database.reader(): + execution_primary_key = getattr(JobExecution, "_meta").primary_key + execution = ( + JobExecution.select(JobExecution, Job) + .join(Job) + .where(execution_primary_key == execution_id) + .get_or_none() ) - artifacts = JobArtifacts.for_execution( - log_dir=Path(log_dir), - job_id=job_id, - execution_id=execution_id, - ) - if not artifacts.log_path.exists(): + if execution is None or _job_id(cast(Job, execution.job)) != job_id: + return ExecutionLogView( + job_id=job_id, + execution_id=execution_id, + title=f"Job {job_id} / execution {execution_id}", + description="Plain text log view routed through the app.", + status_label="Unavailable", + status_tone="failed", + log_text="", + error_message="Execution does not exist.", + ) + + artifacts = JobArtifacts.for_execution( + log_dir=Path(log_dir), + job_id=job_id, + execution_id=execution_id, + ) + if not artifacts.log_path.exists(): + return ExecutionLogView( + job_id=job_id, + execution_id=execution_id, + title=f"Job {job_id} / execution {execution_id}", + description="Plain text log view routed through the app.", + status_label=_execution_status_label(execution), + status_tone=_execution_status_tone(execution), + log_text="", + error_message="Log file has not been created yet.", + ) + return ExecutionLogView( job_id=job_id, execution_id=execution_id, title=f"Job {job_id} / execution {execution_id}", - description="Plain text log view routed through the app.", + description=f"Route: {route}", status_label=_execution_status_label(execution), status_tone=_execution_status_tone(execution), - log_text="", - error_message="Log file has not been created yet.", + log_text=artifacts.log_path.read_text(encoding="utf-8"), ) - return ExecutionLogView( - job_id=job_id, - execution_id=execution_id, - title=f"Job {job_id} / execution {execution_id}", - description=f"Route: {route}", - status_label=_execution_status_label(execution), - status_tone=_execution_status_tone(execution), - log_text=artifacts.log_path.read_text(encoding="utf-8"), - ) - def _job_trigger(job: Job) -> CronTrigger: expression = " ".join( diff --git a/repub/model.py b/repub/model.py index 6ee5ae9..33b3f11 100644 --- a/repub/model.py +++ b/repub/model.py @@ -1,12 +1,8 @@ from __future__ import annotations import json -import os from datetime import UTC, datetime from enum import IntEnum -from importlib import resources -from importlib.resources.abc import Traversable -from pathlib import Path from typing import Any from peewee import ( @@ -16,29 +12,24 @@ from peewee import ( ForeignKeyField, IntegerField, Model, - SqliteDatabase, TextField, ) -from playhouse.migrate import SchemaMigrator, migrate -DEFAULT_DB_PATH = Path("republisher.db") -DATABASE_PRAGMAS = { - "busy_timeout": 5000, - "cache_size": 15625, - "foreign_keys": 1, - "journal_mode": "wal", - "page_size": 4096, - "synchronous": "normal", - "temp_store": "memory", -} -SCHEMA_GLOB = "*.sql" +from repub import db as db_module + +DEFAULT_DB_PATH = db_module.DEFAULT_DB_PATH +DATABASE_PRAGMAS = db_module.DATABASE_PRAGMAS +SCHEMA_GLOB = db_module.SCHEMA_GLOB +database = db_module.database +initialize_database = db_module.initialize_database +resolve_database_path = db_module.resolve_database_path +schema_paths = db_module.schema_paths + MAX_CONCURRENT_JOBS_SETTING_KEY = "max_concurrent_jobs" DEFAULT_MAX_CONCURRENT_JOBS = 1 FEED_URL_SETTING_KEY = "feed_url" DEFAULT_FEED_URL = "" -database = SqliteDatabase(None, pragmas=DATABASE_PRAGMAS) - class JobExecutionStatus(IntEnum): PENDING = 0 @@ -52,101 +43,24 @@ def utc_now() -> datetime: return datetime.now(UTC) -def resolve_database_path(db_path: str | Path | None = None) -> Path: - raw_value = ( - os.environ.get("REPUBLISHER_DB_PATH", DEFAULT_DB_PATH) - if db_path is None - else db_path - ) - raw_path = Path(raw_value) - return raw_path.expanduser().resolve() - - -def schema_paths() -> tuple[Traversable, ...]: - schema_dir = resources.files("repub").joinpath("sql") - return tuple( - sorted( - (path for path in schema_dir.iterdir() if path.name.endswith(".sql")), - key=lambda path: path.name, - ) - ) - - -def initialize_database(db_path: str | Path | None = None) -> Path: - resolved_path = resolve_database_path(db_path) - resolved_path.parent.mkdir(parents=True, exist_ok=True) - - if not database.is_closed(): - database.close() - - database.init(str(resolved_path), pragmas=DATABASE_PRAGMAS) - database.connect(reuse_if_open=True) - try: - for path in schema_paths(): - database.connection().executescript(path.read_text(encoding="utf-8")) - _run_legacy_migrations() - finally: - database.close() - - return resolved_path - - -def _run_legacy_migrations() -> None: - job_columns = {column.name for column in database.get_columns("job")} - operations = [] - migrator = SchemaMigrator.from_database(database) - if "convert_images" not in job_columns: - operations.extend( - ( - migrator.add_column( - "job", - "convert_images", - BooleanField( - default=True, - constraints=[Check("convert_images IN (0, 1)")], - ), - ), - migrator.add_column_default("job", "convert_images", 1), - ) - ) - if "convert_video" not in job_columns: - operations.extend( - ( - migrator.add_column( - "job", - "convert_video", - BooleanField( - default=True, - constraints=[Check("convert_video IN (0, 1)")], - ), - ), - migrator.add_column_default("job", "convert_video", 1), - ) - ) - if operations: - with database.atomic(): - migrate(*operations) - - def source_slug_exists(slug: str) -> bool: - with database.connection_context(): + with database.reader(): return Source.select().where(Source.slug == slug).exists() def save_setting(key: str, value: Any) -> None: payload = json.dumps(value, sort_keys=True) - with database.connection_context(): - with database.atomic(): - setting = AppSetting.get_or_none(AppSetting.key == key) - if setting is None: - AppSetting.create(key=key, value=payload) - return - setting.value = payload - setting.save() + with database.writer(): + setting = AppSetting.get_or_none(AppSetting.key == key) + if setting is None: + AppSetting.create(key=key, value=payload) + return + setting.value = payload + setting.save() def load_setting(key: str, default: Any) -> Any: - with database.connection_context(): + with database.reader(): setting = AppSetting.get_or_none(AppSetting.key == key) if setting is None: return default @@ -177,8 +91,14 @@ def load_settings_form() -> dict[str, object]: } +def load_job_enabled(job_id: int) -> bool | None: + with database.reader(): + job = Job.get_or_none(id=job_id) + return None if job is None else job.enabled + + def load_source_form(slug: str) -> dict[str, object] | None: - with database.connection_context(): + with database.reader(): source = Source.get_or_none(Source.slug == slug) if source is None: return None @@ -259,46 +179,45 @@ def create_source( include_content: bool = True, content_format: str = "", ) -> Source: - with database.connection_context(): - with database.atomic(): - source = Source.create( - name=name, - slug=slug, - source_type=source_type, - notes=notes, - ) - if source_type == "feed": - SourceFeed.create( - source=source, - feed_url=feed_url, - ) - else: - SourcePangea.create( - source=source, - domain=pangea_domain, - category_name=pangea_category, - content_type=content_type, - only_newest=only_newest, - max_articles=max_articles, - oldest_article=oldest_article, - include_authors=include_authors, - exclude_media=exclude_media, - include_content=include_content, - content_format=content_format, - ) - Job.create( + with database.writer(): + source = Source.create( + name=name, + slug=slug, + source_type=source_type, + notes=notes, + ) + if source_type == "feed": + SourceFeed.create( source=source, - enabled=enabled, - convert_images=convert_images, - convert_video=convert_video, - spider_arguments=spider_arguments, - cron_minute=cron_minute, - cron_hour=cron_hour, - cron_day_of_month=cron_day_of_month, - cron_day_of_week=cron_day_of_week, - cron_month=cron_month, + feed_url=feed_url, ) - return source + else: + SourcePangea.create( + source=source, + domain=pangea_domain, + category_name=pangea_category, + content_type=content_type, + only_newest=only_newest, + max_articles=max_articles, + oldest_article=oldest_article, + include_authors=include_authors, + exclude_media=exclude_media, + include_content=include_content, + content_format=content_format, + ) + Job.create( + source=source, + enabled=enabled, + convert_images=convert_images, + convert_video=convert_video, + spider_arguments=spider_arguments, + cron_minute=cron_minute, + cron_hour=cron_hour, + cron_day_of_month=cron_day_of_month, + cron_day_of_week=cron_day_of_week, + cron_month=cron_month, + ) + return source def update_source( @@ -329,91 +248,88 @@ def update_source( include_content: bool = True, content_format: str = "", ) -> Source | None: - with database.connection_context(): - with database.atomic(): - source = Source.get_or_none(Source.slug == source_slug) - if source is None: - return None + with database.writer(): + source = Source.get_or_none(Source.slug == source_slug) + if source is None: + return None - source.name = name - source.notes = notes - source.source_type = source_type - source.save() + source.name = name + source.notes = notes + source.source_type = source_type + source.save() - job = Job.get(Job.source == source) - job.enabled = enabled - job.convert_images = convert_images - job.convert_video = convert_video - job.spider_arguments = spider_arguments - job.cron_minute = cron_minute - job.cron_hour = cron_hour - job.cron_day_of_month = cron_day_of_month - job.cron_day_of_week = cron_day_of_week - job.cron_month = cron_month - job.save() + job = Job.get(Job.source == source) + job.enabled = enabled + job.convert_images = convert_images + job.convert_video = convert_video + job.spider_arguments = spider_arguments + job.cron_minute = cron_minute + job.cron_hour = cron_hour + job.cron_day_of_month = cron_day_of_month + job.cron_day_of_week = cron_day_of_week + job.cron_month = cron_month + job.save() - if source_type == "feed": - SourcePangea.delete().where(SourcePangea.source == source).execute() - feed = SourceFeed.get_or_none(SourceFeed.source == source) - if feed is None: - SourceFeed.create(source=source, feed_url=feed_url) - else: - feed.feed_url = feed_url - feed.save() + if source_type == "feed": + SourcePangea.delete().where(SourcePangea.source == source).execute() + feed = SourceFeed.get_or_none(SourceFeed.source == source) + if feed is None: + SourceFeed.create(source=source, feed_url=feed_url) else: - SourceFeed.delete().where(SourceFeed.source == source).execute() - pangea = SourcePangea.get_or_none(SourcePangea.source == source) - if pangea is None: - SourcePangea.create( - source=source, - domain=pangea_domain, - category_name=pangea_category, - content_type=content_type, - only_newest=only_newest, - max_articles=max_articles, - oldest_article=oldest_article, - include_authors=include_authors, - exclude_media=exclude_media, - include_content=include_content, - content_format=content_format, - ) - else: - pangea.domain = pangea_domain - pangea.category_name = pangea_category - pangea.content_type = content_type - pangea.only_newest = only_newest - pangea.max_articles = max_articles - pangea.oldest_article = oldest_article - pangea.include_authors = include_authors - pangea.exclude_media = exclude_media - pangea.include_content = include_content - pangea.content_format = content_format - pangea.save() + feed.feed_url = feed_url + feed.save() + else: + SourceFeed.delete().where(SourceFeed.source == source).execute() + pangea = SourcePangea.get_or_none(SourcePangea.source == source) + if pangea is None: + SourcePangea.create( + source=source, + domain=pangea_domain, + category_name=pangea_category, + content_type=content_type, + only_newest=only_newest, + max_articles=max_articles, + oldest_article=oldest_article, + include_authors=include_authors, + exclude_media=exclude_media, + include_content=include_content, + content_format=content_format, + ) + else: + pangea.domain = pangea_domain + pangea.category_name = pangea_category + pangea.content_type = content_type + pangea.only_newest = only_newest + pangea.max_articles = max_articles + pangea.oldest_article = oldest_article + pangea.include_authors = include_authors + pangea.exclude_media = exclude_media + pangea.include_content = include_content + pangea.content_format = content_format + pangea.save() - return source + return source def delete_job_source(job_id: int) -> bool: - with database.connection_context(): - with database.atomic(): - job = Job.get_or_none(id=job_id) - if job is None: - return False - source = Source.get_by_id(job.source_id) - return source.delete_instance() > 0 + with database.writer(): + job = Job.get_or_none(id=job_id) + if job is None: + return False + source = Source.get_by_id(job.source_id) + return source.delete_instance() > 0 def delete_source(slug: str) -> bool: - with database.connection_context(): - with database.atomic(): - source = Source.get_or_none(Source.slug == slug) - if source is None: - return False - return source.delete_instance() > 0 + with database.writer(): + source = Source.get_or_none(Source.slug == slug) + if source is None: + return False + return source.delete_instance() > 0 def load_sources() -> tuple[dict[str, object], ...]: - with database.connection_context(): + with database.reader(): sources = tuple(Source.select().order_by(Source.created_at.desc())) source_ids = tuple(int(source.get_id()) for source in sources) if not source_ids: diff --git a/repub/web.py b/repub/web.py index 1c50228..372e121 100644 --- a/repub/web.py +++ b/repub/web.py @@ -27,11 +27,11 @@ from repub.jobs import ( load_runs_view, ) from repub.model import ( - Job, create_source, delete_job_source, delete_source, initialize_database, + load_job_enabled, load_settings_form, load_source_form, load_sources, @@ -329,9 +329,9 @@ def create_app(*, dev_mode: bool = False) -> Quart: @app.post("/actions/jobs//toggle-enabled") async def toggle_job_enabled_action(job_id: int) -> Response: - job = Job.get_or_none(id=job_id) - if job is not None: - get_job_runtime(app).set_job_enabled(job_id, enabled=not job.enabled) + enabled = load_job_enabled(job_id) + if enabled is not None: + get_job_runtime(app).set_job_enabled(job_id, enabled=not enabled) trigger_refresh(app) return Response(status=204) diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..b0c476f --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import threading +import time +from pathlib import Path + +import pytest +from peewee import InterfaceError + +from repub.db import get_database_connection +from repub.model import AppSetting, database, initialize_database + + +def test_queries_require_managed_database_context(tmp_path: Path) -> None: + initialize_database(tmp_path / "managed-context.db") + + with pytest.raises( + RuntimeError, match="database.reader\\(\\)|database.writer\\(\\)" + ): + AppSetting.select().count() + + +def test_writer_and_reader_contexts_allow_persisted_queries(tmp_path: Path) -> None: + initialize_database(tmp_path / "reader-writer.db") + + with database.writer(): + AppSetting.create(key="feed_url", value='"https://mirror.example"') + + with database.reader(): + setting = AppSetting.get(AppSetting.key == "feed_url") + + assert setting.value == '"https://mirror.example"' + + +def test_managed_connections_disable_peewee_autoconnect(tmp_path: Path) -> None: + initialize_database(tmp_path / "autoconnect-disabled.db") + + connection = get_database_connection() + + assert connection is not None + assert connection.writer_db.autoconnect is False + assert all(reader_db.autoconnect is False for reader_db in connection.reader_dbs) + + reader_db = connection.reader_dbs[0] + reader_db.close() + + with pytest.raises(InterfaceError, match="database connection not opened"): + reader_db.execute_sql("SELECT 1") + + +def test_database_connection_initializes_four_readers_and_one_writer( + tmp_path: Path, +) -> None: + initialize_database(tmp_path / "pool-shape.db") + + connection = get_database_connection() + + assert connection is not None + assert connection.pool_size == 4 + assert len(connection.reader_dbs) == 4 + assert connection._reader_pool.qsize() == 4 + assert connection.writer_db is not None + + +def test_reader_lease_is_returned_to_the_pool_after_use(tmp_path: Path) -> None: + initialize_database(tmp_path / "reader-lease.db") + + connection = get_database_connection() + + assert connection is not None + initial_size = connection._reader_pool.qsize() + + with database.reader(): + assert connection._reader_pool.qsize() == initial_size - 1 + + assert connection._reader_pool.qsize() == initial_size + + +def test_writer_contexts_serialize_through_the_single_writer(tmp_path: Path) -> None: + initialize_database(tmp_path / "single-writer.db") + + events: list[str] = [] + entered_first_writer = threading.Event() + allow_first_writer_to_exit = threading.Event() + + def first_writer() -> None: + with database.writer(): + events.append("first-entered") + entered_first_writer.set() + allow_first_writer_to_exit.wait(timeout=1) + events.append("first-exiting") + + def second_writer() -> None: + entered_first_writer.wait(timeout=1) + with database.writer(): + events.append("second-entered") + + first_thread = threading.Thread(target=first_writer) + second_thread = threading.Thread(target=second_writer) + first_thread.start() + second_thread.start() + + assert entered_first_writer.wait(timeout=1) is True + time.sleep(0.05) + assert events == ["first-entered"] + + allow_first_writer_to_exit.set() + first_thread.join(timeout=1) + second_thread.join(timeout=1) + + assert events == ["first-entered", "first-exiting", "second-entered"] diff --git a/tests/test_jobs.py b/tests/test_jobs.py index db2b3a3..7e2a447 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -9,6 +9,7 @@ from repub.model import ( JobExecution, JobExecutionStatus, create_source, + database, initialize_database, ) @@ -31,15 +32,16 @@ def test_load_runs_view_humanizes_completed_execution_summary_bytes( cron_month="*", feed_url="https://example.com/completed.xml", ) - job = Job.get(Job.source == source) - JobExecution.create( - job=job, - running_status=JobExecutionStatus.SUCCEEDED, - ended_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), - requests_count=14, - items_count=11, - bytes_count=16_410_269, - ) + with database.writer(): + job = Job.get(Job.source == source) + JobExecution.create( + job=job, + running_status=JobExecutionStatus.SUCCEEDED, + ended_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), + requests_count=14, + items_count=11, + bytes_count=16_410_269, + ) view = load_runs_view( log_dir=tmp_path / "out" / "logs", @@ -67,13 +69,14 @@ def test_load_runs_view_projects_completed_execution_duration( cron_month="*", feed_url="https://example.com/completed.xml", ) - job = Job.get(Job.source == source) - JobExecution.create( - job=job, - running_status=JobExecutionStatus.SUCCEEDED, - started_at=datetime(2026, 3, 30, 11, 59, 12, tzinfo=UTC), - ended_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), - ) + with database.writer(): + job = Job.get(Job.source == source) + JobExecution.create( + job=job, + running_status=JobExecutionStatus.SUCCEEDED, + started_at=datetime(2026, 3, 30, 11, 59, 12, tzinfo=UTC), + ended_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), + ) view = load_runs_view( log_dir=tmp_path / "out" / "logs", @@ -101,15 +104,16 @@ def test_load_runs_view_humanizes_running_execution_summary_bytes( cron_month="*", feed_url="https://example.com/running.xml", ) - job = Job.get(Job.source == source) - JobExecution.create( - job=job, - running_status=JobExecutionStatus.RUNNING, - started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), - requests_count=14, - items_count=11, - bytes_count=1_536, - ) + with database.writer(): + job = Job.get(Job.source == source) + JobExecution.create( + job=job, + running_status=JobExecutionStatus.RUNNING, + started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), + requests_count=14, + items_count=11, + bytes_count=1_536, + ) view = load_runs_view( log_dir=tmp_path / "out" / "logs", @@ -137,12 +141,13 @@ def test_load_runs_view_projects_running_execution_duration( cron_month="*", feed_url="https://example.com/running.xml", ) - job = Job.get(Job.source == source) - JobExecution.create( - job=job, - running_status=JobExecutionStatus.RUNNING, - started_at=datetime(2026, 3, 30, 11, 59, 12, tzinfo=UTC), - ) + with database.writer(): + job = Job.get(Job.source == source) + JobExecution.create( + job=job, + running_status=JobExecutionStatus.RUNNING, + started_at=datetime(2026, 3, 30, 11, 59, 12, tzinfo=UTC), + ) view = load_runs_view( log_dir=tmp_path / "out" / "logs", @@ -184,21 +189,22 @@ def test_load_runs_view_projects_queued_executions_in_fifo_order( cron_month="*", feed_url="https://example.com/second.xml", ) - first_job = Job.get(Job.source == first_source) - second_job = Job.get(Job.source == second_source) reference_time = datetime(2026, 3, 30, 12, 30, tzinfo=UTC) first_created_at = reference_time - timedelta(minutes=7) second_created_at = reference_time - timedelta(minutes=3) - first_execution = JobExecution.create( - job=first_job, - created_at=first_created_at, - running_status=JobExecutionStatus.PENDING, - ) - second_execution = JobExecution.create( - job=second_job, - created_at=second_created_at, - running_status=JobExecutionStatus.PENDING, - ) + with database.writer(): + first_job = Job.get(Job.source == first_source) + second_job = Job.get(Job.source == second_source) + first_execution = JobExecution.create( + job=first_job, + created_at=first_created_at, + running_status=JobExecutionStatus.PENDING, + ) + second_execution = JobExecution.create( + job=second_job, + created_at=second_created_at, + running_status=JobExecutionStatus.PENDING, + ) view = load_runs_view( log_dir=tmp_path / "out" / "logs", @@ -258,12 +264,13 @@ def test_load_runs_view_keeps_queued_jobs_in_scheduled_jobs( cron_month="*", feed_url="https://example.com/scheduled.xml", ) - queued_job = Job.get(Job.source == queued_source) - Job.get(Job.source == scheduled_source) - JobExecution.create( - job=queued_job, - running_status=JobExecutionStatus.PENDING, - ) + with database.writer(): + queued_job = Job.get(Job.source == queued_source) + Job.get(Job.source == scheduled_source) + JobExecution.create( + job=queued_job, + running_status=JobExecutionStatus.PENDING, + ) view = load_runs_view( log_dir=tmp_path / "out" / "logs", @@ -299,17 +306,18 @@ def test_load_runs_view_running_row_targets_queued_follow_up_cancel( cron_month="*", feed_url="https://example.com/running.xml", ) - job = Job.get(Job.source == source) - JobExecution.create( - job=job, - started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), - running_status=JobExecutionStatus.RUNNING, - ) - pending_execution = JobExecution.create( - job=job, - created_at=datetime(2026, 3, 30, 12, 5, tzinfo=UTC), - running_status=JobExecutionStatus.PENDING, - ) + with database.writer(): + job = Job.get(Job.source == source) + JobExecution.create( + job=job, + started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), + running_status=JobExecutionStatus.RUNNING, + ) + pending_execution = JobExecution.create( + job=job, + created_at=datetime(2026, 3, 30, 12, 5, tzinfo=UTC), + running_status=JobExecutionStatus.PENDING, + ) view = load_runs_view( log_dir=tmp_path / "out" / "logs", @@ -341,14 +349,15 @@ def test_load_runs_view_paginates_completed_executions_after_20_rows( cron_month="*", feed_url="https://example.com/completed.xml", ) - job = Job.get(Job.source == source) - base_time = datetime(2026, 3, 30, 12, 0, tzinfo=UTC) - for offset in range(21): - JobExecution.create( - job=job, - running_status=JobExecutionStatus.SUCCEEDED, - ended_at=base_time - timedelta(minutes=offset), - ) + with database.writer(): + job = Job.get(Job.source == source) + base_time = datetime(2026, 3, 30, 12, 0, tzinfo=UTC) + for offset in range(21): + JobExecution.create( + job=job, + running_status=JobExecutionStatus.SUCCEEDED, + ended_at=base_time - timedelta(minutes=offset), + ) first_page = load_runs_view( log_dir=tmp_path / "out" / "logs", diff --git a/tests/test_model.py b/tests/test_model.py index 450a654..46d1f64 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -112,8 +112,7 @@ def test_initialize_database_configures_sqlite_pragmas(tmp_path: Path) -> None: initialize_database(db_path) - database.connect(reuse_if_open=True) - try: + with database.reader_conn(): pragma_values = { "cache_size": database.execute_sql("PRAGMA cache_size").fetchone()[0], "page_size": database.execute_sql("PRAGMA page_size").fetchone()[0], @@ -132,8 +131,6 @@ def test_initialize_database_configures_sqlite_pragmas(tmp_path: Path) -> None: "foreign_keys": 1, "busy_timeout": 5000, } - finally: - database.close() def test_initialize_database_creates_scheduler_and_execution_indexes( @@ -208,34 +205,35 @@ def test_initialize_database_creates_run_queue_indexes(tmp_path: Path) -> None: def test_job_table_allows_exactly_one_job_per_source(tmp_path: Path) -> None: initialize_database(tmp_path / "jobs.db") - source = Source.create( - name="Guardian feed mirror", - slug="guardian-feed", - source_type="feed", - ) - Job.create( - source=source, - enabled=True, - spider_arguments="", - cron_minute="15", - cron_hour="*", - cron_day_of_month="*", - cron_day_of_week="*", - cron_month="*", - ) - - with pytest.raises(IntegrityError): + with database.writer(): + source = Source.create( + name="Guardian feed mirror", + slug="guardian-feed", + source_type="feed", + ) Job.create( source=source, enabled=True, - spider_arguments="language=en", - cron_minute="30", + spider_arguments="", + cron_minute="15", cron_hour="*", cron_day_of_month="*", cron_day_of_week="*", cron_month="*", ) + with pytest.raises(IntegrityError): + Job.create( + source=source, + enabled=True, + spider_arguments="language=en", + cron_minute="30", + cron_hour="*", + cron_day_of_month="*", + cron_day_of_week="*", + cron_month="*", + ) + def test_load_max_concurrent_jobs_defaults_to_one(tmp_path: Path) -> None: initialize_database(tmp_path / "settings-defaults.db") @@ -248,7 +246,8 @@ def test_save_setting_persists_json_value(tmp_path: Path) -> None: save_setting("max_concurrent_jobs", 4) - row = AppSetting.get(AppSetting.key == "max_concurrent_jobs") + with database.reader(): + row = AppSetting.get(AppSetting.key == "max_concurrent_jobs") assert row.value == "4" assert load_max_concurrent_jobs() == 4 diff --git a/tests/test_scheduler_runtime.py b/tests/test_scheduler_runtime.py index c53b2da..def747c 100644 --- a/tests/test_scheduler_runtime.py +++ b/tests/test_scheduler_runtime.py @@ -19,6 +19,7 @@ from repub.model import ( JobExecutionStatus, Source, create_source, + database, initialize_database, save_setting, ) @@ -29,6 +30,16 @@ FIXTURE_FEED_PATH = ( ).resolve() +def _db_reader(callable_): + with database.reader(): + return callable_() + + +def _db_writer(callable_): + with database.writer(): + return callable_() + + def initialize_runtime_database(db_path: Path) -> None: initialize_database(db_path) save_setting("feed_url", "http://localhost:8080") @@ -64,8 +75,9 @@ def test_job_runtime_syncs_enabled_jobs_into_apscheduler(tmp_path: Path) -> None cron_month="*", feed_url="https://example.com/disabled.xml", ) - enabled_job = Job.get(Job.source == enabled_source) - disabled_job = Job.get(Job.source == disabled_source) + with database.reader(): + enabled_job = Job.get(Job.source == enabled_source) + disabled_job = Job.get(Job.source == disabled_source) runtime = JobRuntime(log_dir=tmp_path / "out" / "logs") try: @@ -77,8 +89,10 @@ def test_job_runtime_syncs_enabled_jobs_into_apscheduler(tmp_path: Path) -> None assert f"job-{enabled_job.id}" in scheduled_ids assert f"job-{disabled_job.id}" not in scheduled_ids - enabled_job.enabled = False - enabled_job.save() + with database.writer(): + enabled_job = Job.get_by_id(enabled_job.id) + enabled_job.enabled = False + enabled_job.save() runtime.sync_jobs() scheduled_ids = {job.id for job in runtime.scheduler.get_jobs()} @@ -105,7 +119,8 @@ def test_job_runtime_run_now_writes_log_and_stats_and_marks_success( cron_month="*", feed_url=FIXTURE_FEED_PATH.as_uri(), ) - job = Job.get(Job.source == source) + with database.reader(): + job = Job.get(Job.source == source) runtime = JobRuntime(log_dir=tmp_path / "out" / "logs") try: @@ -178,8 +193,9 @@ def test_job_runtime_respects_max_concurrent_jobs_setting(tmp_path: Path) -> Non cron_month="*", feed_url=feed_url, ) - first_job = Job.get(Job.source == first_source) - second_job = Job.get(Job.source == second_source) + with database.reader(): + first_job = Job.get(Job.source == first_source) + second_job = Job.get(Job.source == second_source) runtime = JobRuntime(log_dir=log_dir) try: @@ -197,16 +213,20 @@ def test_job_runtime_respects_max_concurrent_jobs_setting(tmp_path: Path) -> Non JobExecutionStatus.PENDING, ) assert ( - JobExecution.select() - .where(JobExecution.running_status == JobExecutionStatus.RUNNING) - .count() + _db_reader( + lambda: JobExecution.select() + .where(JobExecution.running_status == JobExecutionStatus.RUNNING) + .count() + ) == 1 ) assert second_execution.started_at is None assert ( - JobExecution.select() - .where(JobExecution.running_status == JobExecutionStatus.PENDING) - .count() + _db_reader( + lambda: JobExecution.select() + .where(JobExecution.running_status == JobExecutionStatus.PENDING) + .count() + ) == 1 ) runtime.request_execution_cancel(first_execution_id) @@ -253,8 +273,9 @@ def test_job_runtime_starts_queued_execution_after_capacity_opens( cron_month="*", feed_url=FIXTURE_FEED_PATH.as_uri(), ) - first_job = Job.get(Job.source == first_source) - second_job = Job.get(Job.source == second_source) + with database.reader(): + first_job = Job.get(Job.source == first_source) + second_job = Job.get(Job.source == second_source) runtime = JobRuntime(log_dir=log_dir) try: @@ -314,8 +335,9 @@ def test_job_runtime_deduplicates_manual_queue_requests(tmp_path: Path) -> None: cron_month="*", feed_url="https://example.com/queued.xml", ) - blocking_job = Job.get(Job.source == blocking_source) - queued_job = Job.get(Job.source == queued_source) + with database.reader(): + blocking_job = Job.get(Job.source == blocking_source) + queued_job = Job.get(Job.source == queued_source) runtime = JobRuntime(log_dir=log_dir) try: @@ -332,12 +354,14 @@ def test_job_runtime_deduplicates_manual_queue_requests(tmp_path: Path) -> None: assert first_pending_id is not None assert second_pending_id == first_pending_id assert ( - JobExecution.select() - .where( - (JobExecution.job == queued_job) - & (JobExecution.running_status == JobExecutionStatus.PENDING) + _db_reader( + lambda: JobExecution.select() + .where( + (JobExecution.job == queued_job) + & (JobExecution.running_status == JobExecutionStatus.PENDING) + ) + .count() ) - .count() == 1 ) finally: @@ -367,7 +391,8 @@ def test_job_runtime_allows_one_running_and_one_pending_per_job( cron_month="*", feed_url=feed_url, ) - job = Job.get(Job.source == source) + with database.reader(): + job = Job.get(Job.source == source) runtime = JobRuntime(log_dir=log_dir) try: @@ -383,17 +408,21 @@ def test_job_runtime_allows_one_running_and_one_pending_per_job( assert pending_execution_id is not None assert duplicate_pending_id == pending_execution_id assert ( - JobExecution.select() - .where(JobExecution.job == job) - .where(JobExecution.running_status == JobExecutionStatus.RUNNING) - .count() + _db_reader( + lambda: JobExecution.select() + .where(JobExecution.job == job) + .where(JobExecution.running_status == JobExecutionStatus.RUNNING) + .count() + ) == 1 ) assert ( - JobExecution.select() - .where(JobExecution.job == job) - .where(JobExecution.running_status == JobExecutionStatus.PENDING) - .count() + _db_reader( + lambda: JobExecution.select() + .where(JobExecution.job == job) + .where(JobExecution.running_status == JobExecutionStatus.PENDING) + .count() + ) == 1 ) finally: @@ -420,11 +449,12 @@ def test_job_runtime_start_drains_pending_rows_created_before_start( cron_month="*", feed_url=FIXTURE_FEED_PATH.as_uri(), ) - job = Job.get(Job.source == source) - pending_execution = JobExecution.create( - job=job, - running_status=JobExecutionStatus.PENDING, - ) + with database.writer(): + job = Job.get(Job.source == source) + pending_execution = JobExecution.create( + job=job, + running_status=JobExecutionStatus.PENDING, + ) runtime = JobRuntime(log_dir=log_dir) try: @@ -477,18 +507,23 @@ def test_job_runtime_scheduled_runs_use_the_persistent_queue( cron_month="*", feed_url="https://example.com/second-scheduled.xml", ) - first_job = Job.get(Job.source == first_source) - second_job = Job.get(Job.source == second_source) + with database.reader(): + first_job = Job.get(Job.source == first_source) + second_job = Job.get(Job.source == second_source) runtime = JobRuntime(log_dir=log_dir) try: runtime.start() runtime.run_scheduled_job(first_job.id) - first_execution = JobExecution.get(JobExecution.job == first_job) + first_execution = _db_reader( + lambda: JobExecution.get(JobExecution.job == first_job) + ) _wait_for_running_execution(int(first_execution.get_id())) runtime.run_scheduled_job(second_job.id) - second_execution = JobExecution.get(JobExecution.job == second_job) + second_execution = _db_reader( + lambda: JobExecution.get(JobExecution.job == second_job) + ) assert second_execution.running_status == JobExecutionStatus.PENDING assert second_execution.started_at is None @@ -519,7 +554,8 @@ def test_job_runtime_cancel_pending_follow_up_keeps_running_worker_alive( cron_month="*", feed_url=feed_url, ) - job = Job.get(Job.source == source) + with database.reader(): + job = Job.get(Job.source == source) runtime = JobRuntime(log_dir=log_dir) try: @@ -533,9 +569,14 @@ def test_job_runtime_cancel_pending_follow_up_keeps_running_worker_alive( _wait_for_execution_status(pending_execution_id, JobExecutionStatus.PENDING) assert runtime.cancel_queued_execution(pending_execution_id) is True - assert JobExecution.get_or_none(id=pending_execution_id) is None assert ( - JobExecution.get_by_id(running_execution_id).running_status + _db_reader(lambda: JobExecution.get_or_none(id=pending_execution_id)) + is None + ) + assert ( + _db_reader( + lambda: JobExecution.get_by_id(running_execution_id).running_status + ) == JobExecutionStatus.RUNNING ) finally: @@ -559,7 +600,8 @@ def test_job_runtime_cancel_marks_execution_canceled(tmp_path: Path) -> None: cron_month="*", feed_url=feed_url, ) - job = Job.get(Job.source == source) + with database.reader(): + job = Job.get(Job.source == source) runtime = JobRuntime(log_dir=tmp_path / "out" / "logs") try: @@ -602,12 +644,13 @@ def test_job_runtime_start_reconciles_stale_running_execution(tmp_path: Path) -> cron_month="*", feed_url="https://example.com/stale.xml", ) - job = Job.get(Job.source == source) - execution = JobExecution.create( - job=job, - started_at="2026-03-30 12:30:00+00:00", - running_status=JobExecutionStatus.RUNNING, - ) + with database.writer(): + job = Job.get(Job.source == source) + execution = JobExecution.create( + job=job, + started_at="2026-03-30 12:30:00+00:00", + running_status=JobExecutionStatus.RUNNING, + ) artifacts = JobArtifacts.for_execution( log_dir=tmp_path / "out" / "logs", job_id=job.id, @@ -622,7 +665,9 @@ def test_job_runtime_start_reconciles_stale_running_execution(tmp_path: Path) -> runtime = JobRuntime(log_dir=tmp_path / "out" / "logs") try: runtime.start() - reconciled_execution = JobExecution.get_by_id(execution.get_id()) + reconciled_execution = _db_reader( + lambda: JobExecution.get_by_id(execution.get_id()) + ) assert reconciled_execution.running_status == JobExecutionStatus.FAILED assert reconciled_execution.ended_at is not None @@ -649,12 +694,13 @@ def test_job_runtime_publishes_refresh_while_jobs_are_running(tmp_path: Path) -> cron_month="*", feed_url="https://example.com/running.xml", ) - job = Job.get(Job.source == source) - JobExecution.create( - job=job, - started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), - running_status=JobExecutionStatus.RUNNING, - ) + with database.writer(): + job = Job.get(Job.source == source) + JobExecution.create( + job=job, + started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), + running_status=JobExecutionStatus.RUNNING, + ) events: list[object] = [] runtime = JobRuntime( @@ -688,12 +734,13 @@ def test_job_runtime_start_reattaches_live_worker_after_app_restart( cron_month="*", feed_url=feed_url, ) - job = Job.get(Job.source == source) - execution = JobExecution.create( - job=job, - started_at=datetime.now(UTC), - running_status=JobExecutionStatus.RUNNING, - ) + with database.writer(): + job = Job.get(Job.source == source) + execution = JobExecution.create( + job=job, + started_at=datetime.now(UTC), + running_status=JobExecutionStatus.RUNNING, + ) artifacts = JobArtifacts.for_execution( log_dir=log_dir, job_id=job.id, @@ -728,7 +775,9 @@ def test_job_runtime_start_reattaches_live_worker_after_app_restart( time.sleep(0.1) runtime.start() - running_execution = JobExecution.get_by_id(execution.get_id()) + running_execution = _db_reader( + lambda: JobExecution.get_by_id(execution.get_id()) + ) assert running_execution.running_status == JobExecutionStatus.RUNNING assert running_execution.ended_at is None @@ -764,13 +813,14 @@ def test_job_runtime_start_restores_live_worker_marked_failed_by_restart_bug( cron_month="*", feed_url=feed_url, ) - job = Job.get(Job.source == source) - execution = JobExecution.create( - job=job, - started_at=datetime.now(UTC), - ended_at=datetime.now(UTC), - running_status=JobExecutionStatus.FAILED, - ) + with database.writer(): + job = Job.get(Job.source == source) + execution = JobExecution.create( + job=job, + started_at=datetime.now(UTC), + ended_at=datetime.now(UTC), + running_status=JobExecutionStatus.FAILED, + ) artifacts = JobArtifacts.for_execution( log_dir=log_dir, job_id=job.id, @@ -805,7 +855,9 @@ def test_job_runtime_start_restores_live_worker_marked_failed_by_restart_bug( time.sleep(0.1) runtime.start() - restored_execution = JobExecution.get_by_id(execution.get_id()) + restored_execution = _db_reader( + lambda: JobExecution.get_by_id(execution.get_id()) + ) assert restored_execution.running_status == JobExecutionStatus.RUNNING assert restored_execution.ended_at is None @@ -895,14 +947,15 @@ def test_load_runs_view_humanizes_completed_execution_end_time( cron_month="*", feed_url="https://example.com/completed.xml", ) - job = Job.get(Job.source == source) - reference_time = datetime(2026, 1, 15, 12, 0, tzinfo=UTC) - ended_at = reference_time - timedelta(hours=2) - JobExecution.create( - job=job, - running_status=JobExecutionStatus.SUCCEEDED, - ended_at=ended_at, - ) + with database.writer(): + job = Job.get(Job.source == source) + reference_time = datetime(2026, 1, 15, 12, 0, tzinfo=UTC) + ended_at = reference_time - timedelta(hours=2) + JobExecution.create( + job=job, + running_status=JobExecutionStatus.SUCCEEDED, + ended_at=ended_at, + ) view = load_runs_view(log_dir=app.config["REPUB_LOG_DIR"], now=reference_time) completed = view["completed"][0] @@ -934,14 +987,15 @@ def test_load_runs_view_humanizes_running_execution_start_time( cron_month="*", feed_url="https://example.com/running.xml", ) - job = Job.get(Job.source == source) - reference_time = datetime(2026, 1, 15, 12, 0, tzinfo=UTC) - started_at = reference_time - timedelta(hours=2) - JobExecution.create( - job=job, - running_status=JobExecutionStatus.RUNNING, - started_at=started_at, - ) + with database.writer(): + job = Job.get(Job.source == source) + reference_time = datetime(2026, 1, 15, 12, 0, tzinfo=UTC) + started_at = reference_time - timedelta(hours=2) + JobExecution.create( + job=job, + running_status=JobExecutionStatus.RUNNING, + started_at=started_at, + ) view = load_runs_view(log_dir=app.config["REPUB_LOG_DIR"], now=reference_time) running = view["running"][0] @@ -974,7 +1028,8 @@ def test_render_runs_uses_database_backed_jobs_and_executions( cron_month="*", feed_url=FIXTURE_FEED_PATH.as_uri(), ) - job = Job.get(Job.source == source) + with database.reader(): + job = Job.get(Job.source == source) runtime = get_job_runtime(app) runtime.start() try: @@ -1021,11 +1076,12 @@ def test_render_execution_logs_handles_missing_execution_and_missing_log_file( cron_month="*", feed_url="https://example.com/log-source.xml", ) - job = Job.get(Job.source == source) - execution = JobExecution.create( - job=job, - running_status=JobExecutionStatus.FAILED, - ) + with database.writer(): + job = Job.get(Job.source == source) + execution = JobExecution.create( + job=job, + running_status=JobExecutionStatus.FAILED, + ) async def run() -> None: missing_execution = str( @@ -1067,18 +1123,25 @@ def test_delete_job_action_removes_source_job_and_execution_history( cron_month="*", feed_url="https://example.com/delete.xml", ) - job = Job.get(Job.source == source) - execution = JobExecution.create( - job=job, - running_status=JobExecutionStatus.SUCCEEDED, - ) + with database.writer(): + job = Job.get(Job.source == source) + execution = JobExecution.create( + job=job, + running_status=JobExecutionStatus.SUCCEEDED, + ) response = await client.post(f"/actions/jobs/{job.id}/delete") assert response.status_code == 204 - assert Source.get_or_none(Source.slug == "delete-source") is None - assert Job.get_or_none(id=job.id) is None - assert JobExecution.get_or_none(id=int(execution.get_id())) is None + assert ( + _db_reader(lambda: Source.get_or_none(Source.slug == "delete-source")) + is None + ) + assert _db_reader(lambda: Job.get_or_none(id=job.id)) is None + assert ( + _db_reader(lambda: JobExecution.get_or_none(id=int(execution.get_id()))) + is None + ) asyncio.run(run()) @@ -1107,18 +1170,25 @@ def test_delete_source_action_removes_source_job_and_execution_history( cron_month="*", feed_url="https://example.com/delete-source-row.xml", ) - job = Job.get(Job.source == source) - execution = JobExecution.create( - job=job, - running_status=JobExecutionStatus.SUCCEEDED, - ) + with database.writer(): + job = Job.get(Job.source == source) + execution = JobExecution.create( + job=job, + running_status=JobExecutionStatus.SUCCEEDED, + ) response = await client.post("/actions/sources/delete-source-row/delete") assert response.status_code == 204 - assert Source.get_or_none(Source.slug == "delete-source-row") is None - assert Job.get_or_none(id=job.id) is None - assert JobExecution.get_or_none(id=int(execution.get_id())) is None + assert ( + _db_reader(lambda: Source.get_or_none(Source.slug == "delete-source-row")) + is None + ) + assert _db_reader(lambda: Job.get_or_none(id=job.id)) is None + assert ( + _db_reader(lambda: JobExecution.get_or_none(id=int(execution.get_id()))) + is None + ) asyncio.run(run()) @@ -1128,7 +1198,7 @@ def _wait_for_running_execution( ) -> JobExecution: deadline = time.monotonic() + timeout_seconds while time.monotonic() < deadline: - execution = JobExecution.get_by_id(execution_id) + execution = _db_reader(lambda: JobExecution.get_by_id(execution_id)) if execution.running_status == JobExecutionStatus.RUNNING: return execution time.sleep(0.02) @@ -1143,7 +1213,7 @@ def _wait_for_execution_status( ) -> JobExecution: deadline = time.monotonic() + timeout_seconds while time.monotonic() < deadline: - execution = JobExecution.get_by_id(execution_id) + execution = _db_reader(lambda: JobExecution.get_by_id(execution_id)) if execution.running_status == status: return execution time.sleep(0.02) @@ -1155,7 +1225,7 @@ def _wait_for_terminal_execution( ) -> JobExecution: deadline = time.monotonic() + timeout_seconds while time.monotonic() < deadline: - execution = JobExecution.get_by_id(execution_id) + execution = _db_reader(lambda: JobExecution.get_by_id(execution_id)) if execution.running_status in { JobExecutionStatus.SUCCEEDED, JobExecutionStatus.FAILED, diff --git a/tests/test_web.py b/tests/test_web.py index 0b9098a..e23d2cb 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -21,6 +21,7 @@ from repub.model import ( SourceFeed, SourcePangea, create_source, + database, load_max_concurrent_jobs, load_settings_form, save_setting, @@ -43,6 +44,35 @@ from repub.web import ( ) +def _db_reader(fn): + with database.reader(): + return fn() + + +def _db_writer(fn): + with database.writer(): + return fn() + + +def test_web_routes_do_not_access_peewee_models_directly() -> None: + web_source = Path("repub/web.py").read_text(encoding="utf-8") + + assert ( + re.search( + r"\b(Job|Source|JobExecution|SourceFeed|SourcePangea)\.get", + web_source, + ) + is None + ) + assert ( + re.search( + r"\b(Job|Source|JobExecution|SourceFeed|SourcePangea)\.select", + web_source, + ) + is None + ) + + def test_status_badge_uses_green_done_tone() -> None: badge = str(status_badge(label="Succeeded", tone="done")) @@ -790,8 +820,12 @@ def test_load_dashboard_view_lists_source_feed_artifacts( updated_at = reference_time - timedelta(minutes=32) updated_at_epoch = updated_at.timestamp() os.utime(feed_path, (updated_at_epoch, updated_at_epoch)) - available_job = Job.get(Job.source == available_source) - missing_job = Job.get(Job.source == missing_source) + available_job, missing_job = _db_reader( + lambda: ( + Job.get(Job.source == available_source), + Job.get(Job.source == missing_source), + ) + ) source_feeds = cast( tuple[dict[str, object], ...], @@ -871,16 +905,18 @@ def test_load_dashboard_view_projects_feed_status_from_job_runtime( feed_url="https://example.com/queued.xml", ) - running_job = Job.get(Job.source == running_source) - queued_job = Job.get(Job.source == queued_source) - JobExecution.create( - job=running_job, - running_status=JobExecutionStatus.RUNNING, - started_at=reference_time - timedelta(minutes=2), - ) - JobExecution.create( - job=queued_job, - running_status=JobExecutionStatus.PENDING, + _db_writer( + lambda: ( + JobExecution.create( + job=Job.get(Job.source == running_source), + running_status=JobExecutionStatus.RUNNING, + started_at=reference_time - timedelta(minutes=2), + ), + JobExecution.create( + job=Job.get(Job.source == queued_source), + running_status=JobExecutionStatus.PENDING, + ), + ) ) source_feeds = cast( @@ -938,8 +974,12 @@ def test_render_dashboard_shows_source_feed_links_and_statuses( published_feed = tmp_path / "out" / "feeds" / "published-source" / "feed.rss" published_feed.parent.mkdir(parents=True) published_feed.write_text("\n", encoding="utf-8") - published_job = Job.get(Job.source == published_source) - missing_job = Job.get(Job.source == missing_source) + published_job, missing_job = _db_reader( + lambda: ( + Job.get(Job.source == published_source), + Job.get(Job.source == missing_source), + ) + ) body = str(await render_dashboard(app)) @@ -1253,9 +1293,15 @@ def test_create_source_action_creates_pangea_source_and_job_in_database( assert response.status_code == 200 assert "window.location = '/sources'" in body - source = Source.get(Source.slug == "kenya-health") - pangea = SourcePangea.get(SourcePangea.source == source) - job = Job.get(Job.source == source) + source, pangea, job = _db_reader( + lambda: ( + Source.get(Source.slug == "kenya-health"), + SourcePangea.get( + SourcePangea.source == Source.get(Source.slug == "kenya-health") + ), + Job.get(Job.source == Source.get(Source.slug == "kenya-health")), + ) + ) rendered_sources = str(await render_sources(app)) assert source.name == "Kenya health desk" @@ -1307,9 +1353,15 @@ def test_create_source_action_creates_feed_source_and_job_in_database( assert response.status_code == 200 assert "window.location = '/sources'" in body - source = Source.get(Source.slug == "nasa-feed") - feed = SourceFeed.get(SourceFeed.source == source) - job = Job.get(Job.source == source) + source, feed, job = _db_reader( + lambda: ( + Source.get(Source.slug == "nasa-feed"), + SourceFeed.get( + SourceFeed.source == Source.get(Source.slug == "nasa-feed") + ), + Job.get(Job.source == Source.get(Source.slug == "nasa-feed")), + ) + ) rendered_sources = str(await render_sources(app)) assert source.source_type == "feed" @@ -1390,9 +1442,15 @@ def test_edit_source_action_updates_existing_source_and_job_in_database( assert response.status_code == 200 assert "window.location = '/sources'" in body - source = Source.get(Source.slug == "kenya-health") - pangea = SourcePangea.get(SourcePangea.source == source) - job = Job.get(Job.source == source) + source, pangea, job = _db_reader( + lambda: ( + Source.get(Source.slug == "kenya-health"), + SourcePangea.get( + SourcePangea.source == Source.get(Source.slug == "kenya-health") + ), + Job.get(Job.source == Source.get(Source.slug == "kenya-health")), + ) + ) rendered_sources = str(await render_sources(app)) assert source.name == "Kenya health desk nightly" @@ -1477,8 +1535,18 @@ def test_edit_source_action_rejects_slug_changes(monkeypatch, tmp_path: Path) -> assert response.status_code == 200 assert "Slug is immutable." in body - assert Source.get(Source.slug == "kenya-health").name == "Kenya health desk" - assert Source.select().where(Source.slug == "kenya-health-renamed").count() == 0 + assert ( + _db_reader(lambda: Source.get(Source.slug == "kenya-health").name) + == "Kenya health desk" + ) + assert ( + _db_reader( + lambda: Source.select() + .where(Source.slug == "kenya-health-renamed") + .count() + ) + == 0 + ) asyncio.run(run()) @@ -1491,10 +1559,12 @@ def test_create_source_action_validates_duplicate_slug_and_pangea_type( async def run() -> None: app = create_app() - Source.create( - name="Guardian feed mirror", - slug="guardian-feed", - source_type="feed", + _db_writer( + lambda: Source.create( + name="Guardian feed mirror", + slug="guardian-feed", + source_type="feed", + ) ) client = app.test_client() @@ -1526,7 +1596,14 @@ def test_create_source_action_validates_duplicate_slug_and_pangea_type( assert "Content format is invalid." in body assert "Content type is invalid." in body assert "Max articles must be an integer." in body - assert Source.select().where(Source.name == "Duplicate guardian").count() == 0 + assert ( + _db_reader( + lambda: Source.select() + .where(Source.name == "Duplicate guardian") + .count() + ) + == 0 + ) asyncio.run(run()) @@ -1629,10 +1706,14 @@ def test_render_runs_shows_running_scheduled_and_completed_tables( cron_month="*", feed_url="https://example.com/runs.xml", ) - job = Job.get(Job.source == source) - execution = JobExecution.create( - job=job, - running_status=JobExecutionStatus.SUCCEEDED, + job, execution = _db_writer( + lambda: ( + Job.get(Job.source == source), + JobExecution.create( + job=Job.get(Job.source == source), + running_status=JobExecutionStatus.SUCCEEDED, + ), + ) ) body = str(await render_runs(app)) @@ -1704,14 +1785,16 @@ def test_runs_pagination_action_updates_only_the_current_tab( cron_month="*", feed_url="https://example.com/paged-runs.xml", ) - job = Job.get(Job.source == source) - - for minute in range(21): - JobExecution.create( - job=job, - ended_at=datetime(2026, 3, 30, 12, minute, tzinfo=UTC), - running_status=JobExecutionStatus.SUCCEEDED, + _db_writer( + lambda: tuple( + JobExecution.create( + job=Job.get(Job.source == source), + ended_at=datetime(2026, 3, 30, 12, minute, tzinfo=UTC), + running_status=JobExecutionStatus.SUCCEEDED, + ) + for minute in range(21) ) + ) async with client.request( "/runs?u=shim", @@ -1853,10 +1936,14 @@ def test_render_runs_keeps_queued_execution_in_scheduled_jobs_table( cron_month="*", feed_url="https://example.com/scheduled.xml", ) - queued_job = Job.get(Job.source == queued_source) - queued_execution = JobExecution.create( - job=queued_job, - running_status=JobExecutionStatus.PENDING, + queued_job, queued_execution = _db_writer( + lambda: ( + Job.get(Job.source == queued_source), + JobExecution.create( + job=Job.get(Job.source == queued_source), + running_status=JobExecutionStatus.PENDING, + ), + ) ) async def run() -> None: @@ -1899,15 +1986,19 @@ def test_render_runs_shows_cancel_button_for_running_row_with_queued_follow_up( cron_month="*", feed_url="https://example.com/busy.xml", ) - job = Job.get(Job.source == source) - running_execution = JobExecution.create( - job=job, - started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), - running_status=JobExecutionStatus.RUNNING, - ) - pending_execution = JobExecution.create( - job=job, - running_status=JobExecutionStatus.PENDING, + job, running_execution, pending_execution = _db_writer( + lambda: ( + Job.get(Job.source == source), + JobExecution.create( + job=Job.get(Job.source == source), + started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), + running_status=JobExecutionStatus.RUNNING, + ), + JobExecution.create( + job=Job.get(Job.source == source), + running_status=JobExecutionStatus.PENDING, + ), + ) ) async def run() -> None: @@ -2036,15 +2127,19 @@ def test_cancel_queued_execution_action_deletes_pending_row_without_touching_run cron_month="*", feed_url="https://example.com/busy.xml", ) - job = Job.get(Job.source == source) - running_execution = JobExecution.create( - job=job, - started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), - running_status=JobExecutionStatus.RUNNING, - ) - pending_execution = JobExecution.create( - job=job, - running_status=JobExecutionStatus.PENDING, + job, running_execution, pending_execution = _db_writer( + lambda: ( + Job.get(Job.source == source), + JobExecution.create( + job=Job.get(Job.source == source), + started_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), + running_status=JobExecutionStatus.RUNNING, + ), + JobExecution.create( + job=Job.get(Job.source == source), + running_status=JobExecutionStatus.PENDING, + ), + ) ) response = await client.post( @@ -2052,9 +2147,18 @@ def test_cancel_queued_execution_action_deletes_pending_row_without_touching_run ) assert response.status_code == 204 - assert JobExecution.get_or_none(id=int(pending_execution.get_id())) is None assert ( - JobExecution.get_by_id(int(running_execution.get_id())).running_status + _db_reader( + lambda: JobExecution.get_or_none(id=int(pending_execution.get_id())) + ) + is None + ) + assert ( + _db_reader( + lambda: JobExecution.get_by_id( + int(running_execution.get_id()) + ).running_status + ) == JobExecutionStatus.RUNNING ) @@ -2087,16 +2191,20 @@ def test_clear_completed_executions_action_removes_history_and_log_artifacts( cron_month="*", feed_url="https://example.com/history.xml", ) - job = Job.get(Job.source == source) - completed_execution = JobExecution.create( - job=job, - running_status=JobExecutionStatus.SUCCEEDED, - ended_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), - ) - running_execution = JobExecution.create( - job=job, - running_status=JobExecutionStatus.RUNNING, - started_at=datetime(2026, 3, 30, 12, 5, tzinfo=UTC), + job, completed_execution, running_execution = _db_writer( + lambda: ( + Job.get(Job.source == source), + JobExecution.create( + job=Job.get(Job.source == source), + running_status=JobExecutionStatus.SUCCEEDED, + ended_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), + ), + JobExecution.create( + job=Job.get(Job.source == source), + running_status=JobExecutionStatus.RUNNING, + started_at=datetime(2026, 3, 30, 12, 5, tzinfo=UTC), + ), + ) ) log_dir.mkdir(parents=True, exist_ok=True) completed_prefix = ( @@ -2112,8 +2220,18 @@ def test_clear_completed_executions_action_removes_history_and_log_artifacts( response = await client.post("/actions/completed-executions/clear") assert response.status_code == 204 - assert JobExecution.get_or_none(id=int(completed_execution.get_id())) is None - assert JobExecution.get_or_none(id=int(running_execution.get_id())) is not None + assert ( + _db_reader( + lambda: JobExecution.get_or_none(id=int(completed_execution.get_id())) + ) + is None + ) + assert ( + _db_reader( + lambda: JobExecution.get_or_none(id=int(running_execution.get_id())) + ) + is not None + ) for suffix in (".log", ".jsonl", ".pygea.log"): assert not completed_prefix.with_suffix(suffix).exists() assert running_log_path.exists() @@ -2161,17 +2279,21 @@ def test_move_queued_execution_action_reorders_queue( cron_month="*", feed_url="https://example.com/second.xml", ) - first_job = Job.get(Job.source == first_source) - second_job = Job.get(Job.source == second_source) - first_execution = JobExecution.create( - job=first_job, - created_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), - running_status=JobExecutionStatus.PENDING, - ) - second_execution = JobExecution.create( - job=second_job, - created_at=datetime(2026, 3, 30, 12, 5, tzinfo=UTC), - running_status=JobExecutionStatus.PENDING, + first_job, second_job, first_execution, second_execution = _db_writer( + lambda: ( + Job.get(Job.source == first_source), + Job.get(Job.source == second_source), + JobExecution.create( + job=Job.get(Job.source == first_source), + created_at=datetime(2026, 3, 30, 12, 0, tzinfo=UTC), + running_status=JobExecutionStatus.PENDING, + ), + JobExecution.create( + job=Job.get(Job.source == second_source), + created_at=datetime(2026, 3, 30, 12, 5, tzinfo=UTC), + running_status=JobExecutionStatus.PENDING, + ), + ) ) response = await client.post( @@ -2217,17 +2339,26 @@ def test_toggle_job_enabled_action_removes_queued_execution( cron_month="*", feed_url="https://example.com/queued.xml", ) - job = Job.get(Job.source == source) - queued_execution = JobExecution.create( - job=job, - running_status=JobExecutionStatus.PENDING, + job, queued_execution = _db_writer( + lambda: ( + Job.get(Job.source == source), + JobExecution.create( + job=Job.get(Job.source == source), + running_status=JobExecutionStatus.PENDING, + ), + ) ) response = await client.post(f"/actions/jobs/{job.id}/toggle-enabled") assert response.status_code == 204 - assert Job.get_by_id(job.id).enabled is False - assert JobExecution.get_or_none(id=int(queued_execution.get_id())) is None + assert _db_reader(lambda: Job.get_by_id(job.id).enabled) is False + assert ( + _db_reader( + lambda: JobExecution.get_or_none(id=int(queued_execution.get_id())) + ) + is None + ) body = str(await render_runs(app)) assert ( f"/actions/queued-executions/{int(queued_execution.get_id())}/cancel" @@ -2279,10 +2410,14 @@ def test_render_execution_logs_uses_app_route(monkeypatch, tmp_path: Path) -> No cron_month="*", feed_url="https://example.com/logs.xml", ) - job = Job.get(Job.source == source) - execution = JobExecution.create( - job=job, - running_status=JobExecutionStatus.RUNNING, + job, execution = _db_writer( + lambda: ( + Job.get(Job.source == source), + JobExecution.create( + job=Job.get(Job.source == source), + running_status=JobExecutionStatus.RUNNING, + ), + ) ) log_path = log_dir / f"job-{job.id}-execution-{execution.get_id()}.log" log_path.parent.mkdir(parents=True, exist_ok=True)