From 3f33994cdcd327601c705e6fb07a38b0789953c0 Mon Sep 17 00:00:00 2001 From: Abel Luck Date: Tue, 31 Mar 2026 13:03:25 +0200 Subject: [PATCH] Clean up cancelled SSE wait tasks --- repub/datastar.py | 31 +++++++++++++++++-------------- tests/test_web.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/repub/datastar.py b/repub/datastar.py index 49b95c0..53566d0 100644 --- a/repub/datastar.py +++ b/repub/datastar.py @@ -169,20 +169,23 @@ async def render_stream( return queue_task = asyncio.create_task(queue.get()) shutdown_task = asyncio.create_task(shutdown_event.wait()) - done, pending = await asyncio.wait( - {queue_task, shutdown_task}, - return_when=asyncio.FIRST_COMPLETED, - ) - for task in pending: - task.cancel() - for task in pending: - with suppress(asyncio.CancelledError): - await task - if shutdown_task in done: - with suppress(asyncio.CancelledError): - await queue_task - return - event_name = queue_task.result() + try: + done, _pending = await asyncio.wait( + {queue_task, shutdown_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + if shutdown_task in done: + return + event_name = queue_task.result() + finally: + for task in (queue_task, shutdown_task): + if not task.done(): + task.cancel() + for task in (queue_task, shutdown_task): + if task.done() and not task.cancelled(): + continue + with suppress(asyncio.CancelledError): + await task last_event_id, event = await render_sse_event( render, last_event_id=last_event_id, diff --git a/tests/test_web.py b/tests/test_web.py index f9747a5..0b9098a 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -575,6 +575,39 @@ def test_render_stream_stops_when_shutdown_is_requested() -> None: asyncio.run(run()) +def test_render_stream_cleans_up_child_tasks_when_cancelled() -> None: + async def run() -> None: + queue = RefreshBroker().subscribe() + shutdown_event = asyncio.Event() + + async def render() -> str: + return '
queue
' + + stream = render_stream( + queue, + render, + render_on_connect=False, + shutdown_event=shutdown_event, + ) + next_event = asyncio.create_task(anext(stream)) + await asyncio.sleep(0) + next_event.cancel() + + with pytest.raises(asyncio.CancelledError): + await next_event + + await asyncio.sleep(0) + + pending = tuple( + task + for task in asyncio.all_tasks() + if task is not asyncio.current_task() and not task.done() + ) + assert pending == () + + asyncio.run(run()) + + def test_render_dashboard_shows_dashboard_information_architecture( monkeypatch, tmp_path: Path ) -> None: