Compare commits

2 Commits

Author SHA1 Message Date
√(noham)²
b75027c750 Lint 2026-01-09 19:49:49 +01:00
√(noham)²
279be1d6bf Add MPD manifest generation utilities 2026-01-09 19:46:24 +01:00

View File

@@ -6,12 +6,18 @@ import os
import asyncio import asyncio
import time import time
from typing import Dict, Any from typing import Dict, Any
from xml.dom import minidom
import requests import requests
import aiohttp import aiohttp
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
from utils.logging_config import logger from utils.logging_config import logger
# Register namespaces to preserve prefixes in output
ET.register_namespace("", "urn:mpeg:dash:schema:mpd:2011")
ET.register_namespace("xsi", "http://www.w3.org/2001/XMLSchema-instance")
ET.register_namespace("cenc", "urn:mpeg:cenc:2013")
def parse_mpd_manifest(mpd_content: str) -> Dict[str, Any]: def parse_mpd_manifest(mpd_content: str) -> Dict[str, Any]:
"""Parse an MPD manifest and extract metadata. """Parse an MPD manifest and extract metadata.
@@ -389,7 +395,9 @@ def get_init(output_folder, track_id):
return init_path return init_path
async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, batch_size=64): async def save_segments(
output_folder, track_id, start_tick, rep_nb, duration, batch_size=64
):
"""Download and save multiple media segments in batches. """Download and save multiple media segments in batches.
Args: Args:
@@ -428,14 +436,18 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, b
return (True, tick, rep) return (True, tick, rep)
logger.error( logger.error(
"Failed to download segment %d (tick %d): HTTP %d", "Failed to download segment %d (tick %d): HTTP %d",
rep, tick, resp.status rep,
tick,
resp.status,
) )
return (False, tick, rep) return (False, tick, rep)
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logger.warning("Error downloading segment %d (tick %d): %s", rep, tick, e) logger.warning("Error downloading segment %d (tick %d): %s", rep, tick, e)
return (False, tick, rep) return (False, tick, rep)
logger.info("Starting download of %d segments in batches of %d...", rep_nb, batch_size) logger.info(
"Starting download of %d segments in batches of %d...", rep_nb, batch_size
)
logger.debug("Track ID: %s", track_id) logger.debug("Track ID: %s", track_id)
logger.debug("Base tick: %d", start_tick) logger.debug("Base tick: %d", start_tick)
@@ -448,7 +460,9 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, b
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
# Process segments in batches # Process segments in batches
with tqdm(total=len(segments_to_download), desc="Downloading segments", unit="seg") as pbar: with tqdm(
total=len(segments_to_download), desc="Downloading segments", unit="seg"
) as pbar:
for batch_start in range(0, len(segments_to_download), batch_size): for batch_start in range(0, len(segments_to_download), batch_size):
batch = segments_to_download[batch_start : batch_start + batch_size] batch = segments_to_download[batch_start : batch_start + batch_size]
tasks = [download_segment(session, tick, rep) for tick, rep in batch] tasks = [download_segment(session, tick, rep) for tick, rep in batch]
@@ -468,10 +482,14 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, b
retry_successful = 0 retry_successful = 0
final_failures = [] final_failures = []
with tqdm(total=len(retry_list), desc="Retrying segments", unit="seg") as pbar: with tqdm(
total=len(retry_list), desc="Retrying segments", unit="seg"
) as pbar:
for batch_start in range(0, len(retry_list), batch_size): for batch_start in range(0, len(retry_list), batch_size):
batch = retry_list[batch_start : batch_start + batch_size] batch = retry_list[batch_start : batch_start + batch_size]
tasks = [download_segment(session, tick, rep) for tick, rep in batch] tasks = [
download_segment(session, tick, rep) for tick, rep in batch
]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
@@ -487,10 +505,12 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, b
logger.warning( logger.warning(
"Failed to download %d segments after retry: %s", "Failed to download %d segments after retry: %s",
len(final_failures), len(final_failures),
[tick for tick, _ in final_failures] [tick for tick, _ in final_failures],
) )
else: else:
logger.info("All %d retried segments downloaded successfully", retry_successful) logger.info(
"All %d retried segments downloaded successfully", retry_successful
)
end_time = time.time() end_time = time.time()
elapsed = end_time - start_time elapsed = end_time - start_time
@@ -527,3 +547,216 @@ def get_kid(output_folder, track_id):
kid = kid_bytes.hex() kid = kid_bytes.hex()
return kid return kid
return None return None
def generate_mpd_manifest(manifest_info: Dict[str, Any]) -> str:
"""Generate an MPD manifest from parsed manifest information.
Args:
manifest_info: A dictionary containing manifest information.
Returns:
The MPD manifest content as a string.
"""
mpd = ET.Element("MPD")
xsi_ns = "{http://www.w3.org/2001/XMLSchema-instance}"
mpd.set(f"{xsi_ns}schemaLocation", "urn:mpeg:dash:schema:mpd:2011 DASH-MPD.xsd")
# Set MPD attributes
if manifest_info.get("profiles"):
mpd.set("profiles", manifest_info["profiles"])
if manifest_info.get("type"):
mpd.set("type", manifest_info["type"])
if manifest_info.get("publishTime"):
mpd.set("publishTime", manifest_info["publishTime"])
if manifest_info.get("availabilityStartTime"):
mpd.set("availabilityStartTime", manifest_info["availabilityStartTime"])
if manifest_info.get("minimumUpdatePeriod"):
mpd.set("minimumUpdatePeriod", manifest_info["minimumUpdatePeriod"])
if manifest_info.get("minBufferTime"):
mpd.set("minBufferTime", manifest_info["minBufferTime"])
if manifest_info.get("timeShiftBufferDepth"):
mpd.set("timeShiftBufferDepth", manifest_info["timeShiftBufferDepth"])
if manifest_info.get("suggestedPresentationDelay"):
mpd.set(
"suggestedPresentationDelay", manifest_info["suggestedPresentationDelay"]
)
# Add UTCTiming element
utc_timing = ET.SubElement(mpd, "UTCTiming")
utc_timing.set("schemeIdUri", "urn:mpeg:dash:utc:http-iso:2014")
utc_timing.set("value", "https://time.akamai.com/?iso")
# Create periods
for period_info in manifest_info.get("periods", []):
period = ET.SubElement(mpd, "Period")
if period_info.get("id"):
period.set("id", period_info["id"])
if period_info.get("start"):
period.set("start", period_info["start"])
# Create adaptation sets
for adaptation_info in period_info.get("adaptation_sets", []):
generate_adaptation_set(period, adaptation_info)
return format_xml_custom(mpd)
def generate_adaptation_set(
parent: ET.Element, adaptation_info: Dict[str, Any]
) -> ET.Element:
"""Generate an AdaptationSet element from adaptation set information.
Args:
parent: The parent XML element.
adaptation_info: Dictionary containing adaptation set information.
Returns:
The created AdaptationSet element.
"""
adaptation_set = ET.SubElement(parent, "AdaptationSet")
if adaptation_info.get("id"):
adaptation_set.set("id", adaptation_info["id"])
if adaptation_info.get("group"):
adaptation_set.set("group", adaptation_info["group"])
if adaptation_info.get("segmentAlignment"):
adaptation_set.set("segmentAlignment", adaptation_info["segmentAlignment"])
if adaptation_info.get("startWithSAP"):
adaptation_set.set("startWithSAP", adaptation_info["startWithSAP"])
if adaptation_info.get("contentType"):
adaptation_set.set("contentType", adaptation_info["contentType"])
if adaptation_info.get("lang"):
adaptation_set.set("lang", adaptation_info["lang"])
for drm_info in adaptation_info.get("drm_info", []):
generate_content_protection(adaptation_set, drm_info)
# Add SupplementalProperty if it's a video adaptation set with group="1"
if (
adaptation_info.get("contentType") == "video"
and adaptation_info.get("group") == "1"
):
supplemental = ET.SubElement(adaptation_set, "SupplementalProperty")
supplemental.set("schemeIdUri", "urn:mpeg:dash:adaptation-set-switching:2016")
if adaptation_info.get("supplementalProperty"):
supplemental.set("value", adaptation_info["supplementalProperty"])
if adaptation_info.get("role"):
role = ET.SubElement(adaptation_set, "Role")
role.set("schemeIdUri", "urn:mpeg:dash:role:2011")
role.set("value", adaptation_info["role"])
for rep_info in adaptation_info.get("representations", []):
generate_representation(adaptation_set, rep_info)
return adaptation_set
def generate_content_protection(
parent: ET.Element, drm_info: Dict[str, Any]
) -> ET.Element:
"""Generate ContentProtection element from DRM information.
Args:
parent: The parent XML element.
drm_info: Dictionary containing DRM information.
Returns:
The created ContentProtection element.
"""
content_protection = ET.SubElement(parent, "ContentProtection")
if drm_info.get("schemeIdUri"):
content_protection.set("schemeIdUri", drm_info["schemeIdUri"])
if drm_info.get("pssh"):
pssh = ET.SubElement(content_protection, "{urn:mpeg:cenc:2013}pssh")
pssh.text = drm_info["pssh"]
if drm_info.get("value"):
content_protection.set("value", drm_info["value"])
if drm_info.get("default_KID"):
content_protection.set(
"{urn:mpeg:cenc:2013}default_KID", drm_info["default_KID"]
)
return content_protection
def generate_representation(parent: ET.Element, rep_info: Dict[str, Any]) -> ET.Element:
"""Generate Representation element from representation information.
Args:
parent: The parent XML element.
rep_info: Dictionary containing representation information.
Returns:
The created Representation element.
"""
representation = ET.SubElement(parent, "Representation")
if rep_info.get("id"):
representation.set("id", rep_info["id"])
if rep_info.get("bandwidth"):
representation.set("bandwidth", rep_info["bandwidth"])
if rep_info.get("codecs"):
representation.set("codecs", rep_info["codecs"])
if rep_info.get("mimeType"):
representation.set("mimeType", rep_info["mimeType"])
if rep_info.get("width"):
representation.set("width", rep_info["width"])
if rep_info.get("height"):
representation.set("height", rep_info["height"])
if rep_info.get("frameRate"):
representation.set("frameRate", rep_info["frameRate"])
segments = rep_info.get("segments", {})
if segments:
segment_template = ET.SubElement(representation, "SegmentTemplate")
if segments.get("timescale"):
segment_template.set("timescale", segments["timescale"])
if segments.get("initialization"):
segment_template.set("initialization", segments["initialization"])
if segments.get("media"):
segment_template.set("media", segments["media"])
timeline = segments.get("timeline", [])
if timeline:
segment_timeline = ET.SubElement(segment_template, "SegmentTimeline")
for timeline_info in timeline:
s_element = ET.SubElement(segment_timeline, "S")
t = timeline_info.get("t", 0)
if t != 0:
s_element.set("t", str(t))
d = timeline_info.get("d", 0)
if d != 0:
s_element.set("d", str(d))
r = timeline_info.get("r", 0)
if r != 0:
s_element.set("r", str(r))
return representation
def format_xml_custom(element: ET.Element) -> str:
"""Format XML element to match the original MPD style.
Args:
element: The XML element to format.
Returns:
Formatted XML string.
"""
rough_string = ET.tostring(element, encoding="unicode")
# Fix namespace prefix (ns0 -> cenc)
rough_string = rough_string.replace("ns0:", "cenc:")
dom = minidom.parseString(rough_string)
pretty_xml = dom.toprettyxml(indent="\t", encoding=None)
# Remove empty lines
lines = [line for line in pretty_xml.split("\n") if line.strip()]
return "\n".join(lines)