feat: draw the rest of the owl

This commit is contained in:
Iain Learmonth 2026-03-26 10:58:03 +00:00
parent e21b725192
commit 2ba848467f
28 changed files with 1538 additions and 448 deletions

104
alembic.ini Normal file
View file

@ -0,0 +1,104 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
file_template = %%(year)d-%%(month).2d-%%(day).2d_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library that can be
# installed by adding `alembic[tz]` to the pip requirements
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
hooks = ruff_format,ruff
ruff_format.type = exec
ruff_format.executable = ruff
ruff_format.options = format REVISION_SCRIPT_FILENAME
ruff.type = exec
ruff.executable = ruff
ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

93
alembic/env.py Normal file
View file

@ -0,0 +1,93 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
import src.main as _
from src.config import settings
from src.database import metadata
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
DATABASE_URL = str(settings.DATABASE_URL)
db_driver = settings.DATABASE_URL.scheme
db_driver_parts = db_driver.split("+")
if len(db_driver_parts) > 1: # e.g. postgresql+asyncpg
sync_scheme = db_driver_parts[0].strip()
DATABASE_URL = DATABASE_URL.replace( # replace with sync driver
db_driver, sync_scheme
)
config.set_main_option("sqlalchemy.url", DATABASE_URL)
config.compare_type = True
config.compare_server_default = True
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

25
alembic/script.py.mako Normal file
View file

@ -0,0 +1,25 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View file

