from __future__ import annotations import asyncio import hashlib from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass, field from datetime import UTC, datetime, timedelta from typing import Protocol from datastar_py import ServerSentEventGenerator as SSE from datastar_py.sse import DatastarEvent class HtmlRenderable(Protocol): def __html__(self) -> str: ... RenderResult = str | HtmlRenderable RenderFunction = Callable[[], Awaitable[RenderResult]] PageState = dict[str, object] TabState = dict[str, PageState] @dataclass class _TabSession: data: TabState = field(default_factory=dict) created_at: datetime = field(default_factory=lambda: datetime.now(UTC)) modified_at: datetime = field(default_factory=lambda: datetime.now(UTC)) connections: int = 0 class TabStateStore: def __init__(self, *, clean_age_threshold: timedelta = timedelta(hours=24)) -> None: self._sessions: dict[str, _TabSession] = {} self.clean_age_threshold = clean_age_threshold def connect(self, tab_id: str, *, now: datetime | None = None) -> TabState: session = self._sessions.get(tab_id) current_time = _now(now) if session is None: session = _TabSession(created_at=current_time, modified_at=current_time) self._sessions[tab_id] = session session.connections += 1 return session.data def disconnect(self, tab_id: str) -> None: session = self._sessions.get(tab_id) if session is None: return session.connections = max(0, session.connections - 1) if session.connections == 0: self._sessions.pop(tab_id, None) def get_tab_state(self, tab_id: str) -> TabState | None: session = self._sessions.get(tab_id) return None if session is None else session.data def get_page_state(self, tab_id: str | None, page_key: str) -> PageState: if tab_id is None: return {} session = self._sessions.get(tab_id) if session is None: return {} return session.data.get(page_key, {}) def update_page_state( self, tab_id: str, page_key: str, update: Callable[[PageState], PageState], *, now: datetime | None = None, ) -> PageState: current_time = _now(now) session = self._sessions.get(tab_id) if session is None: session = _TabSession(created_at=current_time, modified_at=current_time) self._sessions[tab_id] = session page_state = dict(session.data.get(page_key, {})) session.data[page_key] = update(page_state) session.modified_at = current_time return session.data[page_key] def cleanup_stale(self, *, now: datetime | None = None) -> set[str]: current_time = _now(now) removed: set[str] = set() for tab_id, session in tuple(self._sessions.items()): if current_time - session.modified_at < self.clean_age_threshold: continue self._sessions.pop(tab_id, None) removed.add(tab_id) return removed class RefreshBroker: def __init__(self) -> None: self._subscribers: dict[ asyncio.Queue[object], tuple[asyncio.AbstractEventLoop, str | None] ] = {} def subscribe(self, *, tab_id: str | None = None) -> asyncio.Queue[object]: queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1) self._subscribers[queue] = (asyncio.get_running_loop(), tab_id) return queue def unsubscribe(self, queue: asyncio.Queue[object]) -> None: self._subscribers.pop(queue, None) def publish( self, event: object = "refresh-event", *, tab_id: str | None = None ) -> None: for queue, (loop, subscriber_tab_id) in tuple(self._subscribers.items()): if tab_id is not None and subscriber_tab_id != tab_id: continue loop.call_soon_threadsafe(_publish_event, queue, event) def _publish_event(queue: asyncio.Queue[object], event: object) -> None: if queue.full(): try: queue.get_nowait() except asyncio.QueueEmpty: pass try: queue.put_nowait(event) except asyncio.QueueFull: return async def render_sse_event( render: RenderFunction, *, last_event_id: str | None = None, use_view_transition: bool = False, ) -> tuple[str | None, DatastarEvent | None]: html = _coerce_html(await render()) event_id = _render_hash(html) if event_id == last_event_id: return last_event_id, None return event_id, SSE.patch_elements( html, event_id=event_id, use_view_transition=use_view_transition, ) async def render_stream( queue: asyncio.Queue[object], render: RenderFunction, *, last_event_id: str | None = None, render_on_connect: bool = True, ) -> AsyncGenerator[DatastarEvent, None]: if render_on_connect: last_event_id, event = await render_sse_event( render, last_event_id=last_event_id ) if event is not None: yield event while True: event_name = await queue.get() last_event_id, event = await render_sse_event( render, last_event_id=last_event_id, use_view_transition=event_name == "queue-reordered", ) if event is not None: yield event def _coerce_html(view: RenderResult) -> str: if isinstance(view, str): return view return view.__html__() def _render_hash(html: str) -> str: return hashlib.blake2s(html.encode("utf-8"), digest_size=16).hexdigest() def _now(now: datetime | None) -> datetime: return now or datetime.now(UTC)