republisher/repub/datastar.py

98 lines
2.7 KiB
Python

from __future__ import annotations
import asyncio
import hashlib
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Protocol
from datastar_py import ServerSentEventGenerator as SSE
from datastar_py.sse import DatastarEvent
class HtmlRenderable(Protocol):
def __html__(self) -> str: ...
RenderResult = str | HtmlRenderable
RenderFunction = Callable[[], Awaitable[RenderResult]]
class RefreshBroker:
def __init__(self) -> None:
self._subscribers: dict[asyncio.Queue[object], asyncio.AbstractEventLoop] = {}
def subscribe(self) -> asyncio.Queue[object]:
queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1)
self._subscribers[queue] = asyncio.get_running_loop()
return queue
def unsubscribe(self, queue: asyncio.Queue[object]) -> None:
self._subscribers.pop(queue, None)
def publish(self, event: object = "refresh-event") -> None:
for queue, loop in tuple(self._subscribers.items()):
loop.call_soon_threadsafe(_publish_event, queue, event)
def _publish_event(queue: asyncio.Queue[object], event: object) -> None:
if queue.full():
try:
queue.get_nowait()
except asyncio.QueueEmpty:
pass
try:
queue.put_nowait(event)
except asyncio.QueueFull:
return
async def render_sse_event(
render: RenderFunction,
*,
last_event_id: str | None = None,
use_view_transition: bool = False,
) -> tuple[str | None, DatastarEvent | None]:
html = _coerce_html(await render())
event_id = _render_hash(html)
if event_id == last_event_id:
return last_event_id, None
return event_id, SSE.patch_elements(
html,
event_id=event_id,
use_view_transition=use_view_transition,
)
async def render_stream(
queue: asyncio.Queue[object],
render: RenderFunction,
*,
last_event_id: str | None = None,
render_on_connect: bool = True,
) -> AsyncGenerator[DatastarEvent, None]:
if render_on_connect:
last_event_id, event = await render_sse_event(
render, last_event_id=last_event_id
)
if event is not None:
yield event
while True:
event_name = await queue.get()
last_event_id, event = await render_sse_event(
render,
last_event_id=last_event_id,
use_view_transition=event_name == "queue-reordered",
)
if event is not None:
yield event
def _coerce_html(view: RenderResult) -> str:
if isinstance(view, str):
return view
return view.__html__()
def _render_hash(html: str) -> str:
return hashlib.blake2s(html.encode("utf-8"), digest_size=16).hexdigest()