359 lines
12 KiB
Python
359 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import queue
|
|
import re
|
|
import threading
|
|
from contextlib import contextmanager
|
|
from contextvars import ContextVar
|
|
from importlib import resources
|
|
from importlib.resources.abc import Traversable
|
|
from pathlib import Path
|
|
from typing import Iterator
|
|
|
|
from peewee import BooleanField, Check, SqliteDatabase
|
|
from playhouse.migrate import SchemaMigrator, migrate
|
|
|
|
DEFAULT_DB_PATH = Path("republisher.db")
|
|
DATABASE_PRAGMAS = {
|
|
"busy_timeout": 5000,
|
|
"cache_size": 15625,
|
|
"foreign_keys": 1,
|
|
"journal_mode": "wal",
|
|
"page_size": 4096,
|
|
"synchronous": "normal",
|
|
"temp_store": "memory",
|
|
}
|
|
SCHEMA_GLOB = "*.sql"
|
|
|
|
_WRITE_SQL_PREFIX = re.compile(
|
|
r"^\s*(INSERT|UPDATE|DELETE|REPLACE|CREATE|DROP|ALTER)\b"
|
|
)
|
|
_current_database: ContextVar[ManagedSqliteDatabase | None] = ContextVar(
|
|
"repub_current_database",
|
|
default=None,
|
|
)
|
|
_current_scope: ContextVar[str | None] = ContextVar(
|
|
"repub_current_database_scope",
|
|
default=None,
|
|
)
|
|
_database_connection: DatabaseConnection | None = None
|
|
|
|
|
|
def resolve_database_path(db_path: str | Path | None = None) -> Path:
|
|
raw_value = (
|
|
os.environ.get("REPUBLISHER_DB_PATH", DEFAULT_DB_PATH)
|
|
if db_path is None
|
|
else db_path
|
|
)
|
|
return Path(raw_value).expanduser().resolve()
|
|
|
|
|
|
def schema_paths() -> tuple[Traversable, ...]:
|
|
schema_dir = resources.files("repub").joinpath("sql")
|
|
return tuple(
|
|
sorted(
|
|
(path for path in schema_dir.iterdir() if path.name.endswith(".sql")),
|
|
key=lambda path: path.name,
|
|
)
|
|
)
|
|
|
|
|
|
class ManagedSqliteDatabase(SqliteDatabase):
|
|
def __init__(self, database: str, **kwargs) -> None:
|
|
pragmas = kwargs.pop("pragmas", DATABASE_PRAGMAS)
|
|
kwargs.setdefault("check_same_thread", False)
|
|
super().__init__(
|
|
database,
|
|
autoconnect=False,
|
|
thread_safe=False,
|
|
pragmas=pragmas,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
class RoutedSqliteDatabase(SqliteDatabase):
|
|
def __init__(self) -> None:
|
|
super().__init__(
|
|
":managed:",
|
|
autoconnect=False,
|
|
pragmas=DATABASE_PRAGMAS,
|
|
)
|
|
|
|
def _require_active_database(self) -> ManagedSqliteDatabase:
|
|
active_database = _current_database.get()
|
|
if active_database is None:
|
|
raise RuntimeError(
|
|
"Database access requires a database.reader() or database.writer() context."
|
|
)
|
|
return active_database
|
|
|
|
def _validate_sql(self, sql: str) -> None:
|
|
scope = _current_scope.get()
|
|
if scope in {"reader", "reader_conn"} and _WRITE_SQL_PREFIX.match(sql):
|
|
raise RuntimeError("Write query attempted inside database.reader() scope.")
|
|
|
|
def connect(self, reuse_if_open: bool = False): # type: ignore[override]
|
|
raise RuntimeError(
|
|
"Do not call database.connect() directly; use database.reader() or database.writer()."
|
|
)
|
|
|
|
def close(self): # type: ignore[override]
|
|
raise RuntimeError(
|
|
"Do not call database.close() directly; use database.reader() or database.writer()."
|
|
)
|
|
|
|
def is_closed(self) -> bool: # type: ignore[override]
|
|
active_database = _current_database.get()
|
|
return True if active_database is None else active_database.is_closed()
|
|
|
|
def connection(self): # type: ignore[override]
|
|
return self._require_active_database().connection()
|
|
|
|
def cursor(self, named_cursor=None): # type: ignore[override]
|
|
return self._require_active_database().cursor(named_cursor=named_cursor)
|
|
|
|
def execute_sql(self, sql, params=None): # type: ignore[override]
|
|
self._validate_sql(str(sql))
|
|
return self._require_active_database().execute_sql(sql, params)
|
|
|
|
def execute(self, query, **context_options): # type: ignore[override]
|
|
return self._require_active_database().execute(query, **context_options)
|
|
|
|
def atomic(self, *args, **kwargs): # type: ignore[override]
|
|
return self._require_active_database().atomic(*args, **kwargs)
|
|
|
|
def transaction(self, *args, **kwargs): # type: ignore[override]
|
|
return self._require_active_database().transaction(*args, **kwargs)
|
|
|
|
def savepoint(self, *args, **kwargs): # type: ignore[override]
|
|
return self._require_active_database().savepoint(*args, **kwargs)
|
|
|
|
def connection_context(self): # type: ignore[override]
|
|
raise RuntimeError(
|
|
"Do not call database.connection_context() directly; use database.reader() or database.writer()."
|
|
)
|
|
|
|
def reader(self):
|
|
return _require_database_connection().reader()
|
|
|
|
def writer(self):
|
|
return _require_database_connection().writer()
|
|
|
|
def reader_conn(self):
|
|
return _require_database_connection().reader_conn()
|
|
|
|
def writer_conn(self):
|
|
return _require_database_connection().writer_conn()
|
|
|
|
|
|
class DatabaseConnection:
|
|
def __init__(
|
|
self,
|
|
db_path: str | Path,
|
|
*,
|
|
pool_size: int = 4,
|
|
pragmas: dict[str, object] | None = None,
|
|
) -> None:
|
|
self.db_path = resolve_database_path(db_path)
|
|
self.pool_size = pool_size
|
|
self.pragmas = dict(DATABASE_PRAGMAS if pragmas is None else pragmas)
|
|
self.writer_db = ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas)
|
|
self.reader_dbs = tuple(
|
|
ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas)
|
|
for _ in range(pool_size)
|
|
)
|
|
self._reader_pool: queue.Queue[ManagedSqliteDatabase] = queue.Queue()
|
|
self._writer_lock = threading.RLock()
|
|
for reader_db in self.reader_dbs:
|
|
self._reader_pool.put(reader_db)
|
|
|
|
def initialize(self) -> Path:
|
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self.writer_db.connect(reuse_if_open=True)
|
|
try:
|
|
for path in schema_paths():
|
|
self.writer_db.connection().executescript(
|
|
path.read_text(encoding="utf-8")
|
|
)
|
|
_run_legacy_migrations(self.writer_db)
|
|
except Exception:
|
|
self.writer_db.close()
|
|
raise
|
|
|
|
for reader_db in self.reader_dbs:
|
|
reader_db.connect(reuse_if_open=True)
|
|
return self.db_path
|
|
|
|
def close(self) -> None:
|
|
for reader_db in self.reader_dbs:
|
|
if not reader_db.is_closed():
|
|
reader_db.close()
|
|
if not self.writer_db.is_closed():
|
|
self.writer_db.close()
|
|
|
|
@contextmanager
|
|
def reader(self) -> Iterator[ManagedSqliteDatabase]:
|
|
scope = _current_scope.get()
|
|
if scope in {"reader", "writer"}:
|
|
yield _require_active_database()
|
|
return
|
|
if scope == "reader_conn":
|
|
active_database = _require_active_database()
|
|
with active_database.atomic():
|
|
yield active_database
|
|
return
|
|
if scope == "writer_conn":
|
|
active_database = _require_active_database()
|
|
with active_database.atomic():
|
|
yield active_database
|
|
return
|
|
if scope == "writer":
|
|
yield _require_active_database()
|
|
return
|
|
|
|
leased_database = self._reader_pool.get()
|
|
database_token = _current_database.set(leased_database)
|
|
scope_token = _current_scope.set("reader")
|
|
try:
|
|
with leased_database.atomic():
|
|
yield leased_database
|
|
finally:
|
|
_current_scope.reset(scope_token)
|
|
_current_database.reset(database_token)
|
|
self._reader_pool.put(leased_database)
|
|
|
|
@contextmanager
|
|
def writer(self) -> Iterator[ManagedSqliteDatabase]:
|
|
scope = _current_scope.get()
|
|
if scope == "writer":
|
|
yield _require_active_database()
|
|
return
|
|
if scope == "writer_conn":
|
|
active_database = _require_active_database()
|
|
with active_database.atomic():
|
|
yield active_database
|
|
return
|
|
if scope in {"reader", "reader_conn"}:
|
|
raise RuntimeError(
|
|
"Cannot enter database.writer() inside database.reader()."
|
|
)
|
|
|
|
with self._writer_lock:
|
|
database_token = _current_database.set(self.writer_db)
|
|
scope_token = _current_scope.set("writer")
|
|
try:
|
|
with self.writer_db.atomic():
|
|
yield self.writer_db
|
|
finally:
|
|
_current_scope.reset(scope_token)
|
|
_current_database.reset(database_token)
|
|
|
|
@contextmanager
|
|
def reader_conn(self) -> Iterator[ManagedSqliteDatabase]:
|
|
scope = _current_scope.get()
|
|
if scope is not None:
|
|
yield _require_active_database()
|
|
return
|
|
|
|
leased_database = self._reader_pool.get()
|
|
database_token = _current_database.set(leased_database)
|
|
scope_token = _current_scope.set("reader_conn")
|
|
try:
|
|
yield leased_database
|
|
finally:
|
|
_current_scope.reset(scope_token)
|
|
_current_database.reset(database_token)
|
|
self._reader_pool.put(leased_database)
|
|
|
|
@contextmanager
|
|
def writer_conn(self) -> Iterator[ManagedSqliteDatabase]:
|
|
scope = _current_scope.get()
|
|
if scope in {"writer", "writer_conn"}:
|
|
yield _require_active_database()
|
|
return
|
|
if scope in {"reader", "reader_conn"}:
|
|
raise RuntimeError(
|
|
"Cannot enter database.writer_conn() inside database.reader()."
|
|
)
|
|
|
|
with self._writer_lock:
|
|
database_token = _current_database.set(self.writer_db)
|
|
scope_token = _current_scope.set("writer_conn")
|
|
try:
|
|
yield self.writer_db
|
|
finally:
|
|
_current_scope.reset(scope_token)
|
|
_current_database.reset(database_token)
|
|
|
|
|
|
def _run_legacy_migrations(database: ManagedSqliteDatabase) -> None:
|
|
job_columns = {column.name for column in database.get_columns("job")}
|
|
operations = []
|
|
migrator = SchemaMigrator.from_database(database)
|
|
if "convert_images" not in job_columns:
|
|
operations.extend(
|
|
(
|
|
migrator.add_column(
|
|
"job",
|
|
"convert_images",
|
|
BooleanField(
|
|
default=True,
|
|
constraints=[Check("convert_images IN (0, 1)")],
|
|
),
|
|
),
|
|
migrator.add_column_default("job", "convert_images", 1),
|
|
)
|
|
)
|
|
if "convert_video" not in job_columns:
|
|
operations.extend(
|
|
(
|
|
migrator.add_column(
|
|
"job",
|
|
"convert_video",
|
|
BooleanField(
|
|
default=True,
|
|
constraints=[Check("convert_video IN (0, 1)")],
|
|
),
|
|
),
|
|
migrator.add_column_default("job", "convert_video", 1),
|
|
)
|
|
)
|
|
if operations:
|
|
with database.atomic():
|
|
migrate(*operations)
|
|
|
|
|
|
def initialize_database(db_path: str | Path | None = None) -> Path:
|
|
global _database_connection
|
|
|
|
if _database_connection is not None:
|
|
_database_connection.close()
|
|
|
|
connection = DatabaseConnection(resolve_database_path(db_path))
|
|
resolved_path = connection.initialize()
|
|
_database_connection = connection
|
|
return resolved_path
|
|
|
|
|
|
def get_database_connection() -> DatabaseConnection | None:
|
|
return _database_connection
|
|
|
|
|
|
def _require_database_connection() -> DatabaseConnection:
|
|
database_connection = get_database_connection()
|
|
if database_connection is None:
|
|
raise RuntimeError("Database has not been initialized.")
|
|
return database_connection
|
|
|
|
|
|
def _require_active_database() -> ManagedSqliteDatabase:
|
|
active_database = _current_database.get()
|
|
if active_database is None:
|
|
raise RuntimeError(
|
|
"Database access requires a database.reader() or database.writer() context."
|
|
)
|
|
return active_database
|
|
|
|
|
|
database = RoutedSqliteDatabase()
|