2026-03-30 12:34:38 +02:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
|
import hashlib
|
|
|
|
|
from collections.abc import AsyncGenerator, Awaitable, Callable
|
2026-03-31 12:12:36 +02:00
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from datetime import UTC, datetime, timedelta
|
2026-03-30 12:34:38 +02:00
|
|
|
from typing import Protocol
|
|
|
|
|
|
|
|
|
|
from datastar_py import ServerSentEventGenerator as SSE
|
|
|
|
|
from datastar_py.sse import DatastarEvent
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HtmlRenderable(Protocol):
|
|
|
|
|
def __html__(self) -> str: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
RenderResult = str | HtmlRenderable
|
|
|
|
|
RenderFunction = Callable[[], Awaitable[RenderResult]]
|
2026-03-31 12:12:36 +02:00
|
|
|
PageState = dict[str, object]
|
|
|
|
|
TabState = dict[str, PageState]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class _TabSession:
|
|
|
|
|
data: TabState = field(default_factory=dict)
|
|
|
|
|
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
|
|
|
|
modified_at: datetime = field(default_factory=lambda: datetime.now(UTC))
|
|
|
|
|
connections: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TabStateStore:
|
|
|
|
|
def __init__(self, *, clean_age_threshold: timedelta = timedelta(hours=24)) -> None:
|
|
|
|
|
self._sessions: dict[str, _TabSession] = {}
|
|
|
|
|
self.clean_age_threshold = clean_age_threshold
|
|
|
|
|
|
|
|
|
|
def connect(self, tab_id: str, *, now: datetime | None = None) -> TabState:
|
|
|
|
|
session = self._sessions.get(tab_id)
|
|
|
|
|
current_time = _now(now)
|
|
|
|
|
if session is None:
|
|
|
|
|
session = _TabSession(created_at=current_time, modified_at=current_time)
|
|
|
|
|
self._sessions[tab_id] = session
|
|
|
|
|
session.connections += 1
|
|
|
|
|
return session.data
|
|
|
|
|
|
|
|
|
|
def disconnect(self, tab_id: str) -> None:
|
|
|
|
|
session = self._sessions.get(tab_id)
|
|
|
|
|
if session is None:
|
|
|
|
|
return
|
|
|
|
|
session.connections = max(0, session.connections - 1)
|
|
|
|
|
if session.connections == 0:
|
|
|
|
|
self._sessions.pop(tab_id, None)
|
|
|
|
|
|
|
|
|
|
def get_tab_state(self, tab_id: str) -> TabState | None:
|
|
|
|
|
session = self._sessions.get(tab_id)
|
|
|
|
|
return None if session is None else session.data
|
|
|
|
|
|
|
|
|
|
def get_page_state(self, tab_id: str | None, page_key: str) -> PageState:
|
|
|
|
|
if tab_id is None:
|
|
|
|
|
return {}
|
|
|
|
|
session = self._sessions.get(tab_id)
|
|
|
|
|
if session is None:
|
|
|
|
|
return {}
|
|
|
|
|
return session.data.get(page_key, {})
|
|
|
|
|
|
|
|
|
|
def update_page_state(
|
|
|
|
|
self,
|
|
|
|
|
tab_id: str,
|
|
|
|
|
page_key: str,
|
|
|
|
|
update: Callable[[PageState], PageState],
|
|
|
|
|
*,
|
|
|
|
|
now: datetime | None = None,
|
|
|
|
|
) -> PageState:
|
|
|
|
|
current_time = _now(now)
|
|
|
|
|
session = self._sessions.get(tab_id)
|
|
|
|
|
if session is None:
|
|
|
|
|
session = _TabSession(created_at=current_time, modified_at=current_time)
|
|
|
|
|
self._sessions[tab_id] = session
|
|
|
|
|
|
|
|
|
|
page_state = dict(session.data.get(page_key, {}))
|
|
|
|
|
session.data[page_key] = update(page_state)
|
|
|
|
|
session.modified_at = current_time
|
|
|
|
|
return session.data[page_key]
|
|
|
|
|
|
|
|
|
|
def cleanup_stale(self, *, now: datetime | None = None) -> set[str]:
|
|
|
|
|
current_time = _now(now)
|
|
|
|
|
removed: set[str] = set()
|
|
|
|
|
for tab_id, session in tuple(self._sessions.items()):
|
|
|
|
|
if current_time - session.modified_at < self.clean_age_threshold:
|
|
|
|
|
continue
|
|
|
|
|
self._sessions.pop(tab_id, None)
|
|
|
|
|
removed.add(tab_id)
|
|
|
|
|
return removed
|
2026-03-30 12:34:38 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class RefreshBroker:
|
|
|
|
|
def __init__(self) -> None:
|
2026-03-31 12:12:36 +02:00
|
|
|
self._subscribers: dict[
|
|
|
|
|
asyncio.Queue[object], tuple[asyncio.AbstractEventLoop, str | None]
|
|
|
|
|
] = {}
|
2026-03-30 12:34:38 +02:00
|
|
|
|
2026-03-31 12:12:36 +02:00
|
|
|
def subscribe(self, *, tab_id: str | None = None) -> asyncio.Queue[object]:
|
2026-03-30 12:34:38 +02:00
|
|
|
queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1)
|
2026-03-31 12:12:36 +02:00
|
|
|
self._subscribers[queue] = (asyncio.get_running_loop(), tab_id)
|
2026-03-30 12:34:38 +02:00
|
|
|
return queue
|
|
|
|
|
|
|
|
|
|
def unsubscribe(self, queue: asyncio.Queue[object]) -> None:
|
2026-03-30 14:02:39 +02:00
|
|
|
self._subscribers.pop(queue, None)
|
2026-03-30 12:34:38 +02:00
|
|
|
|
2026-03-31 12:12:36 +02:00
|
|
|
def publish(
|
|
|
|
|
self, event: object = "refresh-event", *, tab_id: str | None = None
|
|
|
|
|
) -> None:
|
|
|
|
|
for queue, (loop, subscriber_tab_id) in tuple(self._subscribers.items()):
|
|
|
|
|
if tab_id is not None and subscriber_tab_id != tab_id:
|
|
|
|
|
continue
|
2026-03-30 14:02:39 +02:00
|
|
|
loop.call_soon_threadsafe(_publish_event, queue, event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _publish_event(queue: asyncio.Queue[object], event: object) -> None:
|
|
|
|
|
if queue.full():
|
|
|
|
|
try:
|
|
|
|
|
queue.get_nowait()
|
|
|
|
|
except asyncio.QueueEmpty:
|
|
|
|
|
pass
|
|
|
|
|
try:
|
|
|
|
|
queue.put_nowait(event)
|
|
|
|
|
except asyncio.QueueFull:
|
|
|
|
|
return
|
2026-03-30 12:34:38 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def render_sse_event(
|
2026-03-31 10:23:46 +02:00
|
|
|
render: RenderFunction,
|
|
|
|
|
*,
|
|
|
|
|
last_event_id: str | None = None,
|
|
|
|
|
use_view_transition: bool = False,
|
2026-03-30 12:34:38 +02:00
|
|
|
) -> tuple[str | None, DatastarEvent | None]:
|
|
|
|
|
html = _coerce_html(await render())
|
|
|
|
|
event_id = _render_hash(html)
|
|
|
|
|
if event_id == last_event_id:
|
|
|
|
|
return last_event_id, None
|
2026-03-31 10:23:46 +02:00
|
|
|
return event_id, SSE.patch_elements(
|
|
|
|
|
html,
|
|
|
|
|
event_id=event_id,
|
|
|
|
|
use_view_transition=use_view_transition,
|
|
|
|
|
)
|
2026-03-30 12:34:38 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
async def render_stream(
|
|
|
|
|
queue: asyncio.Queue[object],
|
|
|
|
|
render: RenderFunction,
|
|
|
|
|
*,
|
|
|
|
|
last_event_id: str | None = None,
|
|
|
|
|
render_on_connect: bool = True,
|
|
|
|
|
) -> AsyncGenerator[DatastarEvent, None]:
|
|
|
|
|
if render_on_connect:
|
|
|
|
|
last_event_id, event = await render_sse_event(
|
|
|
|
|
render, last_event_id=last_event_id
|
|
|
|
|
)
|
|
|
|
|
if event is not None:
|
|
|
|
|
yield event
|
|
|
|
|
|
|
|
|
|
while True:
|
2026-03-31 10:23:46 +02:00
|
|
|
event_name = await queue.get()
|
2026-03-30 12:34:38 +02:00
|
|
|
last_event_id, event = await render_sse_event(
|
2026-03-31 10:23:46 +02:00
|
|
|
render,
|
|
|
|
|
last_event_id=last_event_id,
|
|
|
|
|
use_view_transition=event_name == "queue-reordered",
|
2026-03-30 12:34:38 +02:00
|
|
|
)
|
|
|
|
|
if event is not None:
|
|
|
|
|
yield event
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _coerce_html(view: RenderResult) -> str:
|
|
|
|
|
if isinstance(view, str):
|
|
|
|
|
return view
|
|
|
|
|
return view.__html__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _render_hash(html: str) -> str:
|
|
|
|
|
return hashlib.blake2s(html.encode("utf-8"), digest_size=16).hexdigest()
|
2026-03-31 12:12:36 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _now(now: datetime | None) -> datetime:
|
|
|
|
|
return now or datetime.now(UTC)
|