from __future__ import annotations import asyncio import hashlib from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import suppress from dataclasses import dataclass, field from datetime import UTC, datetime, timedelta from typing import Protocol from datastar_py import ServerSentEventGenerator as SSE from datastar_py.sse import DatastarEvent class HtmlRenderable(Protocol): def __html__(self) -> str: ... RenderResult = str | HtmlRenderable RenderFunction = Callable[[], Awaitable[RenderResult]] PageState = dict[str, object] TabState = dict[str, PageState] @dataclass class _TabSession: data: TabState = field(default_factory=dict) created_at: datetime = field(default_factory=lambda: datetime.now(UTC)) modified_at: datetime = field(default_factory=lambda: datetime.now(UTC)) connections: int = 0 class TabStateStore: def __init__(self, *, clean_age_threshold: timedelta = timedelta(hours=24)) -> None: self._sessions: dict[str, _TabSession] = {} self.clean_age_threshold = clean_age_threshold def connect(self, tab_id: str, *, now: datetime | None = None) -> TabState: session = self._sessions.get(tab_id) current_time = _now(now) if session is None: session = _TabSession(created_at=current_time, modified_at=current_time) self._sessions[tab_id] = session session.connections += 1 return session.data def disconnect(self, tab_id: str) -> None: session = self._sessions.get(tab_id) if session is None: return session.connections = max(0, session.connections - 1) if session.connections == 0: self._sessions.pop(tab_id, None) def get_tab_state(self, tab_id: str) -> TabState | None: session = self._sessions.get(tab_id) return None if session is None else session.data def get_page_state(self, tab_id: str | None, page_key: str) -> PageState: if tab_id is None: return {} session = self._sessions.get(tab_id) if session is None: return {} return session.data.get(page_key, {}) def update_page_state( self, tab_id: str, page_key: str, update: Callable[[PageState], PageState], *, now: datetime | None = None, ) -> PageState: current_time = _now(now) session = self._sessions.get(tab_id) if session is None: session = _TabSession(created_at=current_time, modified_at=current_time) self._sessions[tab_id] = session page_state = dict(session.data.get(page_key, {})) session.data[page_key] = update(page_state) session.modified_at = current_time return session.data[page_key] def cleanup_stale(self, *, now: datetime | None = None) -> set[str]: current_time = _now(now) removed: set[str] = set() for tab_id, session in tuple(self._sessions.items()): if current_time - session.modified_at < self.clean_age_threshold: continue self._sessions.pop(tab_id, None) removed.add(tab_id) return removed class RefreshBroker: def __init__(self) -> None: self._subscribers: dict[ asyncio.Queue[object], tuple[asyncio.AbstractEventLoop, str | None] ] = {} def subscribe(self, *, tab_id: str | None = None) -> asyncio.Queue[object]: queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1) self._subscribers[queue] = (asyncio.get_running_loop(), tab_id) return queue def unsubscribe(self, queue: asyncio.Queue[object]) -> None: self._subscribers.pop(queue, None) def publish( self, event: object = "refresh-event", *, tab_id: str | None = None ) -> None: for queue, (loop, subscriber_tab_id) in tuple(self._subscribers.items()): if tab_id is not None and subscriber_tab_id != tab_id: continue loop.call_soon_threadsafe(_publish_event, queue, event) def _publish_event(queue: asyncio.Queue[object], event: object) -> None: if queue.full(): try: queue.get_nowait() except asyncio.QueueEmpty: pass try: queue.put_nowait(event) except asyncio.QueueFull: return async def render_sse_event( render: RenderFunction, *, last_event_id: str | None = None, use_view_transition: bool = False, ) -> tuple[str | None, DatastarEvent | None]: html = _coerce_html(await render()) event_id = _render_hash(html) if event_id == last_event_id: return last_event_id, None return event_id, SSE.patch_elements( html, event_id=event_id, use_view_transition=use_view_transition, ) async def render_stream( queue: asyncio.Queue[object], render: RenderFunction, *, last_event_id: str | None = None, render_on_connect: bool = True, shutdown_event: asyncio.Event | None = None, ) -> AsyncGenerator[DatastarEvent, None]: if render_on_connect: last_event_id, event = await render_sse_event( render, last_event_id=last_event_id ) if event is not None: yield event while True: if shutdown_event is None: event_name = await queue.get() else: if shutdown_event.is_set(): return queue_task = asyncio.create_task(queue.get()) shutdown_task = asyncio.create_task(shutdown_event.wait()) done, pending = await asyncio.wait( {queue_task, shutdown_task}, return_when=asyncio.FIRST_COMPLETED, ) for task in pending: task.cancel() for task in pending: with suppress(asyncio.CancelledError): await task if shutdown_task in done: with suppress(asyncio.CancelledError): await queue_task return event_name = queue_task.result() last_event_id, event = await render_sse_event( render, last_event_id=last_event_id, use_view_transition=event_name == "queue-reordered", ) if event is not None: yield event def _coerce_html(view: RenderResult) -> str: if isinstance(view, str): return view return view.__html__() def _render_hash(html: str) -> str: return hashlib.blake2s(html.encode("utf-8"), digest_size=16).hexdigest() def _now(now: datetime | None) -> datetime: return now or datetime.now(UTC)