diff --git a/repub/web.py b/repub/web.py index b2783e5..c76d0d2 100644 --- a/repub/web.py +++ b/repub/web.py @@ -4,7 +4,7 @@ import asyncio import hashlib from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import suppress -from datetime import timedelta +from datetime import UTC, datetime, timedelta from pathlib import Path from typing import TypedDict, cast from urllib.parse import urlparse @@ -590,7 +590,23 @@ def _load_sidebar_counts(app: Quart) -> dict[str, int]: def _rss_feed_response(feed_text: str | None) -> Response: if feed_text is None: return Response(status=404) - return Response(feed_text, mimetype="application/rss+xml") + etag = hashlib.sha256(feed_text.encode("utf-8")).hexdigest() + if request.if_none_match.contains(etag): + response = Response(status=304) + else: + response = Response( + feed_text, + content_type="application/rss+xml; charset=utf-8", + ) + response.set_etag(etag) + response.cache_control.public = True + response.cache_control.max_age = 300 + response.expires = datetime.now(UTC) + timedelta(minutes=5) + response.vary.add("Host") + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = "GET, HEAD, OPTIONS" + response.headers["Access-Control-Allow-Headers"] = "*" + return response def _read_feed_text(*, feeds_dir: Path, feed_path: str) -> str | None: diff --git a/tests/test_dev_mode.py b/tests/test_dev_mode.py index ae84740..0e68f34 100644 --- a/tests/test_dev_mode.py +++ b/tests/test_dev_mode.py @@ -133,5 +133,42 @@ def test_published_rss_rewrites_feed_url_to_https_host_header( "https://example.com/article" "\n" ) + assert response.headers["Access-Control-Allow-Origin"] == "*" + assert response.headers["Access-Control-Allow-Methods"] == "GET, HEAD, OPTIONS" + assert response.headers["Access-Control-Allow-Headers"] == "*" + assert response.cache_control.public is True + assert response.cache_control.max_age == 300 + assert response.headers["ETag"] != "" + + asyncio.run(run()) + + +def test_published_rss_supports_conditional_requests( + monkeypatch, tmp_path: Path +) -> None: + db_path = tmp_path / "conditional-rss.db" + feeds_dir = tmp_path / "out" / "feeds" + monkeypatch.setenv("REPUBLISHER_DB_PATH", str(db_path)) + + async def run() -> None: + app = create_app() + app.config["REPUB_FEEDS_DIR"] = feeds_dir + feed_path = feeds_dir / "demo-source" / "feed.rss" + feed_path.parent.mkdir(parents=True) + feed_path.write_text( + "Demo\n", encoding="utf-8" + ) + + client = app.test_client() + first_response = await client.get("/feeds/demo-source/feed.rss") + etag = first_response.headers["ETag"] + second_response = await client.get( + "/feeds/demo-source/feed.rss", + headers={"If-None-Match": etag}, + ) + + assert second_response.status_code == 304 + assert await second_response.get_data(as_text=True) == "" + assert second_response.headers["ETag"] == etag asyncio.run(run())