98 lines
2.7 KiB
Python
98 lines
2.7 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import hashlib
|
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
|
from typing import Protocol
|
|
|
|
from datastar_py import ServerSentEventGenerator as SSE
|
|
from datastar_py.sse import DatastarEvent
|
|
|
|
|
|
class HtmlRenderable(Protocol):
|
|
def __html__(self) -> str: ...
|
|
|
|
|
|
RenderResult = str | HtmlRenderable
|
|
RenderFunction = Callable[[], Awaitable[RenderResult]]
|
|
|
|
|
|
class RefreshBroker:
|
|
def __init__(self) -> None:
|
|
self._subscribers: dict[asyncio.Queue[object], asyncio.AbstractEventLoop] = {}
|
|
|
|
def subscribe(self) -> asyncio.Queue[object]:
|
|
queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1)
|
|
self._subscribers[queue] = asyncio.get_running_loop()
|
|
return queue
|
|
|
|
def unsubscribe(self, queue: asyncio.Queue[object]) -> None:
|
|
self._subscribers.pop(queue, None)
|
|
|
|
def publish(self, event: object = "refresh-event") -> None:
|
|
for queue, loop in tuple(self._subscribers.items()):
|
|
loop.call_soon_threadsafe(_publish_event, queue, event)
|
|
|
|
|
|
def _publish_event(queue: asyncio.Queue[object], event: object) -> None:
|
|
if queue.full():
|
|
try:
|
|
queue.get_nowait()
|
|
except asyncio.QueueEmpty:
|
|
pass
|
|
try:
|
|
queue.put_nowait(event)
|
|
except asyncio.QueueFull:
|
|
return
|
|
|
|
|
|
async def render_sse_event(
|
|
render: RenderFunction,
|
|
*,
|
|
last_event_id: str | None = None,
|
|
use_view_transition: bool = False,
|
|
) -> tuple[str | None, DatastarEvent | None]:
|
|
html = _coerce_html(await render())
|
|
event_id = _render_hash(html)
|
|
if event_id == last_event_id:
|
|
return last_event_id, None
|
|
return event_id, SSE.patch_elements(
|
|
html,
|
|
event_id=event_id,
|
|
use_view_transition=use_view_transition,
|
|
)
|
|
|
|
|
|
async def render_stream(
|
|
queue: asyncio.Queue[object],
|
|
render: RenderFunction,
|
|
*,
|
|
last_event_id: str | None = None,
|
|
render_on_connect: bool = True,
|
|
) -> AsyncGenerator[DatastarEvent, None]:
|
|
if render_on_connect:
|
|
last_event_id, event = await render_sse_event(
|
|
render, last_event_id=last_event_id
|
|
)
|
|
if event is not None:
|
|
yield event
|
|
|
|
while True:
|
|
event_name = await queue.get()
|
|
last_event_id, event = await render_sse_event(
|
|
render,
|
|
last_event_id=last_event_id,
|
|
use_view_transition=event_name == "queue-reordered",
|
|
)
|
|
if event is not None:
|
|
yield event
|
|
|
|
|
|
def _coerce_html(view: RenderResult) -> str:
|
|
if isinstance(view, str):
|
|
return view
|
|
return view.__html__()
|
|
|
|
|
|
def _render_hash(html: str) -> str:
|
|
return hashlib.blake2s(html.encode("utf-8"), digest_size=16).hexdigest()
|