republisher/repub/datastar.py

209 lines
6.7 KiB
Python

from __future__ import annotations
import asyncio
import hashlib
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import suppress
from dataclasses import dataclass, field
from datetime import UTC, datetime, timedelta
from typing import Protocol
from datastar_py import ServerSentEventGenerator as SSE
from datastar_py.sse import DatastarEvent
class HtmlRenderable(Protocol):
def __html__(self) -> str: ...
RenderResult = str | HtmlRenderable
RenderFunction = Callable[[], Awaitable[RenderResult]]
PageState = dict[str, object]
TabState = dict[str, PageState]
@dataclass
class _TabSession:
data: TabState = field(default_factory=dict)
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
modified_at: datetime = field(default_factory=lambda: datetime.now(UTC))
connections: int = 0
class TabStateStore:
def __init__(self, *, clean_age_threshold: timedelta = timedelta(hours=24)) -> None:
self._sessions: dict[str, _TabSession] = {}
self.clean_age_threshold = clean_age_threshold
def connect(self, tab_id: str, *, now: datetime | None = None) -> TabState:
session = self._sessions.get(tab_id)
current_time = _now(now)
if session is None:
session = _TabSession(created_at=current_time, modified_at=current_time)
self._sessions[tab_id] = session
session.connections += 1
return session.data
def disconnect(self, tab_id: str) -> None:
session = self._sessions.get(tab_id)
if session is None:
return
session.connections = max(0, session.connections - 1)
if session.connections == 0:
self._sessions.pop(tab_id, None)
def get_tab_state(self, tab_id: str) -> TabState | None:
session = self._sessions.get(tab_id)
return None if session is None else session.data
def get_page_state(self, tab_id: str | None, page_key: str) -> PageState:
if tab_id is None:
return {}
session = self._sessions.get(tab_id)
if session is None:
return {}
return session.data.get(page_key, {})
def update_page_state(
self,
tab_id: str,
page_key: str,
update: Callable[[PageState], PageState],
*,
now: datetime | None = None,
) -> PageState:
current_time = _now(now)
session = self._sessions.get(tab_id)
if session is None:
session = _TabSession(created_at=current_time, modified_at=current_time)
self._sessions[tab_id] = session
page_state = dict(session.data.get(page_key, {}))
session.data[page_key] = update(page_state)
session.modified_at = current_time
return session.data[page_key]
def cleanup_stale(self, *, now: datetime | None = None) -> set[str]:
current_time = _now(now)
removed: set[str] = set()
for tab_id, session in tuple(self._sessions.items()):
if current_time - session.modified_at < self.clean_age_threshold:
continue
self._sessions.pop(tab_id, None)
removed.add(tab_id)
return removed
class RefreshBroker:
def __init__(self) -> None:
self._subscribers: dict[
asyncio.Queue[object], tuple[asyncio.AbstractEventLoop, str | None]
] = {}
def subscribe(self, *, tab_id: str | None = None) -> asyncio.Queue[object]:
queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1)
self._subscribers[queue] = (asyncio.get_running_loop(), tab_id)
return queue
def unsubscribe(self, queue: asyncio.Queue[object]) -> None:
self._subscribers.pop(queue, None)
def publish(
self, event: object = "refresh-event", *, tab_id: str | None = None
) -> None:
for queue, (loop, subscriber_tab_id) in tuple(self._subscribers.items()):
if tab_id is not None and subscriber_tab_id != tab_id:
continue
loop.call_soon_threadsafe(_publish_event, queue, event)
def _publish_event(queue: asyncio.Queue[object], event: object) -> None:
if queue.full():
try:
queue.get_nowait()
except asyncio.QueueEmpty:
pass
try:
queue.put_nowait(event)
except asyncio.QueueFull:
return
async def render_sse_event(
render: RenderFunction,
*,
last_event_id: str | None = None,
use_view_transition: bool = False,
) -> tuple[str | None, DatastarEvent | None]:
html = _coerce_html(await render())
event_id = _render_hash(html)
if event_id == last_event_id:
return last_event_id, None
return event_id, SSE.patch_elements(
html,
event_id=event_id,
use_view_transition=use_view_transition,
)
async def render_stream(
queue: asyncio.Queue[object],
render: RenderFunction,
*,
last_event_id: str | None = None,
render_on_connect: bool = True,
shutdown_event: asyncio.Event | None = None,
) -> AsyncGenerator[DatastarEvent, None]:
if render_on_connect:
last_event_id, event = await render_sse_event(
render, last_event_id=last_event_id
)
if event is not None:
yield event
while True:
if shutdown_event is None:
event_name = await queue.get()
else:
if shutdown_event.is_set():
return
queue_task = asyncio.create_task(queue.get())
shutdown_task = asyncio.create_task(shutdown_event.wait())
try:
done, _pending = await asyncio.wait(
{queue_task, shutdown_task},
return_when=asyncio.FIRST_COMPLETED,
)
if shutdown_task in done:
return
event_name = queue_task.result()
finally:
for task in (queue_task, shutdown_task):
if not task.done():
task.cancel()
for task in (queue_task, shutdown_task):
if task.done() and not task.cancelled():
continue
with suppress(asyncio.CancelledError):
await task
last_event_id, event = await render_sse_event(
render,
last_event_id=last_event_id,
use_view_transition=event_name == "queue-reordered",
)
if event is not None:
yield event
def _coerce_html(view: RenderResult) -> str:
if isinstance(view, str):
return view
return view.__html__()
def _render_hash(html: str) -> str:
return hashlib.blake2s(html.encode("utf-8"), digest_size=16).hexdigest()
def _now(now: datetime | None) -> datetime:
return now or datetime.now(UTC)