This commit is contained in:
√(noham)²
2026-01-09 19:49:49 +01:00
parent 279be1d6bf
commit b75027c750

View File

@@ -6,6 +6,7 @@ import os
import asyncio import asyncio
import time import time
from typing import Dict, Any from typing import Dict, Any
from xml.dom import minidom
import requests import requests
import aiohttp import aiohttp
@@ -13,9 +14,9 @@ from tqdm.asyncio import tqdm
from utils.logging_config import logger from utils.logging_config import logger
# Register namespaces to preserve prefixes in output # Register namespaces to preserve prefixes in output
ET.register_namespace('', 'urn:mpeg:dash:schema:mpd:2011') ET.register_namespace("", "urn:mpeg:dash:schema:mpd:2011")
ET.register_namespace('xsi', 'http://www.w3.org/2001/XMLSchema-instance') ET.register_namespace("xsi", "http://www.w3.org/2001/XMLSchema-instance")
ET.register_namespace('cenc', 'urn:mpeg:cenc:2013') ET.register_namespace("cenc", "urn:mpeg:cenc:2013")
def parse_mpd_manifest(mpd_content: str) -> Dict[str, Any]: def parse_mpd_manifest(mpd_content: str) -> Dict[str, Any]:
@@ -394,7 +395,9 @@ def get_init(output_folder, track_id):
return init_path return init_path
async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, batch_size=64): async def save_segments(
output_folder, track_id, start_tick, rep_nb, duration, batch_size=64
):
"""Download and save multiple media segments in batches. """Download and save multiple media segments in batches.
Args: Args:
@@ -433,14 +436,18 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, b
return (True, tick, rep) return (True, tick, rep)
logger.error( logger.error(
"Failed to download segment %d (tick %d): HTTP %d", "Failed to download segment %d (tick %d): HTTP %d",
rep, tick, resp.status rep,
tick,
resp.status,
) )
return (False, tick, rep) return (False, tick, rep)
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logger.warning("Error downloading segment %d (tick %d): %s", rep, tick, e) logger.warning("Error downloading segment %d (tick %d): %s", rep, tick, e)
return (False, tick, rep) return (False, tick, rep)
logger.info("Starting download of %d segments in batches of %d...", rep_nb, batch_size) logger.info(
"Starting download of %d segments in batches of %d...", rep_nb, batch_size
)
logger.debug("Track ID: %s", track_id) logger.debug("Track ID: %s", track_id)
logger.debug("Base tick: %d", start_tick) logger.debug("Base tick: %d", start_tick)
@@ -453,7 +460,9 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, b
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
# Process segments in batches # Process segments in batches
with tqdm(total=len(segments_to_download), desc="Downloading segments", unit="seg") as pbar: with tqdm(
total=len(segments_to_download), desc="Downloading segments", unit="seg"
) as pbar:
for batch_start in range(0, len(segments_to_download), batch_size): for batch_start in range(0, len(segments_to_download), batch_size):
batch = segments_to_download[batch_start : batch_start + batch_size] batch = segments_to_download[batch_start : batch_start + batch_size]
tasks = [download_segment(session, tick, rep) for tick, rep in batch] tasks = [download_segment(session, tick, rep) for tick, rep in batch]
@@ -473,10 +482,14 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, b
retry_successful = 0 retry_successful = 0
final_failures = [] final_failures = []
with tqdm(total=len(retry_list), desc="Retrying segments", unit="seg") as pbar: with tqdm(
total=len(retry_list), desc="Retrying segments", unit="seg"
) as pbar:
for batch_start in range(0, len(retry_list), batch_size): for batch_start in range(0, len(retry_list), batch_size):
batch = retry_list[batch_start : batch_start + batch_size] batch = retry_list[batch_start : batch_start + batch_size]
tasks = [download_segment(session, tick, rep) for tick, rep in batch] tasks = [
download_segment(session, tick, rep) for tick, rep in batch
]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
@@ -492,10 +505,12 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, b
logger.warning( logger.warning(
"Failed to download %d segments after retry: %s", "Failed to download %d segments after retry: %s",
len(final_failures), len(final_failures),
[tick for tick, _ in final_failures] [tick for tick, _ in final_failures],
) )
else: else:
logger.info("All %d retried segments downloaded successfully", retry_successful) logger.info(
"All %d retried segments downloaded successfully", retry_successful
)
end_time = time.time() end_time = time.time()
elapsed = end_time - start_time elapsed = end_time - start_time
@@ -544,7 +559,8 @@ def generate_mpd_manifest(manifest_info: Dict[str, Any]) -> str:
The MPD manifest content as a string. The MPD manifest content as a string.
""" """
mpd = ET.Element("MPD") mpd = ET.Element("MPD")
mpd.set("{http://www.w3.org/2001/XMLSchema-instance}schemaLocation", "urn:mpeg:dash:schema:mpd:2011 DASH-MPD.xsd") xsi_ns = "{http://www.w3.org/2001/XMLSchema-instance}"
mpd.set(f"{xsi_ns}schemaLocation", "urn:mpeg:dash:schema:mpd:2011 DASH-MPD.xsd")
# Set MPD attributes # Set MPD attributes
if manifest_info.get("profiles"): if manifest_info.get("profiles"):
@@ -562,7 +578,9 @@ def generate_mpd_manifest(manifest_info: Dict[str, Any]) -> str:
if manifest_info.get("timeShiftBufferDepth"): if manifest_info.get("timeShiftBufferDepth"):
mpd.set("timeShiftBufferDepth", manifest_info["timeShiftBufferDepth"]) mpd.set("timeShiftBufferDepth", manifest_info["timeShiftBufferDepth"])
if manifest_info.get("suggestedPresentationDelay"): if manifest_info.get("suggestedPresentationDelay"):
mpd.set("suggestedPresentationDelay", manifest_info["suggestedPresentationDelay"]) mpd.set(
"suggestedPresentationDelay", manifest_info["suggestedPresentationDelay"]
)
# Add UTCTiming element # Add UTCTiming element
utc_timing = ET.SubElement(mpd, "UTCTiming") utc_timing = ET.SubElement(mpd, "UTCTiming")
@@ -579,7 +597,7 @@ def generate_mpd_manifest(manifest_info: Dict[str, Any]) -> str:
# Create adaptation sets # Create adaptation sets
for adaptation_info in period_info.get("adaptation_sets", []): for adaptation_info in period_info.get("adaptation_sets", []):
adaptation_set = generate_adaptation_set(period, adaptation_info) generate_adaptation_set(period, adaptation_info)
return format_xml_custom(mpd) return format_xml_custom(mpd)
@@ -615,8 +633,10 @@ def generate_adaptation_set(
generate_content_protection(adaptation_set, drm_info) generate_content_protection(adaptation_set, drm_info)
# Add SupplementalProperty if it's a video adaptation set with group="1" # Add SupplementalProperty if it's a video adaptation set with group="1"
if (adaptation_info.get("contentType") == "video" and if (
adaptation_info.get("group") == "1"): adaptation_info.get("contentType") == "video"
and adaptation_info.get("group") == "1"
):
supplemental = ET.SubElement(adaptation_set, "SupplementalProperty") supplemental = ET.SubElement(adaptation_set, "SupplementalProperty")
supplemental.set("schemeIdUri", "urn:mpeg:dash:adaptation-set-switching:2016") supplemental.set("schemeIdUri", "urn:mpeg:dash:adaptation-set-switching:2016")
if adaptation_info.get("supplementalProperty"): if adaptation_info.get("supplementalProperty"):
@@ -657,14 +677,14 @@ def generate_content_protection(
if drm_info.get("value"): if drm_info.get("value"):
content_protection.set("value", drm_info["value"]) content_protection.set("value", drm_info["value"])
if drm_info.get("default_KID"): if drm_info.get("default_KID"):
content_protection.set("{urn:mpeg:cenc:2013}default_KID", drm_info["default_KID"]) content_protection.set(
"{urn:mpeg:cenc:2013}default_KID", drm_info["default_KID"]
)
return content_protection return content_protection
def generate_representation( def generate_representation(parent: ET.Element, rep_info: Dict[str, Any]) -> ET.Element:
parent: ET.Element, rep_info: Dict[str, Any]
) -> ET.Element:
"""Generate Representation element from representation information. """Generate Representation element from representation information.
Args: Args:
@@ -731,11 +751,12 @@ def format_xml_custom(element: ET.Element) -> str:
Formatted XML string. Formatted XML string.
""" """
rough_string = ET.tostring(element, encoding="unicode") rough_string = ET.tostring(element, encoding="unicode")
rough_string = rough_string.replace('ns0:', 'cenc:') # Fix namespace prefix # Fix namespace prefix (ns0 -> cenc)
rough_string = rough_string.replace("ns0:", "cenc:")
from xml.dom import minidom
dom = minidom.parseString(rough_string) dom = minidom.parseString(rough_string)
pretty_xml = dom.toprettyxml(indent="\t", encoding=None) pretty_xml = dom.toprettyxml(indent="\t", encoding=None)
lines = [line for line in pretty_xml.split('\n') if line.strip()] # Remove empty lines # Remove empty lines
lines = [line for line in pretty_xml.split("\n") if line.strip()]
return '\n'.join(lines) return "\n".join(lines)