86 lines
2.4 KiB
Python
86 lines
2.4 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import hashlib
|
||
|
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||
|
|
from typing import Protocol
|
||
|
|
|
||
|
|
from datastar_py import ServerSentEventGenerator as SSE
|
||
|
|
from datastar_py.sse import DatastarEvent
|
||
|
|
|
||
|
|
|
||
|
|
class HtmlRenderable(Protocol):
|
||
|
|
def __html__(self) -> str: ...
|
||
|
|
|
||
|
|
|
||
|
|
RenderResult = str | HtmlRenderable
|
||
|
|
RenderFunction = Callable[[], Awaitable[RenderResult]]
|
||
|
|
|
||
|
|
|
||
|
|
class RefreshBroker:
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self._subscribers: set[asyncio.Queue[object]] = set()
|
||
|
|
|
||
|
|
def subscribe(self) -> asyncio.Queue[object]:
|
||
|
|
queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1)
|
||
|
|
self._subscribers.add(queue)
|
||
|
|
return queue
|
||
|
|
|
||
|
|
def unsubscribe(self, queue: asyncio.Queue[object]) -> None:
|
||
|
|
self._subscribers.discard(queue)
|
||
|
|
|
||
|
|
def publish(self, event: object = "refresh-event") -> None:
|
||
|
|
for queue in tuple(self._subscribers):
|
||
|
|
if queue.full():
|
||
|
|
try:
|
||
|
|
queue.get_nowait()
|
||
|
|
except asyncio.QueueEmpty:
|
||
|
|
pass
|
||
|
|
try:
|
||
|
|
queue.put_nowait(event)
|
||
|
|
except asyncio.QueueFull:
|
||
|
|
continue
|
||
|
|
|
||
|
|
|
||
|
|
async def render_sse_event(
|
||
|
|
render: RenderFunction, *, last_event_id: str | None = None
|
||
|
|
) -> tuple[str | None, DatastarEvent | None]:
|
||
|
|
html = _coerce_html(await render())
|
||
|
|
event_id = _render_hash(html)
|
||
|
|
if event_id == last_event_id:
|
||
|
|
return last_event_id, None
|
||
|
|
return event_id, SSE.patch_elements(html, event_id=event_id)
|
||
|
|
|
||
|
|
|
||
|
|
async def render_stream(
|
||
|
|
queue: asyncio.Queue[object],
|
||
|
|
render: RenderFunction,
|
||
|
|
*,
|
||
|
|
last_event_id: str | None = None,
|
||
|
|
render_on_connect: bool = True,
|
||
|
|
) -> AsyncGenerator[DatastarEvent, None]:
|
||
|
|
if render_on_connect:
|
||
|
|
last_event_id, event = await render_sse_event(
|
||
|
|
render, last_event_id=last_event_id
|
||
|
|
)
|
||
|
|
if event is not None:
|
||
|
|
yield event
|
||
|
|
|
||
|
|
while True:
|
||
|
|
await queue.get()
|
||
|
|
last_event_id, event = await render_sse_event(
|
||
|
|
render, last_event_id=last_event_id
|
||
|
|
)
|
||
|
|
if event is not None:
|
||
|
|
yield event
|
||
|
|
|
||
|
|
|
||
|
|
def _coerce_html(view: RenderResult) -> str:
|
||
|
|
if isinstance(view, str):
|
||
|
|
return view
|
||
|
|
return view.__html__()
|
||
|
|
|
||
|
|
|
||
|
|
def _render_hash(html: str) -> str:
|
||
|
|
return hashlib.blake2s(html.encode("utf-8"), digest_size=16).hexdigest()
|