republisher/repub/db.py

359 lines
12 KiB
Python

from __future__ import annotations
import os
import queue
import re
import threading
from contextlib import contextmanager
from contextvars import ContextVar
from importlib import resources
from importlib.resources.abc import Traversable
from pathlib import Path
from typing import Iterator
from peewee import BooleanField, Check, SqliteDatabase
from playhouse.migrate import SchemaMigrator, migrate
DEFAULT_DB_PATH = Path("republisher.db")
DATABASE_PRAGMAS = {
"busy_timeout": 5000,
"cache_size": 15625,
"foreign_keys": 1,
"journal_mode": "wal",
"page_size": 4096,
"synchronous": "normal",
"temp_store": "memory",
}
SCHEMA_GLOB = "*.sql"
_WRITE_SQL_PREFIX = re.compile(
r"^\s*(INSERT|UPDATE|DELETE|REPLACE|CREATE|DROP|ALTER)\b"
)
_current_database: ContextVar[ManagedSqliteDatabase | None] = ContextVar(
"repub_current_database",
default=None,
)
_current_scope: ContextVar[str | None] = ContextVar(
"repub_current_database_scope",
default=None,
)
_database_connection: DatabaseConnection | None = None
def resolve_database_path(db_path: str | Path | None = None) -> Path:
raw_value = (
os.environ.get("REPUBLISHER_DB_PATH", DEFAULT_DB_PATH)
if db_path is None
else db_path
)
return Path(raw_value).expanduser().resolve()
def schema_paths() -> tuple[Traversable, ...]:
schema_dir = resources.files("repub").joinpath("sql")
return tuple(
sorted(
(path for path in schema_dir.iterdir() if path.name.endswith(".sql")),
key=lambda path: path.name,
)
)
class ManagedSqliteDatabase(SqliteDatabase):
def __init__(self, database: str, **kwargs) -> None:
pragmas = kwargs.pop("pragmas", DATABASE_PRAGMAS)
kwargs.setdefault("check_same_thread", False)
super().__init__(
database,
autoconnect=False,
thread_safe=False,
pragmas=pragmas,
**kwargs,
)
class RoutedSqliteDatabase(SqliteDatabase):
def __init__(self) -> None:
super().__init__(
":managed:",
autoconnect=False,
pragmas=DATABASE_PRAGMAS,
)
def _require_active_database(self) -> ManagedSqliteDatabase:
active_database = _current_database.get()
if active_database is None:
raise RuntimeError(
"Database access requires a database.reader() or database.writer() context."
)
return active_database
def _validate_sql(self, sql: str) -> None:
scope = _current_scope.get()
if scope in {"reader", "reader_conn"} and _WRITE_SQL_PREFIX.match(sql):
raise RuntimeError("Write query attempted inside database.reader() scope.")
def connect(self, reuse_if_open: bool = False): # type: ignore[override]
raise RuntimeError(
"Do not call database.connect() directly; use database.reader() or database.writer()."
)
def close(self): # type: ignore[override]
raise RuntimeError(
"Do not call database.close() directly; use database.reader() or database.writer()."
)
def is_closed(self) -> bool: # type: ignore[override]
active_database = _current_database.get()
return True if active_database is None else active_database.is_closed()
def connection(self): # type: ignore[override]
return self._require_active_database().connection()
def cursor(self, named_cursor=None): # type: ignore[override]
return self._require_active_database().cursor(named_cursor=named_cursor)
def execute_sql(self, sql, params=None): # type: ignore[override]
self._validate_sql(str(sql))
return self._require_active_database().execute_sql(sql, params)
def execute(self, query, **context_options): # type: ignore[override]
return self._require_active_database().execute(query, **context_options)
def atomic(self, *args, **kwargs): # type: ignore[override]
return self._require_active_database().atomic(*args, **kwargs)
def transaction(self, *args, **kwargs): # type: ignore[override]
return self._require_active_database().transaction(*args, **kwargs)
def savepoint(self, *args, **kwargs): # type: ignore[override]
return self._require_active_database().savepoint(*args, **kwargs)
def connection_context(self): # type: ignore[override]
raise RuntimeError(
"Do not call database.connection_context() directly; use database.reader() or database.writer()."
)
def reader(self):
return _require_database_connection().reader()
def writer(self):
return _require_database_connection().writer()
def reader_conn(self):
return _require_database_connection().reader_conn()
def writer_conn(self):
return _require_database_connection().writer_conn()
class DatabaseConnection:
def __init__(
self,
db_path: str | Path,
*,
pool_size: int = 4,
pragmas: dict[str, object] | None = None,
) -> None:
self.db_path = resolve_database_path(db_path)
self.pool_size = pool_size
self.pragmas = dict(DATABASE_PRAGMAS if pragmas is None else pragmas)
self.writer_db = ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas)
self.reader_dbs = tuple(
ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas)
for _ in range(pool_size)
)
self._reader_pool: queue.Queue[ManagedSqliteDatabase] = queue.Queue()
self._writer_lock = threading.RLock()
for reader_db in self.reader_dbs:
self._reader_pool.put(reader_db)
def initialize(self) -> Path:
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.writer_db.connect(reuse_if_open=True)
try:
for path in schema_paths():
self.writer_db.connection().executescript(
path.read_text(encoding="utf-8")
)
_run_legacy_migrations(self.writer_db)
except Exception:
self.writer_db.close()
raise
for reader_db in self.reader_dbs:
reader_db.connect(reuse_if_open=True)
return self.db_path
def close(self) -> None:
for reader_db in self.reader_dbs:
if not reader_db.is_closed():
reader_db.close()
if not self.writer_db.is_closed():
self.writer_db.close()
@contextmanager
def reader(self) -> Iterator[ManagedSqliteDatabase]:
scope = _current_scope.get()
if scope in {"reader", "writer"}:
yield _require_active_database()
return
if scope == "reader_conn":
active_database = _require_active_database()
with active_database.atomic():
yield active_database
return
if scope == "writer_conn":
active_database = _require_active_database()
with active_database.atomic():
yield active_database
return
if scope == "writer":
yield _require_active_database()
return
leased_database = self._reader_pool.get()
database_token = _current_database.set(leased_database)
scope_token = _current_scope.set("reader")
try:
with leased_database.atomic():
yield leased_database
finally:
_current_scope.reset(scope_token)
_current_database.reset(database_token)
self._reader_pool.put(leased_database)
@contextmanager
def writer(self) -> Iterator[ManagedSqliteDatabase]:
scope = _current_scope.get()
if scope == "writer":
yield _require_active_database()
return
if scope == "writer_conn":
active_database = _require_active_database()
with active_database.atomic():
yield active_database
return
if scope in {"reader", "reader_conn"}:
raise RuntimeError(
"Cannot enter database.writer() inside database.reader()."
)
with self._writer_lock:
database_token = _current_database.set(self.writer_db)
scope_token = _current_scope.set("writer")
try:
with self.writer_db.atomic():
yield self.writer_db
finally:
_current_scope.reset(scope_token)
_current_database.reset(database_token)
@contextmanager
def reader_conn(self) -> Iterator[ManagedSqliteDatabase]:
scope = _current_scope.get()
if scope is not None:
yield _require_active_database()
return
leased_database = self._reader_pool.get()
database_token = _current_database.set(leased_database)
scope_token = _current_scope.set("reader_conn")
try:
yield leased_database
finally:
_current_scope.reset(scope_token)
_current_database.reset(database_token)
self._reader_pool.put(leased_database)
@contextmanager
def writer_conn(self) -> Iterator[ManagedSqliteDatabase]:
scope = _current_scope.get()
if scope in {"writer", "writer_conn"}:
yield _require_active_database()
return
if scope in {"reader", "reader_conn"}:
raise RuntimeError(
"Cannot enter database.writer_conn() inside database.reader()."
)
with self._writer_lock:
database_token = _current_database.set(self.writer_db)
scope_token = _current_scope.set("writer_conn")
try:
yield self.writer_db
finally:
_current_scope.reset(scope_token)
_current_database.reset(database_token)
def _run_legacy_migrations(database: ManagedSqliteDatabase) -> None:
job_columns = {column.name for column in database.get_columns("job")}
operations = []
migrator = SchemaMigrator.from_database(database)
if "convert_images" not in job_columns:
operations.extend(
(
migrator.add_column(
"job",
"convert_images",
BooleanField(
default=True,
constraints=[Check("convert_images IN (0, 1)")],
),
),
migrator.add_column_default("job", "convert_images", 1),
)
)
if "convert_video" not in job_columns:
operations.extend(
(
migrator.add_column(
"job",
"convert_video",
BooleanField(
default=True,
constraints=[Check("convert_video IN (0, 1)")],
),
),
migrator.add_column_default("job", "convert_video", 1),
)
)
if operations:
with database.atomic():
migrate(*operations)
def initialize_database(db_path: str | Path | None = None) -> Path:
global _database_connection
if _database_connection is not None:
_database_connection.close()
connection = DatabaseConnection(resolve_database_path(db_path))
resolved_path = connection.initialize()
_database_connection = connection
return resolved_path
def get_database_connection() -> DatabaseConnection | None:
return _database_connection
def _require_database_connection() -> DatabaseConnection:
database_connection = get_database_connection()
if database_connection is None:
raise RuntimeError("Database has not been initialized.")
return database_connection
def _require_active_database() -> ManagedSqliteDatabase:
active_database = _current_database.get()
if active_database is None:
raise RuntimeError(
"Database access requires a database.reader() or database.writer() context."
)
return active_database
database = RoutedSqliteDatabase()