@ -0,0 +1,69 @@
"""initial schema
Revision ID: e723dddd82db
Revises:
Create Date: 2026-03-26 10:37:45.864627
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e723dddd82db"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"link",
sa.Column("url", sa.String(), nullable=False),
sa.Column("link_domain", sa.String(), nullable=False),
sa.Column("pool", sa.Integer(), nullable=False),
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("link_pkey")),
)
op.create_table(
"mirror",
sa.Column("origin", sa.String(), nullable=False),
sa.Column("pool", sa.Integer(), nullable=False),
sa.Column("mirror", sa.String(), nullable=False),
sa.Column("first_seen", sa.DateTime(timezone=True), nullable=False),
sa.Column("last_seen", sa.DateTime(timezone=True), nullable=False),
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("mirror_pkey")),
)
op.create_table(
"snapshot",
sa.Column("url", sa.String(), nullable=False),
sa.Column("pool", sa.Integer(), nullable=False),
sa.Column(
"snapshot_state",
sa.Enum(
"PENDING", "FAILED", "UPDATING", "FROZEN", "EXPIRED", name="snapshotstate"
),
nullable=False,
),
sa.Column("provider", sa.Enum("GOOGLE", name="snapshotprovider"), nullable=False),
sa.Column("snapshot_published_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id", name=op.f("snapshot_pkey")),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("snapshot")
op.drop_table("mirror")
op.drop_table("link")
# ### end Alembic commands ###

980
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -8,17 +8,22 @@ license = "BSD-2"
package-mode = false package-mode = false
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.12" python = "^3.14"
alembic = "^1.18.4"
babel = "^2.17" babel = "^2.17"
beautifulsoup4 = "^4.13" beautifulsoup4 = "^4.13"
fastapi = "^0.115.12" fastapi = "^0.135"
google-cloud-storage = "^3.9.0" google-cloud-storage = "^3.9.0"
hashids = "^1.3"
jinja2 = "^3.1" jinja2 = "^3.1"
lxml = "^6.0" lxml = "^6.0"
minify-html = "^0.18"
requests = "^2.32" requests = "^2.32"
psycopg2-binary = "^2.9"
pydantic = "^2.11" pydantic = "^2.11"
pydantic-settings = "^2.10" pydantic-settings = "^2.10"
pyyaml = "^6.0" pyyaml = "^6.0"
sqlalchemy = "^2.0.48"
tldextract = "^5" tldextract = "^5"
uvicorn = {extras = ["standard"], version = "^0.30.6"} uvicorn = {extras = ["standard"], version = "^0.30.6"}
@ -43,5 +48,5 @@ line-length = 92
asyncio_default_fixture_loop_scope = "module" asyncio_default_fixture_loop_scope = "module"
[tool.ruff] [tool.ruff]
target-version = "py312" target-version = "py314"
line-length = 92 line-length = 92

View file

@ -1,6 +1,7 @@
from os.path import abspath, dirname, join from os.path import abspath, dirname, join
from typing import Any from typing import Any
from pydantic import PostgresDsn
from pydantic_settings import ( from pydantic_settings import (
BaseSettings, BaseSettings,
SettingsConfigDict, SettingsConfigDict,
@ -34,11 +35,10 @@ class CustomBaseSettings(BaseSettings):
class Config(CustomBaseSettings): class Config(CustomBaseSettings):
# DATABASE_URL: PostgresDsn DATABASE_URL: PostgresDsn
# DATABASE_ASYNC_URL: PostgresDsn DATABASE_POOL_SIZE: int = 16
# DATABASE_POOL_SIZE: int = 16 DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes
# DATABASE_POOL_TTL: int = 60 * 20 # 20 minutes DATABASE_POOL_PRE_PING: bool = True
# DATABASE_POOL_PRE_PING: bool = True
ENVIRONMENT: Environment = Environment.PRODUCTION ENVIRONMENT: Environment = Environment.PRODUCTION
@ -46,6 +46,15 @@ class Config(CustomBaseSettings):
CORS_ORIGINS_REGEX: str | None = None CORS_ORIGINS_REGEX: str | None = None
CORS_HEADERS: list[str] = ["*"] CORS_HEADERS: list[str] = ["*"]
API_DOMAIN: str
LINK_DOMAIN: str
INVALID_URL: str
HASH_SECRET_KEY: str
MATOMO_HOST: str
MATOMO_SITE_ID: int
APP_VERSION: str = "0.0.0" APP_VERSION: str = "0.0.0"

View file

@ -1,5 +1,13 @@
from enum import Enum from enum import Enum
DB_NAMING_CONVENTION = {
"ix": "%(column_0_label)s_idx",
"uq": "%(table_name)s_%(column_0_name)s_key",
"ck": "%(table_name)s_%(constraint_name)s_check",
"fk": "%(table_name)s_%(column_0_name)s_fkey",
"pk": "%(table_name)s_pkey",
}
class Environment(str, Enum): class Environment(str, Enum):
LOCAL = "LOCAL" LOCAL = "LOCAL"

View file

@ -1,9 +1,11 @@
import contextlib from contextlib import contextmanager
from typing import Annotated, Iterator, Generator from typing import Annotated, Generator
from fastapi import Depends from fastapi import Depends
from sqlalchemy import ( from sqlalchemy import (
MetaData, create_engine, Connection, MetaData,
create_engine,
Connection,
) )
from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.orm import sessionmaker, Session
@ -20,8 +22,8 @@ metadata = MetaData(naming_convention=DB_NAMING_CONVENTION)
sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine) sm = sessionmaker(autocommit=False, expire_on_commit=False, bind=engine)
@contextlib.contextmanager @contextmanager
def get_db_connection() -> Iterator[Connection]: def get_db_connection() -> Generator[Connection, None, None]:
with engine.connect() as connection: with engine.connect() as connection:
try: try:
yield connection yield connection
@ -30,8 +32,8 @@ def get_db_connection() -> Iterator[Connection]:
raise raise
@contextlib.contextmanager @contextmanager
def get_db_session() -> Iterator[Session]: def get_db_session() -> Generator[Session, None, None]:
session = sm() session = sm()
try: try:
yield session yield session

11
src/link/models.py Normal file
View file

@ -0,0 +1,11 @@
from sqlalchemy.orm import Mapped
from src.models import TimestampMixin, IdMixin, CustomBase
class Link(CustomBase, IdMixin, TimestampMixin):
__tablename__ = "link"
url: Mapped[str]
link_domain: Mapped[str]
pool: Mapped[int]

58
src/link/router.py Normal file
View file

@ -0,0 +1,58 @@
from fastapi import APIRouter, HTTPException, Header, Query, BackgroundTasks
from starlette import status
from starlette.responses import RedirectResponse
from src.config import settings
from src.database import DbSession
from src.link.models import Link
from src.mirrors.service import resolve_mirror
from src.security import ApiKey
from src.snapshots.router import snap
from src.utils import hashids
router = APIRouter()
@router.get("/api/v1/link")
def get_link(background_tasks: BackgroundTasks, db: DbSession, auth: ApiKey, url: str, type_: str = Query(default="auto", alias="type")):
if auth and type_ in ["auto", "live", "live-short"]:
s = db.query(Link).filter(Link.url == url, Link.pool == 0).first()
if not s and resolve_mirror(db, url):
s = Link(url=url, pool=0, link_domain=settings.LINK_DOMAIN)
db.add(s)
db.commit()
if s:
return {"url": f"https://{s.link_domain}/{hashids.encode(s.id)}"}
if type_ in ["auto", "snapshot"]:
if isinstance(s := snap(background_tasks, db, auth, url), dict):
return s
if type_ in ["auto", "live", "live-direct"]:
if mirror := resolve_mirror(db, url):
return {"url": mirror}
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
@router.get("/{hash_}")
def resolve_hash(db: DbSession, hash_: str, host: str = Header(settings.LINK_DOMAIN)):
try:
id_ = hashids.decode(hash_)[0]
except IndexError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
link = (
db.query(Link)
.filter(Link.id == id_, Link.link_domain == host.lower().strip())
.first()
)
if not link:
return RedirectResponse(
settings.INVALID_URL,
status_code=status.HTTP_302_FOUND,
headers={"Referrer-Policy": "no-referrer"},
)
if host.lower().strip() != settings.API_DOMAIN:
return RedirectResponse(
resolve_mirror(db, link.url),
status_code=status.HTTP_302_FOUND,
headers={"Referrer-Policy": "no-referrer"},
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)

View file

@ -1,14 +1,20 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator from typing import AsyncGenerator
from fastapi import FastAPI from fastapi import FastAPI, Header, HTTPException
from starlette import status
from starlette.responses import RedirectResponse
from src.config import app_configs from src.config import app_configs, settings
from src.link.router import router as link_router
from src.mirrors.tasks import update_rsf_mirrors
from src.mirrors.router import router as mirrors_router
from src.snapshots.router import router as snapshots_router from src.snapshots.router import router as snapshots_router
@asynccontextmanager @asynccontextmanager
async def lifespan(_application: FastAPI) -> AsyncGenerator: async def lifespan(_application: FastAPI) -> AsyncGenerator:
await update_rsf_mirrors()
# Startup # Startup
yield yield
# Shutdown # Shutdown
@ -16,9 +22,22 @@ async def lifespan(_application: FastAPI) -> AsyncGenerator:
app = FastAPI(**app_configs, lifespan=lifespan) app = FastAPI(**app_configs, lifespan=lifespan)
app.include_router(link_router)
app.include_router(snapshots_router) app.include_router(snapshots_router)
app.include_router(mirrors_router)
@app.get("/healthcheck", include_in_schema=False) @app.get("/")
def home(host: str = Header(settings.LINK_DOMAIN)):
if host.lower().strip() != settings.API_DOMAIN:
return RedirectResponse(
settings.INVALID_URL,
status_code=status.HTTP_302_FOUND,
headers={"Referrer-Policy": "no-referrer"},
)
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
@app.get("/api/v1/healthcheck", include_in_schema=False)
async def healthcheck() -> dict[str, str]: async def healthcheck() -> dict[str, str]:
return {"status": "ok"} return {"status": "ok"}

17
src/mirrors/models.py Normal file
View file

@ -0,0 +1,17 @@
from datetime import datetime
from sqlalchemy.orm import Mapped
from src.models import CustomBase, IdMixin
class Mirror(CustomBase, IdMixin):
__tablename__ = "mirror"
origin: Mapped[str]
pool: Mapped[int]
mirror: Mapped[str]
first_seen: Mapped[datetime]
last_seen: Mapped[datetime]
# TODO: Record hits when a redirect goes to the mirror
hits: Mapped[int] = 0

27
src/mirrors/router.py Normal file
View file

@ -0,0 +1,27 @@
from urllib.parse import urlparse
from fastapi import APIRouter
from src.database import DbSession
from src.mirrors.schemas import MirrorLinks, RedirectorData
from src.mirrors.service import refresh_mirrors
from src.security import ApiKey
router = APIRouter()
@router.post("/api/v1/mirrors")
def update_mirrors(db: DbSession, auth: ApiKey, data: RedirectorData):
for pool, data in enumerate(data.pools):
refresh_mirrors(db, pool, data.origins)
db.commit()
@router.get("/api/v1/resolve", response_model=MirrorLinks)
def resolve_mirror(db: DbSession, auth: ApiKey, url: str):
parsed = urlparse(url)
try:
mirror = resolve_mirror(db, parsed.netloc)
return {"url": parsed._replace(netloc=mirror)}
except ValueError:
return {"mirrors": []}

18
src/mirrors/schemas.py Normal file
View file

@ -0,0 +1,18 @@
from typing import Literal
from pydantic import BaseModel, ConfigDict
class RedirectorDataPool(BaseModel):
model_config = ConfigDict(extra="ignore")
origins: dict[str, str]
class RedirectorData(BaseModel):
version: Literal["1.0"]
pools: list[RedirectorDataPool]
class MirrorLinks(BaseModel):
mirrors: list[str]

75
src/mirrors/service.py Normal file
View file

@ -0,0 +1,75 @@
import random
from datetime import datetime, timedelta
from urllib.parse import urlparse, urlunparse
from sqlalchemy import func
from sqlalchemy.orm import Session
from src.kldscp.client import KLDSCP_SUPPORTED_ORIGINS, get_kaleidoscope_mirror
from src.mirrors.models import Mirror
def _refresh_mirror(db: Session, mirror: str, origin: str, pool: int):
if mirror.startswith("https://"):
mirror = mirror[8:]
existing = (
db.query(Mirror)
.filter(Mirror.origin == origin, Mirror.mirror == mirror, Mirror.pool == pool)
.first()
)
if existing:
existing.last_seen = func.now()
else:
db.add(
Mirror(
origin=origin,
mirror=mirror,
pool=pool,
first_seen=func.now(),
last_seen=func.now(),
)
)
def refresh_mirrors(db: Session, pool: int, data: dict[str, str | list[str]]):
for key in data:
if key.startswith("https://") and key.endswith("/"):
origin = key[8:-1]
else:
origin = key
if "/" in origin:
# TODO: flag this to operator
continue
if isinstance(data[key], list):
for mirror in data[key]:
_refresh_mirror(db, mirror, origin, pool)
elif isinstance(data[key], str):
_refresh_mirror(db, data[key], origin, pool)
else:
raise TypeError("data must be dict[str, str | list[str]]")
db.query(Mirror).filter(
Mirror.pool == pool, Mirror.last_seen < datetime.now() - timedelta(minutes=5)
).delete()
def get_mirrors(db: Session, origin: str, pool=None) -> list[str]:
if pool is None:
pool = [0, -2]
elif isinstance(pool, int):
pool = [pool]
result = db.query(Mirror).filter(Mirror.origin == origin, Mirror.pool.in_(pool)).all()
mirrors = [m.mirror for m in result]
if not mirrors:
if origin in KLDSCP_SUPPORTED_ORIGINS:
if (k_mirror := get_kaleidoscope_mirror(origin)) is not None:
mirrors.append(k_mirror)
return mirrors
def resolve_mirror(db: Session, url: str) -> str | None:
parsed = urlparse(url)
try:
mirror = random.choice(get_mirrors(db, parsed.netloc))
return urlunparse(parsed._replace(netloc=f"{mirror}"))
except IndexError:
return None

16
src/mirrors/tasks.py Normal file
View file

@ -0,0 +1,16 @@
import requests
from src.database import get_db_session
from src.mirrors.service import refresh_mirrors
from src.utils import repeat_every
@repeat_every(seconds=600)
def update_rsf_mirrors():
with get_db_session() as db:
r = requests.get(
"https://raw.githubusercontent.com/RSF-RWB/collateralfreedom/refs/heads/main/sites.json"
)
mirrors = r.json()
refresh_mirrors(db, -2, mirrors) # Tracking as hardcoded pool -2
db.commit()

View file

@ -13,6 +13,7 @@ class CustomBase(DeclarativeBase):
} }
metadata = metadata metadata = metadata
class ActivatedMixin: class ActivatedMixin:
active: Mapped[bool] = mapped_column(default=True) active: Mapped[bool] = mapped_column(default=True)

20
src/security.py Normal file
View file

@ -0,0 +1,20 @@
from typing import Annotated
from fastapi import Depends, Header, HTTPException
from starlette import status
from src.config import settings
def api_key(host: str = Header(), authorization: str | None = Header(None)) -> bool:
if host.lower().strip() != settings.API_DOMAIN.strip():
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
try:
if authorization.split()[1] == settings.API_KEY:
return True
return False
except AttributeError, TypeError, IndexError:
return False
ApiKey = Annotated[bool, Depends(api_key)]

View file

@ -1,19 +1,24 @@
import base64 import base64
import copy import copy
import datetime import datetime
import logging
import mimetypes import mimetypes
from typing import Any from typing import Any
from urllib.parse import urlparse, urlunparse, urljoin from urllib.parse import urlparse, urlunparse, urljoin
import minify_html
import requests import requests
from babel.dates import format_date from babel.dates import format_date
from babel.support import Translations from babel.support import Translations
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from jinja2 import Environment, PackageLoader, select_autoescape from jinja2 import Environment, PackageLoader, select_autoescape
from src.config import settings
from src.database import get_db_session
from src.mirrors.service import resolve_mirror
from src.pangea.client import pangea_expanded_image_url
from src.snapshots.config import SnapshotsConfig, config_for_url from src.snapshots.config import SnapshotsConfig, config_for_url
from src.snapshots.schemas import SnapshotContext from src.snapshots.schemas import SnapshotContext
from src.snapshots.service import resolve_snapshot
class SnapshotParseError(RuntimeError): class SnapshotParseError(RuntimeError):
@ -74,7 +79,7 @@ def fetch_url(base: str, url: str) -> str | None:
return None return None
class Snapshot: class SnapshotCamera:
config: SnapshotsConfig | None = None config: SnapshotsConfig | None = None
context: SnapshotContext | None = None context: SnapshotContext | None = None
raw: bytes | None = None raw: bytes | None = None
@ -158,9 +163,29 @@ class Snapshot:
element.decompose() element.decompose()
for image in body.select("img"): for image in body.select("img"):
image.attrs = { image.attrs = {
"src": fetch_url(self.url, image["src"]), "src": fetch_url(
"alt": image["alt"], pangea_expanded_image_url(self.url),
image.get("src", image.get("data-src", "")),
),
"alt": image.get("alt", ""),
} }
with get_db_session() as db:
for hyperlink in body.select("a"):
absolute_url = urljoin(self.url, hyperlink.get("href"))
existing_snapshot = resolve_snapshot(db, absolute_url)
if existing_snapshot:
hyperlink.attrs.update(
{"href": existing_snapshot, "class": "snap-link--snapshot"}
)
continue
mirror_url = resolve_mirror(db, absolute_url)
if mirror_url:
hyperlink.attrs.update(
{"href": mirror_url, "class": "snap-link--mirror"}
)
continue
hyperlink.attrs.update({"href": absolute_url})
return str(body) return str(body)
def preprocess(self) -> None: def preprocess(self) -> None:
@ -173,16 +198,15 @@ class Snapshot:
element.attrs.pop("style") element.attrs.pop("style")
def favicon(self): def favicon(self):
icon = fetch_url( favicon_src = self.get_attribute_value('link[rel="icon"]', "href", optional=True)
self.url, self.get_attribute_value('link[rel="icon"]', "href", optional=True) if favicon_src:
) icon = fetch_url(self.url, favicon_src)
if icon:
return icon return icon
parsed = urlparse(self.url) parsed = urlparse(self.url)
icon_url = urlunparse((parsed.scheme, parsed.netloc, "/favicon.ico", "", "", "")) icon_url = urlunparse((parsed.scheme, parsed.netloc, "/favicon.ico", "", "", ""))
return fetch_url(self.url, icon_url) return fetch_url(self.url, icon_url)
def published_time(self, locale: str = "en") -> str: def published_time(self, locale) -> str:
if self.config.article_published_selector: if self.config.article_published_selector:
if published := self.get_element_content( if published := self.get_element_content(
self.config.article_published_selector, optional=True self.config.article_published_selector, optional=True
@ -194,12 +218,28 @@ class Snapshot:
return format_date(ts, locale=locale) return format_date(ts, locale=locale)
def parse(self) -> None: def parse(self) -> None:
if not self.config:
self.config = config_for_url(self.url)
if not self.config:
return
self.soup = BeautifulSoup(self.raw, "lxml") self.soup = BeautifulSoup(self.raw, "lxml")
self.preprocess() self.preprocess()
article_image_source = self.get_attribute_value( if self.config.article_image_selector:
self.config.article_image_selector, "src" article_image_source = self.get_attribute_value(
) self.config.article_image_selector, "src"
)
article_image_source = pangea_expanded_image_url(article_image_source)
else:
article_image_source = None
page_language = self.get_attribute_value(["html", "body"], "lang", optional=True) page_language = self.get_attribute_value(["html", "body"], "lang", optional=True)
site_url = urlunparse(urlparse(self.url)._replace(path="/"))
with get_db_session() as db:
article_mirror_url = resolve_mirror(db, self.url)
site_mirror_url = (
urlunparse(urlparse(article_mirror_url)._replace(path="/"))
if article_mirror_url
else None
)
self.context = SnapshotContext( self.context = SnapshotContext(
article_author=self.get_element_content( article_author=self.get_element_content(
self.config.article_author_selector, optional=True self.config.article_author_selector, optional=True
@ -208,7 +248,9 @@ class Snapshot:
article_description=self.get_attribute_value( article_description=self.get_attribute_value(
'meta[name="description"]', "content", optional=True 'meta[name="description"]', "content", optional=True
), ),
article_image=fetch_url(self.url, article_image_source), article_image=fetch_url(self.url, article_image_source)
if article_image_source
else None,
article_image_caption=self.get_element_content( article_image_caption=self.get_element_content(
self.config.article_image_caption_selector, optional=True self.config.article_image_caption_selector, optional=True
), ),
@ -216,20 +258,25 @@ class Snapshot:
article_published=self.published_time(page_language), article_published=self.published_time(page_language),
article_title=self.get_element_content(self.config.article_title_selector), article_title=self.get_element_content(self.config.article_title_selector),
article_url=self.url, article_url=self.url,
article_mirror_url=article_mirror_url,
matomo_host=settings.MATOMO_HOST,
matomo_site_id=settings.MATOMO_SITE_ID,
page_direction=self.get_attribute_value(["html", "body"], "dir", optional=True), page_direction=self.get_attribute_value(["html", "body"], "dir", optional=True),
page_language=page_language, page_language=page_language,
site_favicon=self.favicon(), site_favicon=self.favicon(),
site_logo=fetch_file(self.config.site_logo), site_logo=fetch_file(self.config.site_logo),
site_title=self.config.site_title, site_title=self.config.site_title,
site_url=site_url,
site_mirror_url=site_mirror_url,
) )
def get_context(self) -> dict[str, Any]: def get_context(self) -> dict[str, Any] | None:
logging.info("Get content") self.config = config_for_url(self.url)
self.get_content() if self.config:
logging.info("Parse") self.get_content()
self.parse() self.parse()
logging.info("Dump") return self.context.model_dump()
return self.context.model_dump() return None
def render(self) -> str: def render(self) -> str:
context = self.get_context() context = self.get_context()
@ -246,4 +293,6 @@ class Snapshot:
translations = Translations.load("i18n", [context["page_language"], "en"]) translations = Translations.load("i18n", [context["page_language"], "en"])
jinja_env.install_gettext_translations(translations) jinja_env.install_gettext_translations(translations)
template = jinja_env.get_template("article-template.html.j2") template = jinja_env.get_template("article-template.html.j2")
return template.render(**context) return minify_html.minify(
template.render(**context), minify_js=True, minify_css=True
)

View file

@ -1,9 +1,54 @@
from datetime import datetime
from enum import Enum
from sqlalchemy.orm import Mapped from sqlalchemy.orm import Mapped
from src.models import CustomBase, IdMixin from src.models import (
CustomBase,
IdMixin,
DeletedTimestampMixin,
TimestampMixin,
)
from src.google.config import settings as google_settings
from src.utils import hashids
class Snapshot(CustomBase, IdMixin): class SnapshotProvider(Enum):
GOOGLE = "google"
# TODO: when adding make sure to update alembic migration with
# op.execute("ALTER TYPE snapshotprovider ADD VALUE 'aws'")
# AWS = "aws"
# OVH = "ovh"
# ORACLE = "oracle"
# class SnapshotConfiguration(CustomBase, IdMixin, TimestampMixin, DeletedTimestampMixin, DescriptionMixin):
# __tablename__ = "snapshot_template"
#
# domain: Mapped[str]
# path: Mapped[str]
# configuration: Mapped[dict[str, Any]]
class SnapshotState(Enum):
PENDING = "pending"
FAILED = "failed"
UPDATING = "updating"
FROZEN = "frozen"
EXPIRED = "expired"
class Snapshot(CustomBase, IdMixin, TimestampMixin, DeletedTimestampMixin):
__tablename__ = "snapshot" __tablename__ = "snapshot"
url: Mapped[str] url: Mapped[str]
pool: Mapped[int]
snapshot_state: Mapped[SnapshotState]
provider: Mapped[SnapshotProvider]
snapshot_published_at: Mapped[datetime | None]
@property
def link(self) -> str:
if self.provider == SnapshotProvider.GOOGLE:
return f"https://storage.googleapis.com/{google_settings.BUCKET_NAME}/{hashids.encode(self.id)}.html"
return "unknown-provider" # impossible because all enum options

View file

@ -2,48 +2,56 @@ from fastapi import APIRouter, HTTPException, BackgroundTasks
from starlette import status from starlette import status
from starlette.responses import HTMLResponse from starlette.responses import HTMLResponse
from src.database import DbSession
from src.security import ApiKey
from src.snapshots.config import config_for_url
from src.snapshots.models import Snapshot, SnapshotState, SnapshotProvider
from src.config import settings from src.config import settings
from src.google.config import settings as google_settings from src.snapshots.client import SnapshotCamera
from src.snapshots.client import Snapshot
from src.snapshots.schemas import SnapshotContext from src.snapshots.schemas import SnapshotContext
from src.snapshots.tasks import upload_snapshot from src.snapshots.tasks import generate_snapshot
router = APIRouter() router = APIRouter()
@router.get( @router.get(
"/debug/context", "/api/v1/snap-context",
summary="Generate the context used by the snapshot template for debugging purposes. Endpoint disabled on production deployments.", summary="Generate the context used by the snapshot template for debugging purposes.",
response_model=SnapshotContext, response_model=SnapshotContext,
) )
def context(url: str = "https://www.bbc.com/russian/articles/ckgeey4dqgxo"): def context(auth: ApiKey, url: str = "https://www.bbc.com/russian/articles/ckgeey4dqgxo"):
if settings.ENVIRONMENT.is_debug: if settings.ENVIRONMENT.is_debug or auth:
return Snapshot(url).get_context() return SnapshotCamera(url).get_context()
raise HTTPException(status.HTTP_404_NOT_FOUND) raise HTTPException(status.HTTP_404_NOT_FOUND)
@router.get( @router.get(
"/debug/demo", "/api/v1/snap-preview",
summary="Generate a rendered snapshot template for debugging purposes. Endpoint disabled on production deployments.", summary="Generate a rendered snapshot template for debugging purposes.",
response_class=HTMLResponse, response_class=HTMLResponse,
) )
def parse(url: str = "https://www.bbc.com/russian/articles/ckgeey4dqgxo"): def parse(auth: ApiKey, url: str = "https://www.bbc.com/russian/articles/ckgeey4dqgxo"):
if settings.ENVIRONMENT.is_debug: if settings.ENVIRONMENT.is_debug or auth:
return Snapshot(url).render() return SnapshotCamera(url).render()
raise HTTPException(status.HTTP_404_NOT_FOUND) raise HTTPException(status.HTTP_404_NOT_FOUND)
@router.get( @router.get(
"/debug/upload", "/api/v1/snap",
summary="Generate a rendered snapshot template for debugging purposes and upload to Google Cloud Storage. Endpoint disabled on production deployments.", summary="Generate a rendered snapshot template and upload to Google Cloud Storage.",
response_class=HTMLResponse,
) )
def upload( def snap(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
db: DbSession,
auth: ApiKey,
url: str = "https://www.bbc.com/russian/articles/ckgeey4dqgxo", url: str = "https://www.bbc.com/russian/articles/ckgeey4dqgxo",
): ):
if settings.ENVIRONMENT.is_debug: s = db.query(Snapshot).filter(Snapshot.url == url, Snapshot.pool == 0).first()
rendered = Snapshot(url).render() if not s and config_for_url(url):
background_tasks.add_task(upload_snapshot, "debug2.html", rendered) s = Snapshot(url=url, pool=0, snapshot_state=SnapshotState.PENDING, provider=SnapshotProvider.GOOGLE)
return f'<a href="https://storage.googleapis.com/{google_settings.BUCKET_NAME}/debug.html">Google Cloud Storage</a>' db.add(s)
raise HTTPException(status.HTTP_404_NOT_FOUND) db.commit()
background_tasks.add_task(generate_snapshot, s.id)
if s:
return {"url": s.link}
return status.HTTP_403_FORBIDDEN

View file

@ -11,8 +11,13 @@ class SnapshotContext(BaseModel):
article_published: str article_published: str
article_title: str article_title: str
article_url: str article_url: str
article_mirror_url: str | None = None
matomo_host: str
matomo_site_id: int
page_direction: str | None = None page_direction: str | None = None
page_language: str | None = None page_language: str | None = None
site_favicon: str | None = None site_favicon: str | None = None
site_logo: str = None site_logo: str = None
site_title: str site_title: str
site_mirror_url: str | None = None
site_url: str

8
src/snapshots/service.py Normal file
View file

@ -0,0 +1,8 @@
from sqlalchemy.orm import Session
from src.snapshots.models import Snapshot
def resolve_snapshot(db: Session, url: str) -> str | None:
s = db.query(Snapshot).filter(Snapshot.url == url, Snapshot.pool == 0).first()
return s.link if s else None

View file

@ -1,5 +1,29 @@
import logging
from datetime import datetime
from src.database import get_db_session
from src.snapshots.client import SnapshotCamera
from src.snapshots.models import Snapshot, SnapshotState
from src.google.client import upload_blob from src.google.client import upload_blob
from src.utils import hashids
def upload_snapshot(filename: str, content: str) -> None: def generate_snapshot(id_: int) -> None:
upload_blob(filename, content.encode("utf-8"), "text/html") with get_db_session() as db:
snapshot = (
db.query(Snapshot)
.filter(Snapshot.id == id_, Snapshot.snapshot_state == SnapshotState.PENDING)
.first()
)
if not snapshot:
return
try:
content = SnapshotCamera(snapshot.url).render()
upload_blob(hashids.encode(snapshot.id) + ".html", content.encode("utf-8"), "text/html")
snapshot.snapshot_state = SnapshotState.UPDATING
snapshot.snapshot_published_at = datetime.now()
db.commit()
except Exception as e:
logging.error(e)
snapshot.snapshot_state = SnapshotState.FAILED
db.commit()

View file

@ -1,9 +1,8 @@
{% from "article.css.j2" import article_css %}<!DOCTYPE html> {% from "article.css.j2" import article_css %}<!DOCTYPE html>
<html {% if page_direction %} dir="{{ page_direction }}"{% endif %}{% if page_language %} lang="{{ page_language }}"{% endif %} prefix="og: https://ogp.me/ns#"> <html {% if page_direction %} dir="{{ page_direction }}"{% endif %}{% if page_language %} lang="{{ page_language }}"{% endif %} prefix="og: https://ogp.me/ns#">
<base href="{{ article_url }}" />
<head> <head>
<meta charset="utf-8"> <meta charset="utf-8">
<title>{{ article_title }}</title> <title>{{ article_title.strip() }}</title>
<meta name="viewport" <meta name="viewport"
content="width=device-width, initial-scale=1.0, minimum-scale=1.0, maximum-scale=1.0, user-scalable=no"/> content="width=device-width, initial-scale=1.0, minimum-scale=1.0, maximum-scale=1.0, user-scalable=no"/>
<meta name="format-detection" content="telephone=no"/> <meta name="format-detection" content="telephone=no"/>
@ -12,7 +11,7 @@
<meta name="HandheldFriendly" content="True"/> <meta name="HandheldFriendly" content="True"/>
<meta property="og:type" content="article"/> <meta property="og:type" content="article"/>
<meta property="og:title" content="{{ article_title }}"/> <meta property="og:title" content="{{ article_title.strip() }}"/>
<meta property="og:site_name" content="{{ site_title }}"/> <meta property="og:site_name" content="{{ site_title }}"/>
<meta property="og:url" content="{{ article_url }}"/> <meta property="og:url" content="{{ article_url }}"/>
{% if article_image_source %} {% if article_image_source %}
@ -20,9 +19,7 @@
{% endif %} {% endif %}
<meta name="twitter:card" content="summary_large_image"/> <meta name="twitter:card" content="summary_large_image"/>
{% if article_author %}<meta property="article:author" content="{{ article_author }}"/>{% endif %} {% if article_author %}<meta property="article:author" content="{{ article_author }}"/>{% endif %}
<meta name="robots" content="noindex" />
{% if noindex %}<meta name="robots" content="noindex" />{% endif %}
{% if site_favicon %} {% if site_favicon %}
<link rel="icon" href="{{ site_favicon }}" /> <link rel="icon" href="{{ site_favicon }}" />
{% endif %} {% endif %}
@ -47,19 +44,36 @@
let target = e.target.closest("a"); let target = e.target.closest("a");
if (target) { if (target) {
// if the click was on or within an <a> // if the click was on or within an <a>
if (!target.href.includes("cloudfront.net") && if (!target.className.includes("snap-skip-link") &&
!target.href.includes("azureedge.net") && !target.className.includes("snap-link--mirror") &&
!target.href.includes("global.ssl.fastly.net")) { !target.className.includes("snap-link--snapshot")) {
e.preventDefault(); e.preventDefault();
document.body.dataset.currentLink = target.href; document.body.dataset.currentLink = target.href;
} }
} }
}); });
var _paq = window._paq = window._paq || [];
var p = "{{ article_url }}";
_paq.push(["setCustomUrl", p]);
_paq.push(["setExcludedQueryParams", ["roomName", "account", "accountnum", "address", "address1", "address2", "address3", "addressline1", "addressline2", "adres", "adresse", "age", "alter", "auth", "authpw", "bic", "billingaddress", "billingaddress1", "billingaddress2", "calle", "cardnumber", "cc", "ccc", "cccsc", "cccvc", "cccvv", "ccexpiry", "ccexpmonth", "ccexpyear", "ccname", "ccnumber", "cctype", "cell", "cellphone", "city", "clientid", "clientsecret", "company", "consumerkey", "consumersecret", "contrasenya", "contrase\u00f1a", "creditcard", "creditcardnumber", "cvc", "cvv", "dateofbirth", "debitcard", "direcci\u00f3n", "dob", "domain", "ebost", "email", "emailaddress", "emailadresse", "epos", "epost", "eposta", "exp", "familyname", "firma", "firstname", "formlogin", "fullname", "gender", "geschlecht", "gst", "gstnumber", "handynummer", "has\u0142o", "heslo", "iban", "ibanaccountnum", "ibanaccountnumber", "id", "identifier", "indirizzo", "kartakredytowa", "kennwort", "keyconsumerkey", "keyconsumersecret", "konto", "kontonr", "kontonummer", "kredietkaart", "kreditkarte", "kreditkort", "lastname", "login", "mail", "mobiili", "mobile", "mobilne", "nachname", "name", "nickname", "false", "osoite", "parole", "pass", "passord", "password", "passwort", "pasword", "paswort", "paword", "phone", "pin", "plz", "postalcode", "postcode", "postleitzahl", "privatekey", "publickey", "pw", "pwd", "pword", "pwrd", "rue", "secret", "secretq", "secretquestion", "shippingaddress", "shippingaddress1", "shippingaddress2", "socialsec", "socialsecuritynumber", "socsec", "sokak", "ssn", "steuernummer", "strasse", "street", "surname", "swift", "tax", "taxnumber", "tel", "telefon", "telefonnr", "telefonnummer", "telefono", "telephone", "token", "token_auth", "tokenauth", "t\u00e9l\u00e9phone", "ulica", "user", "username", "vat", "vatnumber", "via", "vorname", "wachtwoord", "wagwoord", "webhooksecret", "website", "zip", "zipcode"]]);
_paq.push(["trackPageView", p]);
_paq.push(['enableLinkTracking']);
(function () {
var u = "//{{ matomo_host }}/";
_paq.push(['setTrackerUrl', u + 'matomo.php']);
_paq.push(['setSiteId', '{{ matomo_site_id }}']);
var d = document, g = d.createElement('script'), s = d.getElementsByTagName('script')[0];
g.async = true;
g.src = u + 'matomo.js';
s.parentNode.insertBefore(g, s);
})();
</script> </script>
</head> </head>
<body> <body>
<div class="snap-wrapper"> <div class="snap-wrapper">
<a href="#snap-main" class="snap-skip-link">Skip to main content</a> <a href="#snap-main" class="snap-skip-link">{{ gettext("Skip to main content") }}</a>
<details class="snap-trust-header"> <details class="snap-trust-header">
<summary class="snap-trust-header__header"> <summary class="snap-trust-header__header">
@ -89,14 +103,14 @@
<header class="snap-page-header"> <header class="snap-page-header">
<nav class="snap-page-header-nav"> <nav class="snap-page-header-nav">
{% if article_mirror_url %}<a href="{{ article_mirror_url }}">{% endif %} {% if site_mirror_url %}<a href="{{ site_mirror_url }}" class="snap-link--mirror">{% endif %}
<img src="{{ site_logo }}" alt="{{ site_title }}" class="snap-page-header-logo"> <img src="{{ site_logo }}" alt="{{ site_title }}" class="snap-page-header-logo">
{% if article_mirror_url %}</a>{% endif %} {% if site_mirror_url %}</a>{% endif %}
</nav> </nav>
</header> </header>
<main id="snap-main"> <main id="snap-main">
<header class="snap-article-header"> <header class="snap-article-header">
<h1>{{ article_title }}</h1> <h1>{{ article_title.strip() }}</h1>
<div class="snap-byline"> <div class="snap-byline">
{{ article_published }} - {{ site_title }} {{ article_published }} - {{ site_title }}
</div> </div>
@ -113,8 +127,8 @@
{{ article_body }} {{ article_body }}
{% if article_mirror_url %} {% if article_mirror_url %}
<p> <p>
<a href="{{ article_mirror_url }}" class="snap-footer-link"> <a href="{{ article_mirror_url }}" class="snap-footer-link snap-link--mirror">
View the original article {{ gettext("View the original article") }}
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M19.7212 13.0822C19.3072 13.0822 18.9712 13.4189 18.9712 13.8322V18.9712H5.02881V5.02881H10.167C10.5818 5.02881 10.917 4.69279 10.917 4.27881C10.917 3.86483 10.5818 3.52881 10.167 3.52881H4.27881C3.86405 3.52881 3.52881 3.86483 3.52881 4.27881V19.7212C3.52881 20.136 3.86405 20.4712 4.27881 20.4712H19.7212C20.136 20.4712 20.4712 20.136 20.4712 19.7212V13.8322C20.4712 13.4197 20.136 13.0822 19.7212 13.0822Z" <path d="M19.7212 13.0822C19.3072 13.0822 18.9712 13.4189 18.9712 13.8322V18.9712H5.02881V5.02881H10.167C10.5818 5.02881 10.917 4.69279 10.917 4.27881C10.917 3.86483 10.5818 3.52881 10.167 3.52881H4.27881C3.86405 3.52881 3.52881 3.86483 3.52881 4.27881V19.7212C3.52881 20.136 3.86405 20.4712 4.27881 20.4712H19.7212C20.136 20.4712 20.4712 20.136 20.4712 19.7212V13.8322C20.4712 13.4197 20.136 13.0822 19.7212 13.0822Z"
fill="#222F3A"></path> fill="#222F3A"></path>
@ -128,7 +142,7 @@
</main> </main>
<footer class="snap-footer"> <footer class="snap-footer">
<div> <div>
{% if site_mirror_url %}<a href="https://d7qg4uz16a7xs.cloudfront.net/">{% endif %} {% if site_mirror_url %}<a href="{{ site_mirror_url }}">{% endif %}
<img src="{{ site_logo }}" alt="{{ site_title }} logo"> <img src="{{ site_logo }}" alt="{{ site_title }} logo">
{% if site_mirror_url %}</a>{% endif %} {% if site_mirror_url %}</a>{% endif %}
</div> </div>

View file

@ -144,8 +144,8 @@ figcaption {
box-sizing: border-box; box-sizing: border-box;
color: #333; color: #333;
display: inline-block; display: inline-block;
max-width: 335px; width: auto;
width: 100%; max-width: 100%;
border: 1px solid #e0dfdd; border: 1px solid #e0dfdd;
border-radius: 4px; border-radius: 4px;
padding: 16px 24px; padding: 16px 24px;
@ -168,11 +168,19 @@ figcaption {
} }
.snap-footer-link svg { .snap-footer-link svg {
margin-top: 3px;
}
.snap-footer-link svg:dir(ltr) {
float: right; float: right;
margin-left: 10px;
} }
.snap-footer-link:dir(rtl) svg { .snap-footer-link:dir(rtl) svg {
float: left; float: left;
-webkit-transform: scaleX(-1);
transform: scaleX(-1);
margin-right: 10px;
} }
.snap-footer-link--disabled { .snap-footer-link--disabled {

102
src/utils.py Normal file
View file

@ -0,0 +1,102 @@
import asyncio
import logging
from functools import wraps
from traceback import format_exception
from typing import Coroutine, Callable, Any
from hashids import Hashids
from starlette.concurrency import run_in_threadpool
from src.config import settings
NoArgsNoReturnFuncT = Callable[[], None]
NoArgsNoReturnAsyncFuncT = Callable[[], Coroutine[Any, Any, None]]
ExcArgNoReturnFuncT = Callable[[Exception], None]
ExcArgNoReturnAsyncFuncT = Callable[[Exception], Coroutine[Any, Any, None]]
NoArgsNoReturnAnyFuncT = NoArgsNoReturnFuncT | NoArgsNoReturnAsyncFuncT
ExcArgNoReturnAnyFuncT = ExcArgNoReturnFuncT | ExcArgNoReturnAsyncFuncT
NoArgsNoReturnDecorator = Callable[[NoArgsNoReturnAnyFuncT], NoArgsNoReturnAsyncFuncT]
async def _handle_repeat_func(func: NoArgsNoReturnAnyFuncT) -> None:
if asyncio.iscoroutinefunction(func):
await func()
else:
await run_in_threadpool(func)
async def _handle_repeat_exc(
exc: Exception, on_exception: ExcArgNoReturnAnyFuncT | None
) -> None:
if on_exception:
if asyncio.iscoroutinefunction(on_exception):
await on_exception(exc)
else:
await run_in_threadpool(on_exception, exc)
def repeat_every(
*,
seconds: float,
wait_first: float | None = None,
max_repetitions: int | None = None,
on_complete: NoArgsNoReturnAnyFuncT | None = None,
on_exception: ExcArgNoReturnAnyFuncT | None = None,
) -> NoArgsNoReturnDecorator:
"""
This function returns a decorator that modifies a function so it is periodically re-executed after its first call.
The function it decorates should accept no arguments and return nothing. If necessary, this can be accomplished
by using `functools.partial` or otherwise wrapping the target function prior to decoration.
Parameters
----------
seconds: float
The number of seconds to wait between repeated calls
wait_first: float (default None)
If not None, the function will wait for the given duration before the first call
max_repetitions: Optional[int] (default None)
The maximum number of times to call the repeated function. If `None`, the function is repeated forever.
on_complete: Optional[Callable[[], None]] (default None)
A function to call after the final repetition of the decorated function.
on_exception: Optional[Callable[[Exception], None]] (default None)
A function to call when an exception is raised by the decorated function.
"""
def decorator(func: NoArgsNoReturnAnyFuncT) -> NoArgsNoReturnAsyncFuncT:
"""
Converts the decorated function into a repeated, periodically-called version of itself.
"""
@wraps(func)
async def wrapped() -> None:
async def loop() -> None:
if wait_first is not None:
await asyncio.sleep(wait_first)
repetitions = 0
while max_repetitions is None or repetitions < max_repetitions:
try:
await _handle_repeat_func(func)
except Exception as exc:
formatted_exception = "".join(
format_exception(type(exc), exc, exc.__traceback__)
)
logging.error(formatted_exception)
await _handle_repeat_exc(exc, on_exception)
repetitions += 1
await asyncio.sleep(seconds)
if on_complete:
await _handle_repeat_func(on_complete)
asyncio.ensure_future(loop())
return wrapped
return decorator
hashids = Hashids(min_length=5, salt=settings.HASH_SECRET_KEY)