This commit is contained in:
√(noham)²
2026-01-09 19:49:49 +01:00
parent 279be1d6bf
commit b75027c750

View File

@@ -6,6 +6,7 @@ import os
import asyncio
import time
from typing import Dict, Any
from xml.dom import minidom
import requests
import aiohttp
@@ -13,9 +14,9 @@ from tqdm.asyncio import tqdm
from utils.logging_config import logger
# Register namespaces to preserve prefixes in output
ET.register_namespace('', 'urn:mpeg:dash:schema:mpd:2011')
ET.register_namespace('xsi', 'http://www.w3.org/2001/XMLSchema-instance')
ET.register_namespace('cenc', 'urn:mpeg:cenc:2013')
ET.register_namespace("", "urn:mpeg:dash:schema:mpd:2011")
ET.register_namespace("xsi", "http://www.w3.org/2001/XMLSchema-instance")
ET.register_namespace("cenc", "urn:mpeg:cenc:2013")
def parse_mpd_manifest(mpd_content: str) -> Dict[str, Any]:
@@ -394,7 +395,9 @@ def get_init(output_folder, track_id):
return init_path
async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, batch_size=64):
async def save_segments(
output_folder, track_id, start_tick, rep_nb, duration, batch_size=64
):
"""Download and save multiple media segments in batches.
Args:
@@ -409,7 +412,7 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, b
async def download_segment(session, tick, rep):
"""Download a single segment.
Returns:
Tuple of (success: bool, tick: int, rep: int)
"""
@@ -433,14 +436,18 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, b
return (True, tick, rep)
logger.error(
"Failed to download segment %d (tick %d): HTTP %d",
rep, tick, resp.status
rep,
tick,
resp.status,
)
return (False, tick, rep)
except aiohttp.ClientError as e:
logger.warning("Error downloading segment %d (tick %d): %s", rep, tick, e)
return (False, tick, rep)
logger.info("Starting download of %d segments in batches of %d...", rep_nb, batch_size)
logger.info(
"Starting download of %d segments in batches of %d...", rep_nb, batch_size
)
logger.debug("Track ID: %s", track_id)
logger.debug("Base tick: %d", start_tick)
@@ -453,13 +460,15 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, b
async with aiohttp.ClientSession() as session:
# Process segments in batches
with tqdm(total=len(segments_to_download), desc="Downloading segments", unit="seg") as pbar:
with tqdm(
total=len(segments_to_download), desc="Downloading segments", unit="seg"
) as pbar:
for batch_start in range(0, len(segments_to_download), batch_size):
batch = segments_to_download[batch_start:batch_start + batch_size]
batch = segments_to_download[batch_start : batch_start + batch_size]
tasks = [download_segment(session, tick, rep) for tick, rep in batch]
results = await asyncio.gather(*tasks)
for success, tick, rep in results:
if success:
successful += 1
@@ -472,14 +481,18 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, b
logger.info("Retrying %d failed segments...", len(retry_list))
retry_successful = 0
final_failures = []
with tqdm(total=len(retry_list), desc="Retrying segments", unit="seg") as pbar:
with tqdm(
total=len(retry_list), desc="Retrying segments", unit="seg"
) as pbar:
for batch_start in range(0, len(retry_list), batch_size):
batch = retry_list[batch_start:batch_start + batch_size]
tasks = [download_segment(session, tick, rep) for tick, rep in batch]
batch = retry_list[batch_start : batch_start + batch_size]
tasks = [
download_segment(session, tick, rep) for tick, rep in batch
]
results = await asyncio.gather(*tasks)
for success, tick, rep in results:
if success:
retry_successful += 1
@@ -487,15 +500,17 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, b
else:
final_failures.append((tick, rep))
pbar.update(1)
if final_failures:
logger.warning(
"Failed to download %d segments after retry: %s",
len(final_failures),
[tick for tick, _ in final_failures]
[tick for tick, _ in final_failures],
)
else:
logger.info("All %d retried segments downloaded successfully", retry_successful)
logger.info(
"All %d retried segments downloaded successfully", retry_successful
)
end_time = time.time()
elapsed = end_time - start_time
@@ -544,8 +559,9 @@ def generate_mpd_manifest(manifest_info: Dict[str, Any]) -> str:
The MPD manifest content as a string.
"""
mpd = ET.Element("MPD")
mpd.set("{http://www.w3.org/2001/XMLSchema-instance}schemaLocation", "urn:mpeg:dash:schema:mpd:2011 DASH-MPD.xsd")
xsi_ns = "{http://www.w3.org/2001/XMLSchema-instance}"
mpd.set(f"{xsi_ns}schemaLocation", "urn:mpeg:dash:schema:mpd:2011 DASH-MPD.xsd")
# Set MPD attributes
if manifest_info.get("profiles"):
mpd.set("profiles", manifest_info["profiles"])
@@ -562,13 +578,15 @@ def generate_mpd_manifest(manifest_info: Dict[str, Any]) -> str:
if manifest_info.get("timeShiftBufferDepth"):
mpd.set("timeShiftBufferDepth", manifest_info["timeShiftBufferDepth"])
if manifest_info.get("suggestedPresentationDelay"):
mpd.set("suggestedPresentationDelay", manifest_info["suggestedPresentationDelay"])
mpd.set(
"suggestedPresentationDelay", manifest_info["suggestedPresentationDelay"]
)
# Add UTCTiming element
utc_timing = ET.SubElement(mpd, "UTCTiming")
utc_timing.set("schemeIdUri", "urn:mpeg:dash:utc:http-iso:2014")
utc_timing.set("value", "https://time.akamai.com/?iso")
# Create periods
for period_info in manifest_info.get("periods", []):
period = ET.SubElement(mpd, "Period")
@@ -576,11 +594,11 @@ def generate_mpd_manifest(manifest_info: Dict[str, Any]) -> str:
period.set("id", period_info["id"])
if period_info.get("start"):
period.set("start", period_info["start"])
# Create adaptation sets
for adaptation_info in period_info.get("adaptation_sets", []):
adaptation_set = generate_adaptation_set(period, adaptation_info)
generate_adaptation_set(period, adaptation_info)
return format_xml_custom(mpd)
@@ -597,7 +615,7 @@ def generate_adaptation_set(
The created AdaptationSet element.
"""
adaptation_set = ET.SubElement(parent, "AdaptationSet")
if adaptation_info.get("id"):
adaptation_set.set("id", adaptation_info["id"])
if adaptation_info.get("group"):
@@ -610,26 +628,28 @@ def generate_adaptation_set(
adaptation_set.set("contentType", adaptation_info["contentType"])
if adaptation_info.get("lang"):
adaptation_set.set("lang", adaptation_info["lang"])
for drm_info in adaptation_info.get("drm_info", []):
generate_content_protection(adaptation_set, drm_info)
# Add SupplementalProperty if it's a video adaptation set with group="1"
if (adaptation_info.get("contentType") == "video" and
adaptation_info.get("group") == "1"):
if (
adaptation_info.get("contentType") == "video"
and adaptation_info.get("group") == "1"
):
supplemental = ET.SubElement(adaptation_set, "SupplementalProperty")
supplemental.set("schemeIdUri", "urn:mpeg:dash:adaptation-set-switching:2016")
if adaptation_info.get("supplementalProperty"):
supplemental.set("value", adaptation_info["supplementalProperty"])
if adaptation_info.get("role"):
role = ET.SubElement(adaptation_set, "Role")
role.set("schemeIdUri", "urn:mpeg:dash:role:2011")
role.set("value", adaptation_info["role"])
for rep_info in adaptation_info.get("representations", []):
generate_representation(adaptation_set, rep_info)
return adaptation_set
@@ -646,25 +666,25 @@ def generate_content_protection(
The created ContentProtection element.
"""
content_protection = ET.SubElement(parent, "ContentProtection")
if drm_info.get("schemeIdUri"):
content_protection.set("schemeIdUri", drm_info["schemeIdUri"])
if drm_info.get("pssh"):
pssh = ET.SubElement(content_protection, "{urn:mpeg:cenc:2013}pssh")
pssh.text = drm_info["pssh"]
if drm_info.get("value"):
content_protection.set("value", drm_info["value"])
if drm_info.get("default_KID"):
content_protection.set("{urn:mpeg:cenc:2013}default_KID", drm_info["default_KID"])
content_protection.set(
"{urn:mpeg:cenc:2013}default_KID", drm_info["default_KID"]
)
return content_protection
def generate_representation(
parent: ET.Element, rep_info: Dict[str, Any]
) -> ET.Element:
def generate_representation(parent: ET.Element, rep_info: Dict[str, Any]) -> ET.Element:
"""Generate Representation element from representation information.
Args:
@@ -675,7 +695,7 @@ def generate_representation(
The created Representation element.
"""
representation = ET.SubElement(parent, "Representation")
if rep_info.get("id"):
representation.set("id", rep_info["id"])
if rep_info.get("bandwidth"):
@@ -690,22 +710,22 @@ def generate_representation(
representation.set("height", rep_info["height"])
if rep_info.get("frameRate"):
representation.set("frameRate", rep_info["frameRate"])
segments = rep_info.get("segments", {})
if segments:
segment_template = ET.SubElement(representation, "SegmentTemplate")
if segments.get("timescale"):
segment_template.set("timescale", segments["timescale"])
if segments.get("initialization"):
segment_template.set("initialization", segments["initialization"])
if segments.get("media"):
segment_template.set("media", segments["media"])
timeline = segments.get("timeline", [])
if timeline:
segment_timeline = ET.SubElement(segment_template, "SegmentTimeline")
for timeline_info in timeline:
s_element = ET.SubElement(segment_timeline, "S")
t = timeline_info.get("t", 0)
@@ -717,7 +737,7 @@ def generate_representation(
r = timeline_info.get("r", 0)
if r != 0:
s_element.set("r", str(r))
return representation
@@ -731,11 +751,12 @@ def format_xml_custom(element: ET.Element) -> str:
Formatted XML string.
"""
rough_string = ET.tostring(element, encoding="unicode")
rough_string = rough_string.replace('ns0:', 'cenc:') # Fix namespace prefix
# Fix namespace prefix (ns0 -> cenc)
rough_string = rough_string.replace("ns0:", "cenc:")
from xml.dom import minidom
dom = minidom.parseString(rough_string)
pretty_xml = dom.toprettyxml(indent="\t", encoding=None)
lines = [line for line in pretty_xml.split('\n') if line.strip()] # Remove empty lines
return '\n'.join(lines)
# Remove empty lines
lines = [line for line in pretty_xml.split("\n") if line.strip()]
return "\n".join(lines)