from io import BytesIO from typing import Any from lxml.etree import QName from scrapy.exporters import BaseItemExporter from repub import rss from repub.items import ( ChannelElementItem, ElementItem, MediaVariant, TranscodedMediaFile, ) from repub.utils import FileType, determine_file_type MEDIA_CONTENT_TAG = QName(rss.nsmap["media"], "content").text MEDIA_GROUP_TAG = QName(rss.nsmap["media"], "group").text class RssExporter(BaseItemExporter): def __init__(self, file: BytesIO, **kwargs: Any): super().__init__(**kwargs) if not self.encoding: self.encoding = "utf-8" self.file: BytesIO = file self.rss = rss.rss() self.channel = None self.item_buffer = [] def start_exporting(self) -> None: pass def export_item(self, item: Any): if isinstance(item, ChannelElementItem): self.channel = item.el self.rss.append(item.el) self.flush_buffer() return if self.channel is None: self.item_buffer.append(item) else: self.export_rss_item(item) def flush_buffer(self): for item in self.item_buffer: self.export_rss_item(item) self.item_buffer = [] def compact_attrib(self, **attrib): return { key: str(value) for key, value in attrib.items() if value not in (None, "") } def canonical_variant(self, media_file: TranscodedMediaFile) -> MediaVariant | None: for variant in media_file["variants"]: if variant.get("isDefault") == "true": return variant if media_file["variants"]: return media_file["variants"][0] return None def rebuild_enclosures(self, item: ElementItem) -> None: audio_lookup = {audio["published_url"]: audio for audio in item.audios} for enclosure in item.el.findall("enclosure"): media_file = audio_lookup.get(enclosure.get("url", "")) if media_file is None: continue canonical = self.canonical_variant(media_file) if canonical is None: continue enclosure.attrib.clear() enclosure.attrib.update( self.compact_attrib( url=canonical.get("url"), length=canonical.get("fileSize") or enclosure.get("length"), type=canonical.get("type") or enclosure.get("type"), ) ) def owned_media_type(self, el, managed_types: set[FileType]) -> FileType | None: url = el.get("url", "") file_type = determine_file_type( url=url, medium=el.get("medium"), mimetype=el.get("type"), ) if file_type in managed_types: return file_type return None def strip_managed_media_nodes(self, item: ElementItem) -> dict[str, dict[str, str]]: fallbacks: dict[str, dict[str, str]] = {} managed_types: set[FileType] = set() if item.audios: managed_types.add(FileType.AUDIO) if item.videos: managed_types.add(FileType.VIDEO) if not managed_types: return fallbacks for child in list(item.el): if child.tag == MEDIA_CONTENT_TAG: if self.owned_media_type(child, managed_types) is None: continue fallbacks[child.get("url", "")] = { key: value for key, value in child.attrib.items() if key in {"expression", "lang"} } item.el.remove(child) continue if child.tag != MEDIA_GROUP_TAG: continue for media_content in list(child): if media_content.tag != MEDIA_CONTENT_TAG: continue if self.owned_media_type(media_content, managed_types) is None: continue fallbacks[media_content.get("url", "")] = { key: value for key, value in media_content.attrib.items() if key in {"expression", "lang"} } child.remove(media_content) if len(child) == 0: item.el.remove(child) return fallbacks def append_media_groups( self, item: ElementItem, fallbacks: dict[str, dict[str, str]] ): for media_file in [*item.audios, *item.videos]: if not media_file["variants"]: continue fallback_attrib = fallbacks.get(media_file["published_url"], {}) group = rss.MEDIA.group( *[ rss.MEDIA.content( **self.media_content_attrib(variant, fallback_attrib) ) for variant in media_file["variants"] ] ) if group is not None: item.el.append(group) def media_content_attrib( self, variant: MediaVariant, fallback_attrib: dict[str, str] ) -> dict[str, str]: attrib = dict(fallback_attrib) attrib.update( self.compact_attrib( url=variant.get("url"), type=variant.get("type"), medium=variant.get("medium"), isDefault=variant.get("isDefault"), expression=variant.get("expression"), bitrate=variant.get("bitrate"), framerate=variant.get("framerate"), samplingrate=variant.get("samplingrate"), channels=variant.get("channels"), duration=variant.get("duration"), height=variant.get("height"), width=variant.get("width"), lang=variant.get("lang"), fileSize=variant.get("fileSize"), ) ) return attrib def apply_transcoded_media(self, item: Any) -> None: if not isinstance(item, ElementItem): return if not item.audios and not item.videos: return self.rebuild_enclosures(item) fallbacks = self.strip_managed_media_nodes(item) self.append_media_groups(item, fallbacks) def export_rss_item(self, item: Any): assert self.channel is not None self.apply_transcoded_media(item) self.channel.append(item.el) def finish_exporting(self) -> None: xml_bytes = rss.serialize(self.rss) self.file.write(xml_bytes)