Refactor database access through managed connections

This commit is contained in:
Abel Luck 2026-03-31 17:30:07 +02:00
parent f19bab6fa2
commit 3f28e46ff6
10 changed files with 1327 additions and 716 deletions

359
repub/db.py Normal file
View file

@ -0,0 +1,359 @@
from __future__ import annotations
import os
import queue
import re
import threading
from contextlib import contextmanager
from contextvars import ContextVar
from importlib import resources
from importlib.resources.abc import Traversable
from pathlib import Path
from typing import Iterator
from peewee import BooleanField, Check, SqliteDatabase
from playhouse.migrate import SchemaMigrator, migrate
DEFAULT_DB_PATH = Path("republisher.db")
DATABASE_PRAGMAS = {
"busy_timeout": 5000,
"cache_size": 15625,
"foreign_keys": 1,
"journal_mode": "wal",
"page_size": 4096,
"synchronous": "normal",
"temp_store": "memory",
}
SCHEMA_GLOB = "*.sql"
_WRITE_SQL_PREFIX = re.compile(
r"^\s*(INSERT|UPDATE|DELETE|REPLACE|CREATE|DROP|ALTER)\b"
)
_current_database: ContextVar[ManagedSqliteDatabase | None] = ContextVar(
"repub_current_database",
default=None,
)
_current_scope: ContextVar[str | None] = ContextVar(
"repub_current_database_scope",
default=None,
)
_database_connection: DatabaseConnection | None = None
def resolve_database_path(db_path: str | Path | None = None) -> Path:
raw_value = (
os.environ.get("REPUBLISHER_DB_PATH", DEFAULT_DB_PATH)
if db_path is None
else db_path
)
return Path(raw_value).expanduser().resolve()
def schema_paths() -> tuple[Traversable, ...]:
schema_dir = resources.files("repub").joinpath("sql")
return tuple(
sorted(
(path for path in schema_dir.iterdir() if path.name.endswith(".sql")),
key=lambda path: path.name,
)
)
class ManagedSqliteDatabase(SqliteDatabase):
def __init__(self, database: str, **kwargs) -> None:
pragmas = kwargs.pop("pragmas", DATABASE_PRAGMAS)
kwargs.setdefault("check_same_thread", False)
super().__init__(
database,
autoconnect=False,
thread_safe=False,
pragmas=pragmas,
**kwargs,
)
class RoutedSqliteDatabase(SqliteDatabase):
def __init__(self) -> None:
super().__init__(
":managed:",
autoconnect=False,
pragmas=DATABASE_PRAGMAS,
)
def _require_active_database(self) -> ManagedSqliteDatabase:
active_database = _current_database.get()
if active_database is None:
raise RuntimeError(
"Database access requires a database.reader() or database.writer() context."
)
return active_database
def _validate_sql(self, sql: str) -> None:
scope = _current_scope.get()
if scope in {"reader", "reader_conn"} and _WRITE_SQL_PREFIX.match(sql):
raise RuntimeError("Write query attempted inside database.reader() scope.")
def connect(self, reuse_if_open: bool = False): # type: ignore[override]
raise RuntimeError(
"Do not call database.connect() directly; use database.reader() or database.writer()."
)
def close(self): # type: ignore[override]
raise RuntimeError(
"Do not call database.close() directly; use database.reader() or database.writer()."
)
def is_closed(self) -> bool: # type: ignore[override]
active_database = _current_database.get()
return True if active_database is None else active_database.is_closed()
def connection(self): # type: ignore[override]
return self._require_active_database().connection()
def cursor(self, named_cursor=None): # type: ignore[override]
return self._require_active_database().cursor(named_cursor=named_cursor)
def execute_sql(self, sql, params=None): # type: ignore[override]
self._validate_sql(str(sql))
return self._require_active_database().execute_sql(sql, params)
def execute(self, query, **context_options): # type: ignore[override]
return self._require_active_database().execute(query, **context_options)
def atomic(self, *args, **kwargs): # type: ignore[override]
return self._require_active_database().atomic(*args, **kwargs)
def transaction(self, *args, **kwargs): # type: ignore[override]
return self._require_active_database().transaction(*args, **kwargs)
def savepoint(self, *args, **kwargs): # type: ignore[override]
return self._require_active_database().savepoint(*args, **kwargs)
def connection_context(self): # type: ignore[override]
raise RuntimeError(
"Do not call database.connection_context() directly; use database.reader() or database.writer()."
)
def reader(self):
return _require_database_connection().reader()
def writer(self):
return _require_database_connection().writer()
def reader_conn(self):
return _require_database_connection().reader_conn()
def writer_conn(self):
return _require_database_connection().writer_conn()
class DatabaseConnection:
def __init__(
self,
db_path: str | Path,
*,
pool_size: int = 4,
pragmas: dict[str, object] | None = None,
) -> None:
self.db_path = resolve_database_path(db_path)
self.pool_size = pool_size
self.pragmas = dict(DATABASE_PRAGMAS if pragmas is None else pragmas)
self.writer_db = ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas)
self.reader_dbs = tuple(
ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas)
for _ in range(pool_size)
)
self._reader_pool: queue.Queue[ManagedSqliteDatabase] = queue.Queue()
self._writer_lock = threading.RLock()
for reader_db in self.reader_dbs:
self._reader_pool.put(reader_db)
def initialize(self) -> Path:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.writer_db.connect(reuse_if_open=True)
try:
for path in schema_paths():
self.writer_db.connection().executescript(
path.read_text(encoding="utf-8")
)
_run_legacy_migrations(self.writer_db)
except Exception:
self.writer_db.close()
raise
for reader_db in self.reader_dbs:
reader_db.connect(reuse_if_open=True)
return self.db_path
def close(self) -> None:
for reader_db in self.reader_dbs:
if not reader_db.is_closed():
reader_db.close()
if not self.writer_db.is_closed():
self.writer_db.close()
@contextmanager
def reader(self) -> Iterator[ManagedSqliteDatabase]:
scope = _current_scope.get()
if scope in {"reader", "writer"}:
yield _require_active_database()
return
if scope == "reader_conn":
active_database = _require_active_database()
with active_database.atomic():
yield active_database
return
if scope == "writer_conn":
active_database = _require_active_database()
with active_database.atomic():
yield active_database
return
if scope == "writer":
yield _require_active_database()
return
leased_database = self._reader_pool.get()
database_token = _current_database.set(leased_database)
scope_token = _current_scope.set("reader")
try:
with leased_database.atomic():
yield leased_database
finally:
_current_scope.reset(scope_token)
_current_database.reset(database_token)
self._reader_pool.put(leased_database)
@contextmanager
def writer(self) -> Iterator[ManagedSqliteDatabase]:
scope = _current_scope.get()
if scope == "writer":
yield _require_active_database()
return
if scope == "writer_conn":
active_database = _require_active_database()
with active_database.atomic():
yield active_database
return
if scope in {"reader", "reader_conn"}:
raise RuntimeError(
"Cannot enter database.writer() inside database.reader()."
)
with self._writer_lock:
database_token = _current_database.set(self.writer_db)
scope_token = _current_scope.set("writer")
try:
with self.writer_db.atomic():
yield self.writer_db
finally:
_current_scope.reset(scope_token)
_current_database.reset(database_token)
@contextmanager
def reader_conn(self) -> Iterator[ManagedSqliteDatabase]:
scope = _current_scope.get()
if scope is not None:
yield _require_active_database()
return
leased_database = self._reader_pool.get()
database_token = _current_database.set(leased_database)
scope_token = _current_scope.set("reader_conn")
try:
yield leased_database
finally:
_current_scope.reset(scope_token)
_current_database.reset(database_token)
self._reader_pool.put(leased_database)
@contextmanager
def writer_conn(self) -> Iterator[ManagedSqliteDatabase]:
scope = _current_scope.get()
if scope in {"writer", "writer_conn"}:
yield _require_active_database()
return
if scope in {"reader", "reader_conn"}:
raise RuntimeError(
"Cannot enter database.writer_conn() inside database.reader()."
)
with self._writer_lock:
database_token = _current_database.set(self.writer_db)
scope_token = _current_scope.set("writer_conn")
try:
yield self.writer_db
finally:
_current_scope.reset(scope_token)
_current_database.reset(database_token)
def _run_legacy_migrations(database: ManagedSqliteDatabase) -> None:
job_columns = {column.name for column in database.get_columns("job")}
operations = []
migrator = SchemaMigrator.from_database(database)
if "convert_images" not in job_columns:
operations.extend(
(
migrator.add_column(
"job",
"convert_images",
BooleanField(
default=True,
constraints=[Check("convert_images IN (0, 1)")],
),
),
migrator.add_column_default("job", "convert_images", 1),
)
)
if "convert_video" not in job_columns:
operations.extend(
(
migrator.add_column(
"job",
"convert_video",
BooleanField(
default=True,
constraints=[Check("convert_video IN (0, 1)")],
),
),
migrator.add_column_default("job", "convert_video", 1),
)
)
if operations:
with database.atomic():
migrate(*operations)
def initialize_database(db_path: str | Path | None = None) -> Path:
global _database_connection
if _database_connection is not None:
_database_connection.close()
connection = DatabaseConnection(resolve_database_path(db_path))
resolved_path = connection.initialize()
_database_connection = connection
return resolved_path
def get_database_connection() -> DatabaseConnection | None:
return _database_connection
def _require_database_connection() -> DatabaseConnection:
database_connection = get_database_connection()
if database_connection is None:
raise RuntimeError("Database has not been initialized.")
return database_connection
def _require_active_database() -> ManagedSqliteDatabase:
active_database = _current_database.get()
if active_database is None:
raise RuntimeError(
"Database access requires a database.reader() or database.writer() context."
)
return active_database
database = RoutedSqliteDatabase()

