Refactor database access through managed connections
This commit is contained in:
parent
f19bab6fa2
commit
3f28e46ff6
10 changed files with 1327 additions and 716 deletions
359
repub/db.py
Normal file
359
repub/db.py
Normal file
|
|
@ -0,0 +1,359 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from importlib import resources
|
||||
from importlib.resources.abc import Traversable
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from peewee import BooleanField, Check, SqliteDatabase
|
||||
from playhouse.migrate import SchemaMigrator, migrate
|
||||
|
||||
DEFAULT_DB_PATH = Path("republisher.db")
|
||||
DATABASE_PRAGMAS = {
|
||||
"busy_timeout": 5000,
|
||||
"cache_size": 15625,
|
||||
"foreign_keys": 1,
|
||||
"journal_mode": "wal",
|
||||
"page_size": 4096,
|
||||
"synchronous": "normal",
|
||||
"temp_store": "memory",
|
||||
}
|
||||
SCHEMA_GLOB = "*.sql"
|
||||
|
||||
_WRITE_SQL_PREFIX = re.compile(
|
||||
r"^\s*(INSERT|UPDATE|DELETE|REPLACE|CREATE|DROP|ALTER)\b"
|
||||
)
|
||||
_current_database: ContextVar[ManagedSqliteDatabase | None] = ContextVar(
|
||||
"repub_current_database",
|
||||
default=None,
|
||||
)
|
||||
_current_scope: ContextVar[str | None] = ContextVar(
|
||||
"repub_current_database_scope",
|
||||
default=None,
|
||||
)
|
||||
_database_connection: DatabaseConnection | None = None
|
||||
|
||||
|
||||
def resolve_database_path(db_path: str | Path | None = None) -> Path:
|
||||
raw_value = (
|
||||
os.environ.get("REPUBLISHER_DB_PATH", DEFAULT_DB_PATH)
|
||||
if db_path is None
|
||||
else db_path
|
||||
)
|
||||
return Path(raw_value).expanduser().resolve()
|
||||
|
||||
|
||||
def schema_paths() -> tuple[Traversable, ...]:
|
||||
schema_dir = resources.files("repub").joinpath("sql")
|
||||
return tuple(
|
||||
sorted(
|
||||
(path for path in schema_dir.iterdir() if path.name.endswith(".sql")),
|
||||
key=lambda path: path.name,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ManagedSqliteDatabase(SqliteDatabase):
|
||||
def __init__(self, database: str, **kwargs) -> None:
|
||||
pragmas = kwargs.pop("pragmas", DATABASE_PRAGMAS)
|
||||
kwargs.setdefault("check_same_thread", False)
|
||||
super().__init__(
|
||||
database,
|
||||
autoconnect=False,
|
||||
thread_safe=False,
|
||||
pragmas=pragmas,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class RoutedSqliteDatabase(SqliteDatabase):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
":managed:",
|
||||
autoconnect=False,
|
||||
pragmas=DATABASE_PRAGMAS,
|
||||
)
|
||||
|
||||
def _require_active_database(self) -> ManagedSqliteDatabase:
|
||||
active_database = _current_database.get()
|
||||
if active_database is None:
|
||||
raise RuntimeError(
|
||||
"Database access requires a database.reader() or database.writer() context."
|
||||
)
|
||||
return active_database
|
||||
|
||||
def _validate_sql(self, sql: str) -> None:
|
||||
scope = _current_scope.get()
|
||||
if scope in {"reader", "reader_conn"} and _WRITE_SQL_PREFIX.match(sql):
|
||||
raise RuntimeError("Write query attempted inside database.reader() scope.")
|
||||
|
||||
def connect(self, reuse_if_open: bool = False): # type: ignore[override]
|
||||
raise RuntimeError(
|
||||
"Do not call database.connect() directly; use database.reader() or database.writer()."
|
||||
)
|
||||
|
||||
def close(self): # type: ignore[override]
|
||||
raise RuntimeError(
|
||||
"Do not call database.close() directly; use database.reader() or database.writer()."
|
||||
)
|
||||
|
||||
def is_closed(self) -> bool: # type: ignore[override]
|
||||
active_database = _current_database.get()
|
||||
return True if active_database is None else active_database.is_closed()
|
||||
|
||||
def connection(self): # type: ignore[override]
|
||||
return self._require_active_database().connection()
|
||||
|
||||
def cursor(self, named_cursor=None): # type: ignore[override]
|
||||
return self._require_active_database().cursor(named_cursor=named_cursor)
|
||||
|
||||
def execute_sql(self, sql, params=None): # type: ignore[override]
|
||||
self._validate_sql(str(sql))
|
||||
return self._require_active_database().execute_sql(sql, params)
|
||||
|
||||
def execute(self, query, **context_options): # type: ignore[override]
|
||||
return self._require_active_database().execute(query, **context_options)
|
||||
|
||||
def atomic(self, *args, **kwargs): # type: ignore[override]
|
||||
return self._require_active_database().atomic(*args, **kwargs)
|
||||
|
||||
def transaction(self, *args, **kwargs): # type: ignore[override]
|
||||
return self._require_active_database().transaction(*args, **kwargs)
|
||||
|
||||
def savepoint(self, *args, **kwargs): # type: ignore[override]
|
||||
return self._require_active_database().savepoint(*args, **kwargs)
|
||||
|
||||
def connection_context(self): # type: ignore[override]
|
||||
raise RuntimeError(
|
||||
"Do not call database.connection_context() directly; use database.reader() or database.writer()."
|
||||
)
|
||||
|
||||
def reader(self):
|
||||
return _require_database_connection().reader()
|
||||
|
||||
def writer(self):
|
||||
return _require_database_connection().writer()
|
||||
|
||||
def reader_conn(self):
|
||||
return _require_database_connection().reader_conn()
|
||||
|
||||
def writer_conn(self):
|
||||
return _require_database_connection().writer_conn()
|
||||
|
||||
|
||||
class DatabaseConnection:
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str | Path,
|
||||
*,
|
||||
pool_size: int = 4,
|
||||
pragmas: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
self.db_path = resolve_database_path(db_path)
|
||||
self.pool_size = pool_size
|
||||
self.pragmas = dict(DATABASE_PRAGMAS if pragmas is None else pragmas)
|
||||
self.writer_db = ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas)
|
||||
self.reader_dbs = tuple(
|
||||
ManagedSqliteDatabase(str(self.db_path), pragmas=self.pragmas)
|
||||
for _ in range(pool_size)
|
||||
)
|
||||
self._reader_pool: queue.Queue[ManagedSqliteDatabase] = queue.Queue()
|
||||
self._writer_lock = threading.RLock()
|
||||
for reader_db in self.reader_dbs:
|
||||
self._reader_pool.put(reader_db)
|
||||
|
||||
def initialize(self) -> Path:
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.writer_db.connect(reuse_if_open=True)
|
||||
try:
|
||||
for path in schema_paths():
|
||||
self.writer_db.connection().executescript(
|
||||
path.read_text(encoding="utf-8")
|
||||
)
|
||||
_run_legacy_migrations(self.writer_db)
|
||||
except Exception:
|
||||
self.writer_db.close()
|
||||
raise
|
||||
|
||||
for reader_db in self.reader_dbs:
|
||||
reader_db.connect(reuse_if_open=True)
|
||||
return self.db_path
|
||||
|
||||
def close(self) -> None:
|
||||
for reader_db in self.reader_dbs:
|
||||
if not reader_db.is_closed():
|
||||
reader_db.close()
|
||||
if not self.writer_db.is_closed():
|
||||
self.writer_db.close()
|
||||
|
||||
@contextmanager
|
||||
def reader(self) -> Iterator[ManagedSqliteDatabase]:
|
||||
scope = _current_scope.get()
|
||||
if scope in {"reader", "writer"}:
|
||||
yield _require_active_database()
|
||||
return
|
||||
if scope == "reader_conn":
|
||||
active_database = _require_active_database()
|
||||
with active_database.atomic():
|
||||
yield active_database
|
||||
return
|
||||
if scope == "writer_conn":
|
||||
active_database = _require_active_database()
|
||||
with active_database.atomic():
|
||||
yield active_database
|
||||
return
|
||||
if scope == "writer":
|
||||
yield _require_active_database()
|
||||
return
|
||||
|
||||
leased_database = self._reader_pool.get()
|
||||
database_token = _current_database.set(leased_database)
|
||||
scope_token = _current_scope.set("reader")
|
||||
try:
|
||||
with leased_database.atomic():
|
||||
yield leased_database
|
||||
finally:
|
||||
_current_scope.reset(scope_token)
|
||||
_current_database.reset(database_token)
|
||||
self._reader_pool.put(leased_database)
|
||||
|
||||
@contextmanager
|
||||
def writer(self) -> Iterator[ManagedSqliteDatabase]:
|
||||
scope = _current_scope.get()
|
||||
if scope == "writer":
|
||||
yield _require_active_database()
|
||||
return
|
||||
if scope == "writer_conn":
|
||||
active_database = _require_active_database()
|
||||
with active_database.atomic():
|
||||
yield active_database
|
||||
return
|
||||
if scope in {"reader", "reader_conn"}:
|
||||
raise RuntimeError(
|
||||
"Cannot enter database.writer() inside database.reader()."
|
||||
)
|
||||
|
||||
with self._writer_lock:
|
||||
database_token = _current_database.set(self.writer_db)
|
||||
scope_token = _current_scope.set("writer")
|
||||
try:
|
||||
with self.writer_db.atomic():
|
||||
yield self.writer_db
|
||||
finally:
|
||||
_current_scope.reset(scope_token)
|
||||
_current_database.reset(database_token)
|
||||
|
||||
@contextmanager
|
||||
def reader_conn(self) -> Iterator[ManagedSqliteDatabase]:
|
||||
scope = _current_scope.get()
|
||||
if scope is not None:
|
||||
yield _require_active_database()
|
||||
return
|
||||
|
||||
leased_database = self._reader_pool.get()
|
||||
database_token = _current_database.set(leased_database)
|
||||
scope_token = _current_scope.set("reader_conn")
|
||||
try:
|
||||
yield leased_database
|
||||
finally:
|
||||
_current_scope.reset(scope_token)
|
||||
_current_database.reset(database_token)
|
||||
self._reader_pool.put(leased_database)
|
||||
|
||||
@contextmanager
|
||||
def writer_conn(self) -> Iterator[ManagedSqliteDatabase]:
|
||||
scope = _current_scope.get()
|
||||
if scope in {"writer", "writer_conn"}:
|
||||
yield _require_active_database()
|
||||
return
|
||||
if scope in {"reader", "reader_conn"}:
|
||||
raise RuntimeError(
|
||||
"Cannot enter database.writer_conn() inside database.reader()."
|
||||
)
|
||||
|
||||
with self._writer_lock:
|
||||
database_token = _current_database.set(self.writer_db)
|
||||
scope_token = _current_scope.set("writer_conn")
|
||||
try:
|
||||
yield self.writer_db
|
||||
finally:
|
||||
_current_scope.reset(scope_token)
|
||||
_current_database.reset(database_token)
|
||||
|
||||
|
||||
def _run_legacy_migrations(database: ManagedSqliteDatabase) -> None:
|
||||
job_columns = {column.name for column in database.get_columns("job")}
|
||||
operations = []
|
||||
migrator = SchemaMigrator.from_database(database)
|
||||
if "convert_images" not in job_columns:
|
||||
operations.extend(
|
||||
(
|
||||
migrator.add_column(
|
||||
"job",
|
||||
"convert_images",
|
||||
BooleanField(
|
||||
default=True,
|
||||
constraints=[Check("convert_images IN (0, 1)")],
|
||||
),
|
||||
),
|
||||
migrator.add_column_default("job", "convert_images", 1),
|
||||
)
|
||||
)
|
||||
if "convert_video" not in job_columns:
|
||||
operations.extend(
|
||||
(
|
||||
migrator.add_column(
|
||||
"job",
|
||||
"convert_video",
|
||||
BooleanField(
|
||||
default=True,
|
||||
constraints=[Check("convert_video IN (0, 1)")],
|
||||
),
|
||||
),
|
||||
migrator.add_column_default("job", "convert_video", 1),
|
||||
)
|
||||
)
|
||||
if operations:
|
||||
with database.atomic():
|
||||
migrate(*operations)
|
||||
|
||||
|
||||
def initialize_database(db_path: str | Path | None = None) -> Path:
|
||||
global _database_connection
|
||||
|
||||
if _database_connection is not None:
|
||||
_database_connection.close()
|
||||
|
||||
connection = DatabaseConnection(resolve_database_path(db_path))
|
||||
resolved_path = connection.initialize()
|
||||
_database_connection = connection
|
||||
return resolved_path
|
||||
|
||||
|
||||
def get_database_connection() -> DatabaseConnection | None:
|
||||
return _database_connection
|
||||
|
||||
|
||||
def _require_database_connection() -> DatabaseConnection:
|
||||
database_connection = get_database_connection()
|
||||
if database_connection is None:
|
||||
raise RuntimeError("Database has not been initialized.")
|
||||
return database_connection
|
||||
|
||||
|
||||
def _require_active_database() -> ManagedSqliteDatabase:
|
||||
active_database = _current_database.get()
|
||||
if active_database is None:
|
||||
raise RuntimeError(
|
||||
"Database access requires a database.reader() or database.writer() context."
|
||||
)
|
||||
return active_database
|
||||
|
||||
|
||||
database = RoutedSqliteDatabase()
|
||||
Loading…
Add table
Add a link
Reference in a new issue