from __future__ import annotations import asyncio import hashlib from collections.abc import AsyncGenerator, Awaitable, Callable from typing import TypedDict, cast from urllib.parse import urlparse import htpy as h from datastar_py import ServerSentEventGenerator as SSE from datastar_py.quart import DatastarResponse, read_signals from datastar_py.sse import DatastarEvent from htpy import Renderable from peewee import IntegrityError from quart import Quart, Response, request, url_for from repub.datastar import RefreshBroker, render_stream from repub.model import ( create_source, initialize_database, load_sources, source_slug_exists, ) from repub.pages import ( create_source_page, dashboard_page, execution_logs_page, runs_page, shim_page, sources_page, ) from repub.pages.sources import PANGEA_CONTENT_FORMATS, PANGEA_CONTENT_TYPES REFRESH_BROKER_KEY = "repub.refresh_broker" RenderFunction = Callable[[], Awaitable[Renderable]] class SourceFormData(TypedDict): name: str slug: str source_type: str notes: str spider_arguments: str enabled: bool cron_minute: str cron_hour: str cron_day_of_month: str cron_day_of_week: str cron_month: str feed_url: str pangea_domain: str pangea_category: str content_format: str content_type: str max_articles: int | None oldest_article: int | None only_newest: bool include_authors: bool exclude_media: bool include_content: bool DEFAULT_PANGEA_CONTENT_FORMAT = "MOBILE_3" DEFAULT_PANGEA_CONTENT_TYPE = "articles" DEFAULT_PANGEA_MAX_ARTICLES = "10" DEFAULT_PANGEA_OLDEST_ARTICLE = "3" def _render_shim_page(*, stylesheet_href: str, datastar_src: str) -> tuple[str, str]: head = ( h.title["Republisher Admin UI"], h.link(rel="stylesheet", href=stylesheet_href), ) body = str(shim_page(datastar_src=datastar_src, head=head)) etag = hashlib.sha256(body.encode("utf-8")).hexdigest() return body, etag def create_app() -> Quart: app = Quart(__name__) app.config["REPUB_DB_PATH"] = str(initialize_database()) app.extensions[REFRESH_BROKER_KEY] = RefreshBroker() @app.get("/") @app.get("/sources") @app.get("/sources/create") @app.get("/runs") @app.get("/job//execution//logs") async def page_shim( job_id: int | None = None, execution_id: int | None = None ) -> Response: del job_id, execution_id body, etag = _render_shim_page( stylesheet_href=url_for("static", filename="app.css"), datastar_src=url_for("static", filename="datastar@1.0.0-RC.8.js"), ) if request.if_none_match.contains(etag): response = Response(status=304) response.set_etag(etag) return response response = Response(body, mimetype="text/html") response.set_etag(etag) return response @app.post("/") async def dashboard_patch() -> DatastarResponse: return _page_patch_response(app, render_dashboard) @app.post("/sources") async def sources_patch() -> DatastarResponse: return _page_patch_response(app, lambda: render_sources(app)) @app.post("/sources/create") async def create_source_patch() -> DatastarResponse: return _page_patch_response(app, lambda: render_create_source(app)) @app.post("/actions/sources/create") async def create_source_action() -> DatastarResponse: signals = cast(dict[str, object], await read_signals()) source, error = validate_source_form( signals, slug_exists=source_slug_exists, ) if error is not None: return DatastarResponse( SSE.patch_signals({"_formError": error, "_formSuccess": ""}) ) assert source is not None try: create_source(**source) except IntegrityError: return DatastarResponse( SSE.patch_signals( {"_formError": "Slug must be unique.", "_formSuccess": ""} ) ) trigger_refresh(app) return DatastarResponse(SSE.redirect("/sources")) @app.post("/runs") async def runs_patch() -> DatastarResponse: return _page_patch_response(app, render_runs) @app.post("/job//execution//logs") async def logs_patch(job_id: int, execution_id: int) -> DatastarResponse: async def render() -> Renderable: return await render_execution_logs(job_id=job_id, execution_id=execution_id) return _page_patch_response(app, render) return app def get_refresh_broker(app: Quart) -> RefreshBroker: return cast(RefreshBroker, app.extensions[REFRESH_BROKER_KEY]) def trigger_refresh(app: Quart, event: object = "refresh-event") -> None: get_refresh_broker(app).publish(event) async def render_dashboard() -> Renderable: return dashboard_page() async def render_sources(app: Quart | None = None) -> Renderable: sources = None if app is None else load_sources() return sources_page(sources=sources) async def render_create_source(app: Quart | None = None) -> Renderable: del app return create_source_page() async def render_runs() -> Renderable: return runs_page() async def render_execution_logs(*, job_id: int, execution_id: int) -> Renderable: return execution_logs_page(job_id=job_id, execution_id=execution_id) def _page_patch_response(app: Quart, render: RenderFunction) -> DatastarResponse: queue = get_refresh_broker(app).subscribe() stream = render_stream( queue, render=render, last_event_id=request.headers.get("last-event-id"), ) return DatastarResponse(_unsubscribe_on_close(queue, stream, app)) async def _unsubscribe_on_close( queue: object, stream: AsyncGenerator[DatastarEvent, None], app: Quart ) -> AsyncGenerator[DatastarEvent, None]: try: async for event in stream: yield event finally: get_refresh_broker(app).unsubscribe(cast(asyncio.Queue[object], queue)) def validate_source_form( signals: dict[str, object] | None, *, slug_exists: Callable[[str], bool], ) -> tuple[SourceFormData | None, str | None]: if signals is None: return None, "Missing form data." source_name = _read_string(signals, "sourceName") source_slug = _read_string(signals, "sourceSlug") source_type = _read_string(signals, "sourceType") feed_url = _read_string(signals, "feedUrl") pangea_domain = _read_string(signals, "pangeaDomain") pangea_category = _read_string(signals, "pangeaCategory") content_format = _read_string(signals, "contentFormat") content_type = _read_string(signals, "contentType") max_articles = _read_string(signals, "maxArticles") oldest_article = _read_string(signals, "oldestArticle") source_notes = _read_string(signals, "sourceNotes") spider_arguments = _normalize_multiline(_read_string(signals, "spiderArguments")) cron_minute = _read_string(signals, "cronMinute") cron_hour = _read_string(signals, "cronHour") cron_day_of_month = _read_string(signals, "cronDayOfMonth") cron_day_of_week = _read_string(signals, "cronDayOfWeek") cron_month = _read_string(signals, "cronMonth") errors: list[str] = [] if source_name == "": errors.append("Source name is required.") if source_slug == "": errors.append("Slug is required.") elif slug_exists(source_slug): errors.append("Slug must be unique.") if source_type not in {"feed", "pangea"}: errors.append("Source type must be feed or pangea.") if source_type == "feed": if feed_url == "": errors.append("Feed URL is required for feed sources.") elif not _is_valid_url(feed_url): errors.append("Feed URL must be a valid URL.") if source_type == "pangea": content_format = content_format or DEFAULT_PANGEA_CONTENT_FORMAT content_type = content_type or DEFAULT_PANGEA_CONTENT_TYPE max_articles = max_articles or DEFAULT_PANGEA_MAX_ARTICLES oldest_article = oldest_article or DEFAULT_PANGEA_OLDEST_ARTICLE if pangea_domain == "": errors.append("Pangea domain is required.") if pangea_category == "": errors.append("Category name is required.") if content_format not in PANGEA_CONTENT_FORMATS: errors.append("Content format is invalid.") if content_type not in PANGEA_CONTENT_TYPES: errors.append("Content type is invalid.") if _parse_int(max_articles) is None: errors.append("Max articles must be an integer.") if _parse_int(oldest_article) is None: errors.append("Oldest article must be an integer.") cron_values = ( cron_minute, cron_hour, cron_day_of_month, cron_day_of_week, cron_month, ) if any(value == "" for value in cron_values): errors.append("All cron fields are required.") if errors: return None, " ".join(errors) enabled = _read_bool(signals, "jobEnabled") source: SourceFormData = { "name": source_name, "slug": source_slug, "source_type": source_type, "notes": source_notes, "spider_arguments": spider_arguments, "feed_url": feed_url, "pangea_domain": pangea_domain, "pangea_category": pangea_category, "content_format": content_format, "content_type": content_type, "max_articles": _parse_int(max_articles), "oldest_article": _parse_int(oldest_article), "enabled": enabled, "only_newest": _read_bool(signals, "onlyNewest", default=True), "include_authors": _read_bool(signals, "includeAuthors", default=True), "exclude_media": _read_bool(signals, "excludeMedia", default=False), "include_content": _read_bool(signals, "includeContent", default=True), "cron_minute": cron_minute, "cron_hour": cron_hour, "cron_day_of_month": cron_day_of_month, "cron_day_of_week": cron_day_of_week, "cron_month": cron_month, } return source, None def _read_string(signals: dict[str, object], key: str) -> str: return str(signals.get(key, "")).strip() def _read_bool(signals: dict[str, object], key: str, *, default: bool = False) -> bool: value = signals.get(key, default) if isinstance(value, bool): return value if isinstance(value, str): return value.lower() in {"true", "1", "on", "yes"} return bool(value) def _normalize_multiline(value: str) -> str: return value.replace("\r\n", "\n").replace("\r", "\n") def _parse_int(value: str) -> int | None: try: return int(value) except ValueError: return None def _is_valid_url(value: str) -> bool: parsed = urlparse(value) return parsed.scheme in {"http", "https"} and parsed.netloc != ""