View file

@ -309,7 +309,7 @@ def main(argv: list[str] | None = None) -> int:
def _load_job_source_config(*, db_path: str, job_id: int) -> JobSourceConfig:
initialize_database(db_path)
primary_key = getattr(Job, "_meta").primary_key
with database.connection_context():
with database.reader():
job = (
Job.select(Job, Source)
.join(Source)

View file

@ -18,6 +18,7 @@ from apscheduler.triggers.cron import CronTrigger
from peewee import IntegrityError
from repub.config import feed_output_dir, feed_output_path
from repub.db import get_database_connection
from repub.model import (
Job,
JobExecution,
@ -173,7 +174,7 @@ class JobRuntime:
self._started = False
def sync_jobs(self) -> None:
with database.connection_context():
with database.reader():
jobs = tuple(Job.select().where(Job.enabled == True)) # noqa: E712
desired_ids = set()
@ -216,35 +217,34 @@ class JobRuntime:
return execution_id
def _enqueue_job_run_locked(self, job_id: int) -> int | None:
with database.connection_context():
with database.atomic():
job = Job.get_or_none(id=job_id)
if job is None:
return None
with database.writer():
job = Job.get_or_none(id=job_id)
if job is None:
return None
pending_execution = JobExecution.get_or_none(
(JobExecution.job == job)
& (JobExecution.running_status == JobExecutionStatus.PENDING)
)
if pending_execution is not None:
return _execution_id(pending_execution)
try:
execution = JobExecution.create(
job=job,
running_status=JobExecutionStatus.PENDING,
)
except IntegrityError:
pending_execution = JobExecution.get_or_none(
(JobExecution.job == job)
& (JobExecution.running_status == JobExecutionStatus.PENDING)
)
if pending_execution is not None:
return _execution_id(pending_execution)
try:
execution = JobExecution.create(
job=job,
running_status=JobExecutionStatus.PENDING,
)
except IntegrityError:
pending_execution = JobExecution.get_or_none(
(JobExecution.job == job)
& (JobExecution.running_status == JobExecutionStatus.PENDING)
)
return (
_execution_id(pending_execution)
if pending_execution is not None
else None
)
return _execution_id(execution)
return (
_execution_id(pending_execution)
if pending_execution is not None
else None
)
return _execution_id(execution)
def _start_queued_jobs(self) -> None:
with self._run_lock:
@ -259,14 +259,13 @@ class JobRuntime:
if claimed_execution is None:
return
job = cast(Job, claimed_execution.job)
self._start_worker_for_execution(
job_id=_job_id(job),
execution_id=_execution_id(claimed_execution),
job_id=claimed_execution[0],
execution_id=claimed_execution[1],
)
def _claim_next_pending_execution(self) -> JobExecution | None:
with database.connection_context():
def _claim_next_pending_execution(self) -> tuple[int, int] | None:
with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
pending_executions = tuple(
JobExecution.select(JobExecution, Job)
@ -301,10 +300,21 @@ class JobRuntime:
.execute()
)
if claimed:
return JobExecution.get_by_id(_execution_id(execution))
claimed_execution = (
JobExecution.select(JobExecution, Job)
.join(Job)
.where(execution_primary_key == _execution_id(execution))
.get()
)
claimed_job = cast(Job, claimed_execution.job)
return (_job_id(claimed_job), _execution_id(claimed_execution))
return None
def _start_worker_for_execution(self, *, job_id: int, execution_id: int) -> None:
database_connection = get_database_connection()
if database_connection is None:
raise RuntimeError("Database has not been initialized.")
artifacts = JobArtifacts.for_execution(
log_dir=self.log_dir, job_id=job_id, execution_id=execution_id
)
@ -324,7 +334,7 @@ class JobRuntime:
"--execution-id",
str(execution_id),
"--db-path",
str(database.database),
str(database_connection.db_path),
"--out-dir",
str(self.log_dir.parent),
"--stats-path",
@ -342,15 +352,16 @@ class JobRuntime:
)
def _max_concurrent_jobs_reached(self) -> bool:
return (
JobExecution.select()
.where(JobExecution.running_status == JobExecutionStatus.RUNNING)
.count()
>= load_max_concurrent_jobs()
)
with database.reader():
return (
JobExecution.select()
.where(JobExecution.running_status == JobExecutionStatus.RUNNING)
.count()
>= load_max_concurrent_jobs()
)
def request_execution_cancel(self, execution_id: int) -> bool:
with database.connection_context():
with database.writer():
execution = JobExecution.get_or_none(id=execution_id)
if execution is None:
return False
@ -372,7 +383,7 @@ class JobRuntime:
def cancel_queued_execution(self, execution_id: int) -> bool:
with self._run_lock:
with database.connection_context():
with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
deleted = (
JobExecution.delete()
@ -392,7 +403,7 @@ class JobRuntime:
def move_queued_execution(self, execution_id: int, *, direction: str) -> bool:
offset = -1 if direction == "up" else 1
with self._run_lock:
with database.connection_context():
with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
queued_executions = tuple(
JobExecution.select()
@ -425,59 +436,50 @@ class JobRuntime:
cast(datetime | str, target_execution.created_at)
)
with database.atomic():
if current_created_at == target_created_at:
adjusted_created_at = target_created_at + timedelta(
microseconds=-1 if offset < 0 else 1
if current_created_at == target_created_at:
adjusted_created_at = target_created_at + timedelta(
microseconds=-1 if offset < 0 else 1
)
(
JobExecution.update(created_at=adjusted_created_at)
.where(
execution_primary_key == _execution_id(current_execution)
)
(
JobExecution.update(created_at=adjusted_created_at)
.where(
execution_primary_key
== _execution_id(current_execution)
)
.execute()
)
else:
(
JobExecution.update(created_at=target_created_at)
.where(
execution_primary_key
== _execution_id(current_execution)
)
.execute()
)
(
JobExecution.update(created_at=current_created_at)
.where(
execution_primary_key == _execution_id(target_execution)
)
.execute()
.execute()
)
else:
(
JobExecution.update(created_at=target_created_at)
.where(
execution_primary_key == _execution_id(current_execution)
)
.execute()
)
(
JobExecution.update(created_at=current_created_at)
.where(execution_primary_key == _execution_id(target_execution))
.execute()
)
self._trigger_refresh()
return True
def set_job_enabled(self, job_id: int, *, enabled: bool) -> bool:
with database.connection_context():
with database.atomic():
job = Job.get_or_none(id=job_id)
if job is None:
return False
job.enabled = enabled
job.save()
if not enabled:
(
JobExecution.delete()
.where(
(JobExecution.job == job)
& (
JobExecution.running_status
== JobExecutionStatus.PENDING
)
)
.execute()
with database.writer():
job = Job.get_or_none(id=job_id)
if job is None:
return False
job.enabled = enabled
job.save()
if not enabled:
(
JobExecution.delete()
.where(
(JobExecution.job == job)
& (JobExecution.running_status == JobExecutionStatus.PENDING)
)
.execute()
)
self.sync_jobs()
self._trigger_refresh()
return True
@ -493,7 +495,7 @@ class JobRuntime:
continue
self._apply_stats(worker)
with database.connection_context():
with database.writer():
execution = JobExecution.get_by_id(execution_id)
execution.ended_at = utc_now()
execution.running_status = _worker_final_status(
@ -527,7 +529,7 @@ class JobRuntime:
return
stats = json.loads(lines[-1])
with database.connection_context():
with database.writer():
execution = JobExecution.get_by_id(worker.execution_id)
execution.requests_count = int(stats.get("requests_count", 0))
execution.items_count = int(stats.get("items_count", 0))
@ -544,7 +546,7 @@ class JobRuntime:
self._trigger_refresh()
def _enforce_graceful_stop(self, worker: RunningWorker) -> None:
with database.connection_context():
with database.reader():
execution = JobExecution.get_by_id(worker.execution_id)
if execution.stop_requested_at is None:
return
@ -572,16 +574,17 @@ class JobRuntime:
self._trigger_refresh()
def _has_running_executions(self) -> bool:
return (
JobExecution.select()
.where(JobExecution.running_status == JobExecutionStatus.RUNNING)
.exists()
)
with database.reader():
return (
JobExecution.select()
.where(JobExecution.running_status == JobExecutionStatus.RUNNING)
.exists()
)
def _reconcile_stale_executions(self) -> None:
live_workers = _find_live_workers()
recovered_execution_ids: set[int] = set()
with database.connection_context():
with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
if live_workers:
live_executions = tuple(
@ -669,7 +672,7 @@ def load_runs_view(
reference_time = now or datetime.now(UTC)
resolved_log_dir = Path(log_dir)
sanitized_page_size = max(1, completed_page_size)
with database.connection_context():
with database.reader():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
jobs = tuple(Job.select(Job, Source).join(Source).order_by(Source.name.asc()))
queued_executions = tuple(
@ -718,48 +721,52 @@ def load_runs_view(
_job_id(cast(Job, execution.job)): execution
for execution in queued_executions
}
return {
"running": tuple(
_project_running_execution(
execution,
resolved_log_dir,
reference_time,
queued_follow_up=queued_by_job.get(_job_id(cast(Job, execution.job))),
)
for execution in running_executions
),
"queued": tuple(
_project_queued_execution(
execution,
reference_time,
position=position,
total_count=len(queued_executions),
)
for position, execution in enumerate(queued_executions, start=1)
),
"upcoming": tuple(
_project_upcoming_job(
job,
running_by_job.get(job.id),
queued_by_job.get(job.id),
reference_time,
)
for job in jobs
),
"completed": tuple(
_project_completed_execution(execution, resolved_log_dir, reference_time)
for execution in completed_executions
),
"completed_page": sanitized_completed_page,
"completed_page_size": sanitized_page_size,
"completed_total_count": completed_total_count,
"completed_total_pages": completed_total_pages,
}
return {
"running": tuple(
_project_running_execution(
execution,
resolved_log_dir,
reference_time,
queued_follow_up=queued_by_job.get(
_job_id(cast(Job, execution.job))
),
)
for execution in running_executions
),
"queued": tuple(
_project_queued_execution(
execution,
reference_time,
position=position,
total_count=len(queued_executions),
)
for position, execution in enumerate(queued_executions, start=1)
),
"upcoming": tuple(
_project_upcoming_job(
job,
running_by_job.get(job.id),
queued_by_job.get(job.id),
reference_time,
)
for job in jobs
),
"completed": tuple(
_project_completed_execution(
execution, resolved_log_dir, reference_time
)
for execution in completed_executions
),
"completed_page": sanitized_completed_page,
"completed_page_size": sanitized_page_size,
"completed_total_count": completed_total_count,
"completed_total_pages": completed_total_pages,
}
def clear_completed_executions(*, log_dir: str | Path) -> int:
resolved_log_dir = Path(log_dir)
with database.connection_context():
with database.writer():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
completed_executions = tuple(
JobExecution.select(JobExecution, Job)
@ -810,7 +817,7 @@ def load_dashboard_view(
upcoming_by_job_id = {
int(cast(int, job["job_id"])): job for job in runs_view["upcoming"]
}
with database.connection_context():
with database.reader():
jobs = tuple(Job.select(Job, Source).join(Source).order_by(Source.name.asc()))
failed_last_day = (
JobExecution.select()
@ -820,80 +827,85 @@ def load_dashboard_view(
)
.count()
)
upcoming_ready = sum(
1 for job in runs_view["upcoming"] if str(job["run_reason"]) == "Ready"
)
footprint_bytes = _directory_size(output_dir)
return {
"running": runs_view["running"],
"queued": runs_view["queued"],
"source_feeds": tuple(
_project_source_feed(
cast(Job, job),
output_dir,
reference_time,
running_execution=running_by_job_id.get(_job_id(cast(Job, job))),
queued_execution=queued_by_job_id.get(_job_id(cast(Job, job))),
upcoming_job=upcoming_by_job_id.get(_job_id(cast(Job, job))),
)
for job in jobs
),
"snapshot": {
"running_now": str(len(runs_view["running"])),
"upcoming_today": str(upcoming_ready),
"failures_24h": str(failed_last_day),
"artifact_footprint": _format_bytes(footprint_bytes),
},
}
upcoming_ready = sum(
1 for job in runs_view["upcoming"] if str(job["run_reason"]) == "Ready"
)
footprint_bytes = _directory_size(output_dir)
return {
"running": runs_view["running"],
"queued": runs_view["queued"],
"source_feeds": tuple(
_project_source_feed(
cast(Job, job),
output_dir,
reference_time,
running_execution=running_by_job_id.get(_job_id(cast(Job, job))),
queued_execution=queued_by_job_id.get(_job_id(cast(Job, job))),
upcoming_job=upcoming_by_job_id.get(_job_id(cast(Job, job))),
)
for job in jobs
),
"snapshot": {
"running_now": str(len(runs_view["running"])),
"upcoming_today": str(upcoming_ready),
"failures_24h": str(failed_last_day),
"artifact_footprint": _format_bytes(footprint_bytes),
},
}
def load_execution_log_view(
*, log_dir: str | Path, job_id: int, execution_id: int
) -> ExecutionLogView:
with database.connection_context():
execution = JobExecution.get_or_none(id=execution_id)
route = f"/job/{job_id}/execution/{execution_id}/logs"
if execution is None or _job_id(cast(Job, execution.job)) != job_id:
return ExecutionLogView(
job_id=job_id,
execution_id=execution_id,
title=f"Job {job_id} / execution {execution_id}",
description="Plain text log view routed through the app.",
status_label="Unavailable",
status_tone="failed",
log_text="",
error_message="Execution does not exist.",
with database.reader():
execution_primary_key = getattr(JobExecution, "_meta").primary_key
execution = (
JobExecution.select(JobExecution, Job)
.join(Job)
.where(execution_primary_key == execution_id)
.get_or_none()
)
artifacts = JobArtifacts.for_execution(
log_dir=Path(log_dir),
job_id=job_id,
execution_id=execution_id,
)
if not artifacts.log_path.exists():
if execution is None or _job_id(cast(Job, execution.job)) != job_id:
return ExecutionLogView(
job_id=job_id,
execution_id=execution_id,
title=f"Job {job_id} / execution {execution_id}",
description="Plain text log view routed through the app.",
status_label="Unavailable",
status_tone="failed",
log_text="",
error_message="Execution does not exist.",
)
artifacts = JobArtifacts.for_execution(
log_dir=Path(log_dir),
job_id=job_id,
execution_id=execution_id,
)
if not artifacts.log_path.exists():
return ExecutionLogView(
job_id=job_id,
execution_id=execution_id,
title=f"Job {job_id} / execution {execution_id}",
description="Plain text log view routed through the app.",
status_label=_execution_status_label(execution),
status_tone=_execution_status_tone(execution),
log_text="",
error_message="Log file has not been created yet.",
)
return ExecutionLogView(
job_id=job_id,
execution_id=execution_id,
title=f"Job {job_id} / execution {execution_id}",
description="Plain text log view routed through the app.",
description=f"Route: {route}",
status_label=_execution_status_label(execution),
status_tone=_execution_status_tone(execution),
log_text="",
error_message="Log file has not been created yet.",
log_text=artifacts.log_path.read_text(encoding="utf-8"),
)
return ExecutionLogView(
job_id=job_id,
execution_id=execution_id,
title=f"Job {job_id} / execution {execution_id}",
description=f"Route: {route}",
status_label=_execution_status_label(execution),
status_tone=_execution_status_tone(execution),
log_text=artifacts.log_path.read_text(encoding="utf-8"),
)
def _job_trigger(job: Job) -> CronTrigger:
expression = " ".join(

View file

@ -1,12 +1,8 @@
from __future__ import annotations
import json
import os
from datetime import UTC, datetime
from enum import IntEnum
from importlib import resources
from importlib.resources.abc import Traversable
from pathlib import Path
from typing import Any
from peewee import (
@ -16,29 +12,24 @@ from peewee import (
ForeignKeyField,
IntegerField,
Model,
SqliteDatabase,
TextField,
)
from playhouse.migrate import SchemaMigrator, migrate
DEFAULT_DB_PATH = Path("republisher.db")
DATABASE_PRAGMAS = {
"busy_timeout": 5000,
"cache_size": 15625,
"foreign_keys": 1,
"journal_mode": "wal",
"page_size": 4096,
"synchronous": "normal",
"temp_store": "memory",
}
SCHEMA_GLOB = "*.sql"
from repub import db as db_module
DEFAULT_DB_PATH = db_module.DEFAULT_DB_PATH
DATABASE_PRAGMAS = db_module.DATABASE_PRAGMAS
SCHEMA_GLOB = db_module.SCHEMA_GLOB
database = db_module.database
initialize_database = db_module.initialize_database
resolve_database_path = db_module.resolve_database_path
schema_paths = db_module.schema_paths
MAX_CONCURRENT_JOBS_SETTING_KEY = "max_concurrent_jobs"
DEFAULT_MAX_CONCURRENT_JOBS = 1
FEED_URL_SETTING_KEY = "feed_url"
DEFAULT_FEED_URL = ""
database = SqliteDatabase(None, pragmas=DATABASE_PRAGMAS)
class JobExecutionStatus(IntEnum):
PENDING = 0
@ -52,101 +43,24 @@ def utc_now() -> datetime:
return datetime.now(UTC)
def resolve_database_path(db_path: str | Path | None = None) -> Path:
raw_value = (
os.environ.get("REPUBLISHER_DB_PATH", DEFAULT_DB_PATH)
if db_path is None
else db_path
)
raw_path = Path(raw_value)
return raw_path.expanduser().resolve()
def schema_paths() -> tuple[Traversable, ...]:
schema_dir = resources.files("repub").joinpath("sql")
return tuple(
sorted(
(path for path in schema_dir.iterdir() if path.name.endswith(".sql")),
key=lambda path: path.name,
)
)
def initialize_database(db_path: str | Path | None = None) -> Path:
resolved_path = resolve_database_path(db_path)
resolved_path.parent.mkdir(parents=True, exist_ok=True)
if not database.is_closed():
database.close()
database.init(str(resolved_path), pragmas=DATABASE_PRAGMAS)
database.connect(reuse_if_open=True)
try:
for path in schema_paths():
database.connection().executescript(path.read_text(encoding="utf-8"))
_run_legacy_migrations()
finally:
database.close()
return resolved_path
def _run_legacy_migrations() -> None:
job_columns = {column.name for column in database.get_columns("job")}
operations = []
migrator = SchemaMigrator.from_database(database)
if "convert_images" not in job_columns:
operations.extend(
(
migrator.add_column(
"job",
"convert_images",
BooleanField(
default=True,
constraints=[Check("convert_images IN (0, 1)")],
),
),
migrator.add_column_default("job", "convert_images", 1),
)
)
if "convert_video" not in job_columns:
operations.extend(
(
migrator.add_column(
"job",
"convert_video",
BooleanField(
default=True,
constraints=[Check("convert_video IN (0, 1)")],
),
),
migrator.add_column_default("job", "convert_video", 1),
)
)
if operations:
with database.atomic():
migrate(*operations)
def source_slug_exists(slug: str) -> bool:
with database.connection_context():
with database.reader():
return Source.select().where(Source.slug == slug).exists()
def save_setting(key: str, value: Any) -> None:
payload = json.dumps(value, sort_keys=True)
with database.connection_context():
with database.atomic():
setting = AppSetting.get_or_none(AppSetting.key == key)
if setting is None:
AppSetting.create(key=key, value=payload)
return
setting.value = payload
setting.save()
with database.writer():
setting = AppSetting.get_or_none(AppSetting.key == key)
if setting is None:
AppSetting.create(key=key, value=payload)
return
setting.value = payload
setting.save()
def load_setting(key: str, default: Any) -> Any:
with database.connection_context():
with database.reader():
setting = AppSetting.get_or_none(AppSetting.key == key)
if setting is None:
return default
@ -177,8 +91,14 @@ def load_settings_form() -> dict[str, object]:
}
def load_job_enabled(job_id: int) -> bool | None:
with database.reader():
job = Job.get_or_none(id=job_id)
return None if job is None else job.enabled
def load_source_form(slug: str) -> dict[str, object] | None:
with database.connection_context():
with database.reader():
source = Source.get_or_none(Source.slug == slug)
if source is None:
return None
@ -259,46 +179,45 @@ def create_source(
include_content: bool = True,
content_format: str = "",
) -> Source:
with database.connection_context():
with database.atomic():
source = Source.create(
name=name,
slug=slug,
source_type=source_type,
notes=notes,
)
if source_type == "feed":
SourceFeed.create(
source=source,
feed_url=feed_url,
)
else:
SourcePangea.create(
source=source,
domain=pangea_domain,
category_name=pangea_category,
content_type=content_type,
only_newest=only_newest,
max_articles=max_articles,
oldest_article=oldest_article,
include_authors=include_authors,
exclude_media=exclude_media,
include_content=include_content,
content_format=content_format,
)
Job.create(
with database.writer():
source = Source.create(
name=name,
slug=slug,
source_type=source_type,
notes=notes,
)
if source_type == "feed":
SourceFeed.create(
source=source,
enabled=enabled,
convert_images=convert_images,
convert_video=convert_video,
spider_arguments=spider_arguments,
cron_minute=cron_minute,
cron_hour=cron_hour,
cron_day_of_month=cron_day_of_month,
cron_day_of_week=cron_day_of_week,
cron_month=cron_month,
feed_url=feed_url,
)
return source
else:
SourcePangea.create(
source=source,
domain=pangea_domain,
category_name=pangea_category,
content_type=content_type,
only_newest=only_newest,
max_articles=max_articles,
oldest_article=oldest_article,
include_authors=include_authors,
exclude_media=exclude_media,
include_content=include_content,
content_format=content_format,
)
Job.create(
source=source,
enabled=enabled,
convert_images=convert_images,
convert_video=convert_video,
spider_arguments=spider_arguments,
cron_minute=cron_minute,
cron_hour=cron_hour,
cron_day_of_month=cron_day_of_month,
cron_day_of_week=cron_day_of_week,
cron_month=cron_month,
)
return source
def update_source(
@ -329,91 +248,88 @@ def update_source(
include_content: bool = True,
content_format: str = "",
) -> Source | None:
with database.connection_context():
with database.atomic():
source = Source.get_or_none(Source.slug == source_slug)
if source is None:
return None
with database.writer():
source = Source.get_or_none(Source.slug == source_slug)
if source is None:
return None
source.name = name
source.notes = notes
source.source_type = source_type
source.save()
source.name = name
source.notes = notes
source.source_type = source_type
source.save()
job = Job.get(Job.source == source)
job.enabled = enabled
job.convert_images = convert_images
job.convert_video = convert_video
job.spider_arguments = spider_arguments
job.cron_minute = cron_minute
job.cron_hour = cron_hour
job.cron_day_of_month = cron_day_of_month
job.cron_day_of_week = cron_day_of_week
job.cron_month = cron_month
job.save()
job = Job.get(Job.source == source)
job.enabled = enabled
job.convert_images = convert_images
job.convert_video = convert_video
job.spider_arguments = spider_arguments
job.cron_minute = cron_minute
job.cron_hour = cron_hour
job.cron_day_of_month = cron_day_of_month
job.cron_day_of_week = cron_day_of_week
job.cron_month = cron_month
job.save()
if source_type == "feed":
SourcePangea.delete().where(SourcePangea.source == source).execute()
feed = SourceFeed.get_or_none(SourceFeed.source == source)
if feed is None:
SourceFeed.create(source=source, feed_url=feed_url)
else:
feed.feed_url = feed_url
feed.save()
if source_type == "feed":
SourcePangea.delete().where(SourcePangea.source == source).execute()
feed = SourceFeed.get_or_none(SourceFeed.source == source)
if feed is None:
SourceFeed.create(source=source, feed_url=feed_url)
else:
SourceFeed.delete().where(SourceFeed.source == source).execute()
pangea = SourcePangea.get_or_none(SourcePangea.source == source)
if pangea is None:
SourcePangea.create(
source=source,
domain=pangea_domain,
category_name=pangea_category,
content_type=content_type,
only_newest=only_newest,
max_articles=max_articles,
oldest_article=oldest_article,
include_authors=include_authors,
exclude_media=exclude_media,
include_content=include_content,
content_format=content_format,
)
else:
pangea.domain = pangea_domain
pangea.category_name = pangea_category
pangea.content_type = content_type
pangea.only_newest = only_newest
pangea.max_articles = max_articles
pangea.oldest_article = oldest_article
pangea.include_authors = include_authors
pangea.exclude_media = exclude_media
pangea.include_content = include_content
pangea.content_format = content_format
pangea.save()
feed.feed_url = feed_url
feed.save()
else:
SourceFeed.delete().where(SourceFeed.source == source).execute()
pangea = SourcePangea.get_or_none(SourcePangea.source == source)
if pangea is None:
SourcePangea.create(
source=source,
domain=pangea_domain,
category_name=pangea_category,
content_type=content_type,
only_newest=only_newest,
max_articles=max_articles,
oldest_article=oldest_article,
include_authors=include_authors,
exclude_media=exclude_media,
include_content=include_content,
content_format=content_format,
)
else:
pangea.domain = pangea_domain
pangea.category_name = pangea_category
pangea.content_type = content_type
pangea.only_newest = only_newest
pangea.max_articles = max_articles
pangea.oldest_article = oldest_article
pangea.include_authors = include_authors
pangea.exclude_media = exclude_media
pangea.include_content = include_content
pangea.content_format = content_format
pangea.save()
return source
return source
def delete_job_source(job_id: int) -> bool:
with database.connection_context():
with database.atomic():
job = Job.get_or_none(id=job_id)
if job is None:
return False
source = Source.get_by_id(job.source_id)
return source.delete_instance() > 0
with database.writer():
job = Job.get_or_none(id=job_id)
if job is None:
return False
source = Source.get_by_id(job.source_id)
return source.delete_instance() > 0
def delete_source(slug: str) -> bool:
with database.connection_context():
with database.atomic():
source = Source.get_or_none(Source.slug == slug)
if source is None:
return False
return source.delete_instance() > 0
with database.writer():
source = Source.get_or_none(Source.slug == slug)
if source is None:
return False
return source.delete_instance() > 0
def load_sources() -> tuple[dict[str, object], ...]:
with database.connection_context():
with database.reader():
sources = tuple(Source.select().order_by(Source.created_at.desc()))
source_ids = tuple(int(source.get_id()) for source in sources)
if not source_ids:

View file

@ -27,11 +27,11 @@ from repub.jobs import (
load_runs_view,
)
from repub.model import (
Job,
create_source,
delete_job_source,
delete_source,
initialize_database,
load_job_enabled,
load_settings_form,
load_source_form,
load_sources,
@ -329,9 +329,9 @@ def create_app(*, dev_mode: bool = False) -> Quart:
@app.post("/actions/jobs/<int:job_id>/toggle-enabled")
async def toggle_job_enabled_action(job_id: int) -> Response:
job = Job.get_or_none(id=job_id)
if job is not None:
get_job_runtime(app).set_job_enabled(job_id, enabled=not job.enabled)
enabled = load_job_enabled(job_id)
if enabled is not None:
get_job_runtime(app).set_job_enabled(job_id, enabled=not enabled)
trigger_refresh(app)
return Response(status=204)