from __future__ import annotations import os import queue import re import threading from contextlib import contextmanager from contextvars import ContextVar from importlib import resources from importlib.resources.abc import Traversable from pathlib import Path from typing import Iterator from peewee import BooleanField, Check, SqliteDatabase from playhouse.migrate import SchemaMigrator, migrate DEFAULT_DB_PATH = Path("republisher.db") DATABASE_PRAGMAS = { "busy_timeout": 5000, "cache_size": 15625, "foreign_keys": 1, "journal_mode": "wal", "page_size": 4096, "synchronous": "normal", "temp_store": "memory", } SCHEMA_GLOB = "*.sql" _WRITE_SQL_PREFIX = re.compile( r"^\s*(INSERT|UPDATE|DELETE|REPLACE|CREATE|DROP|ALTER)\b" ) _current_database: ContextVar[ManagedSqliteDatabase | None] = ContextVar( "repub_current_database", default=None, ) _current_scope: ContextVar[str | None] = ContextVar( "repub_current_database_scope", default=None, ) _database_connection: DatabaseConnection | None = None def resolve_database_path(db_path: str | Path | None = None) -> Path: raw_value = ( os.environ.get("REPUBLISHER_DB_PATH", DEFAULT_DB_PATH) if db_path is None else db_path ) return Path(raw_value).expanduser().resolve() def schema_paths() -> tuple[Traversable, ...]: schema_dir = resources.files("repub").joinpath("sql") return tuple( sorted( (path for path in schema_dir.iterdir() if path.name.endswith(".sql")), key=lambda path: path.name, ) ) class ManagedSqliteDatabase(SqliteDatabase): def __init__(self, database: str, **kwargs) -> None: pragmas = kwargs.pop("pragmas", DATABASE_PRAGMAS) kwargs.setdefault("check_same_thread", False) super().__init__( database, autoconnect=False, thread_safe=False, pragmas=pragmas, **kwargs, ) class RoutedSqliteDatabase(SqliteDatabase): def __init__(self) -> None: super().__init__( ":managed:", autoconnect=False, pragmas=DATABASE_PRAGMAS, ) def _require_active_database(self) -> ManagedSqliteDatabase: active_database = _current_database.get() if active_database is None: raise RuntimeError( "Database access requires a database.reader() or database.writer() context." ) return active_database def _validate_sql(self, sql: str) -> None: scope = _current_scope.get() if scope in {"reader", "reader_conn"} and _WRITE_SQL_PREFIX.match(sql): raise RuntimeError("Write query attempted inside database.reader() scope.") def connect(self, reuse_if_open: bool = False): # type: ignore[override] raise RuntimeError( "Do not call database.connect() directly; use database.reader() or database.writer()." ) def close(self): # type: ignore[override] raise RuntimeError( "Do not call database.close() directly; use database.reader() or database.writer()." ) def is_closed(self) -> bool: # type: ignore[override] active_database = _current_database.get() return True if active_database is None else active_database.is_closed() def connection(self): # type: ignore[override] return self._require_active_database().connection() def cursor(self, named_cursor=None): # type: ignore[override] return self._require_active_database().cursor(named_cursor=named_cursor) def execute_sql(self, sql, params=None): # type: ignore[override] self._validate_sql(str(sql)) return self._require_active_database().execute_sql(sql, params) def execute(self, query, **context_options): # type: ignore[override] return self._require_active_database().execute(query, **context_options) def atomic(self, *args, **kwargs): # type: ignore[override] return self._require_active_database().atomic(*args, **kwargs) def transaction(self, *args, **kwargs): # type: ignore[override] return self._require_active_database().transaction(*args, **kwargs) def savepoint(self, *args, **kwargs): # type: ignore[override] return self._require_active_database().savepoint(*args, **kwargs) def connection_context(self): # type: ignore[override] raise RuntimeError( "Do not call database.connection_context() directly; use database.reader() or database.writer()." ) def reader(self): return _require_database_connection().reader() def writer(self): return _require_database_connection().writer() def reader_conn(self): return _require_database_connection().reader_conn() def writer_conn(self): return _require_database_connection().writer_conn() class DatabaseConnection: def __init__( self, db_path: str | Path, *, pool_size: int = 4, pragmas: dict[str, object] | None = None, ) -> None: self.db_path = resolve_database_path(db_path) self.pool_size = pool_size self.pragmas = dict(DATABASE_PRAGMAS if pragmas is None else pragmas) self.writer_db = ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas) self.reader_dbs = tuple( ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas) for _ in range(pool_size) ) self._reader_pool: queue.Queue[ManagedSqliteDatabase] = queue.Queue() self._writer_lock = threading.RLock() for reader_db in self.reader_dbs: self._reader_pool.put(reader_db) def initialize(self) -> Path: self.db_path.parent.mkdir(parents=True, exist_ok=True) self.writer_db.connect(reuse_if_open=True) try: for path in schema_paths(): self.writer_db.connection().executescript( path.read_text(encoding="utf-8") ) _run_legacy_migrations(self.writer_db) except Exception: self.writer_db.close() raise for reader_db in self.reader_dbs: reader_db.connect(reuse_if_open=True) return self.db_path def close(self) -> None: for reader_db in self.reader_dbs: if not reader_db.is_closed(): reader_db.close() if not self.writer_db.is_closed(): self.writer_db.close() @contextmanager def reader(self) -> Iterator[ManagedSqliteDatabase]: scope = _current_scope.get() if scope in {"reader", "writer"}: yield _require_active_database() return if scope == "reader_conn": active_database = _require_active_database() with active_database.atomic(): yield active_database return if scope == "writer_conn": active_database = _require_active_database() with active_database.atomic(): yield active_database return if scope == "writer": yield _require_active_database() return leased_database = self._reader_pool.get() database_token = _current_database.set(leased_database) scope_token = _current_scope.set("reader") try: with leased_database.atomic(): yield leased_database finally: _current_scope.reset(scope_token) _current_database.reset(database_token) self._reader_pool.put(leased_database) @contextmanager def writer(self) -> Iterator[ManagedSqliteDatabase]: scope = _current_scope.get() if scope == "writer": yield _require_active_database() return if scope == "writer_conn": active_database = _require_active_database() with active_database.atomic(): yield active_database return if scope in {"reader", "reader_conn"}: raise RuntimeError( "Cannot enter database.writer() inside database.reader()." ) with self._writer_lock: database_token = _current_database.set(self.writer_db) scope_token = _current_scope.set("writer") try: with self.writer_db.atomic(): yield self.writer_db finally: _current_scope.reset(scope_token) _current_database.reset(database_token) @contextmanager def reader_conn(self) -> Iterator[ManagedSqliteDatabase]: scope = _current_scope.get() if scope is not None: yield _require_active_database() return leased_database = self._reader_pool.get() database_token = _current_database.set(leased_database) scope_token = _current_scope.set("reader_conn") try: yield leased_database finally: _current_scope.reset(scope_token) _current_database.reset(database_token) self._reader_pool.put(leased_database) @contextmanager def writer_conn(self) -> Iterator[ManagedSqliteDatabase]: scope = _current_scope.get() if scope in {"writer", "writer_conn"}: yield _require_active_database() return if scope in {"reader", "reader_conn"}: raise RuntimeError( "Cannot enter database.writer_conn() inside database.reader()." ) with self._writer_lock: database_token = _current_database.set(self.writer_db) scope_token = _current_scope.set("writer_conn") try: yield self.writer_db finally: _current_scope.reset(scope_token) _current_database.reset(database_token) def _run_legacy_migrations(database: ManagedSqliteDatabase) -> None: job_columns = {column.name for column in database.get_columns("job")} operations = [] migrator = SchemaMigrator.from_database(database) if "convert_images" not in job_columns: operations.extend( ( migrator.add_column( "job", "convert_images", BooleanField( default=True, constraints=[Check("convert_images IN (0, 1)")], ), ), migrator.add_column_default("job", "convert_images", 1), ) ) if "convert_video" not in job_columns: operations.extend( ( migrator.add_column( "job", "convert_video", BooleanField( default=True, constraints=[Check("convert_video IN (0, 1)")], ), ), migrator.add_column_default("job", "convert_video", 1), ) ) if operations: with database.atomic(): migrate(*operations) def initialize_database(db_path: str | Path | None = None) -> Path: global _database_connection if _database_connection is not None: _database_connection.close() connection = DatabaseConnection(resolve_database_path(db_path)) resolved_path = connection.initialize() _database_connection = connection return resolved_path def get_database_connection() -> DatabaseConnection | None: return _database_connection def _require_database_connection() -> DatabaseConnection: database_connection = get_database_connection() if database_connection is None: raise RuntimeError("Database has not been initialized.") return database_connection def _require_active_database() -> ManagedSqliteDatabase: active_database = _current_database.get() if active_database is None: raise RuntimeError( "Database access requires a database.reader() or database.writer() context." ) return active_database database = RoutedSqliteDatabase()