from __future__ import annotations import asyncio import hashlib from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Protocol from datastar_py import ServerSentEventGenerator as SSE from datastar_py.sse import DatastarEvent class HtmlRenderable(Protocol): def __html__(self) -> str: ... RenderResult = str | HtmlRenderable RenderFunction = Callable[[], Awaitable[RenderResult]] class RefreshBroker: def __init__(self) -> None: self._subscribers: dict[asyncio.Queue[object], asyncio.AbstractEventLoop] = {} def subscribe(self) -> asyncio.Queue[object]: queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1) self._subscribers[queue] = asyncio.get_running_loop() return queue def unsubscribe(self, queue: asyncio.Queue[object]) -> None: self._subscribers.pop(queue, None) def publish(self, event: object = "refresh-event") -> None: for queue, loop in tuple(self._subscribers.items()): loop.call_soon_threadsafe(_publish_event, queue, event) def _publish_event(queue: asyncio.Queue[object], event: object) -> None: if queue.full(): try: queue.get_nowait() except asyncio.QueueEmpty: pass try: queue.put_nowait(event) except asyncio.QueueFull: return async def render_sse_event( render: RenderFunction, *, last_event_id: str | None = None, use_view_transition: bool = False, ) -> tuple[str | None, DatastarEvent | None]: html = _coerce_html(await render()) event_id = _render_hash(html) if event_id == last_event_id: return last_event_id, None return event_id, SSE.patch_elements( html, event_id=event_id, use_view_transition=use_view_transition, ) async def render_stream( queue: asyncio.Queue[object], render: RenderFunction, *, last_event_id: str | None = None, render_on_connect: bool = True, ) -> AsyncGenerator[DatastarEvent, None]: if render_on_connect: last_event_id, event = await render_sse_event( render, last_event_id=last_event_id ) if event is not None: yield event while True: event_name = await queue.get() last_event_id, event = await render_sse_event( render, last_event_id=last_event_id, use_view_transition=event_name == "queue-reordered", ) if event is not None: yield event def _coerce_html(view: RenderResult) -> str: if isinstance(view, str): return view return view.__html__() def _render_hash(html: str) -> str: return hashlib.blake2s(html.encode("utf-8"), digest_size=16).hexdigest()