diff --git a/utils/stream.py b/utils/stream.py index 3994ccf..447d0bd 100644 --- a/utils/stream.py +++ b/utils/stream.py @@ -389,19 +389,25 @@ def get_init(output_folder, track_id): return init_path -async def save_segments(output_folder, track_id, start_tick, rep_nb, duration): - """Download and save multiple media segments. +async def save_segments(output_folder, track_id, start_tick, rep_nb, duration, batch_size=64): + """Download and save multiple media segments in batches. Args: + output_folder: The output folder path. track_id: The track identifier. start_tick: The starting tick value. rep_nb: The number of segments to download. duration: The duration per segment. + batch_size: Number of concurrent downloads per batch (default: 16). """ os.makedirs(f"{output_folder}/segments_{track_id}", exist_ok=True) async def download_segment(session, tick, rep): - """Download a single segment.""" + """Download a single segment. + + Returns: + Tuple of (success: bool, tick: int, rep: int) + """ url = f"https://media.stream.proxad.net/media/{track_id}_{tick}" headers = { "Accept": "*/*", @@ -419,43 +425,78 @@ async def save_segments(output_folder, track_id, start_tick, rep_nb, duration): filename = f"{output_folder}/segments_{track_id}/{tick}.m4s" with open(filename, "wb") as f: f.write(content) - return True + return (True, tick, rep) logger.error( "Failed to download segment %d (tick %d): HTTP %d", rep, tick, resp.status ) - return False + return (False, tick, rep) except aiohttp.ClientError as e: logger.warning("Error downloading segment %d (tick %d): %s", rep, tick, e) - return False + return (False, tick, rep) - logger.info("Starting download of %d segments...", rep_nb) + logger.info("Starting download of %d segments in batches of %d...", rep_nb, batch_size) logger.debug("Track ID: %s", track_id) logger.debug("Base tick: %d", start_tick) start_time = time.time() - async with aiohttp.ClientSession() as session: - tasks = [] - for i in range(rep_nb): - tick = start_tick + i * duration - tasks.append(download_segment(session, tick, i)) + # Build list of all segments to download + segments_to_download = [(start_tick + i * duration, i) for i in range(rep_nb)] + retry_list = [] + successful = 0 - results = [] - for coro in tqdm( - asyncio.as_completed(tasks), - total=len(tasks), - desc="Downloading segments", - unit="seg", - ): - result = await coro - results.append(result) - successful = sum(1 for r in results if r is True) + async with aiohttp.ClientSession() as session: + # Process segments in batches + with tqdm(total=len(segments_to_download), desc="Downloading segments", unit="seg") as pbar: + for batch_start in range(0, len(segments_to_download), batch_size): + batch = segments_to_download[batch_start:batch_start + batch_size] + tasks = [download_segment(session, tick, rep) for tick, rep in batch] + + results = await asyncio.gather(*tasks) + + for success, tick, rep in results: + if success: + successful += 1 + else: + retry_list.append((tick, rep)) + pbar.update(1) + + # Retry failed segments + if retry_list: + logger.info("Retrying %d failed segments...", len(retry_list)) + retry_successful = 0 + final_failures = [] + + with tqdm(total=len(retry_list), desc="Retrying segments", unit="seg") as pbar: + for batch_start in range(0, len(retry_list), batch_size): + batch = retry_list[batch_start:batch_start + batch_size] + tasks = [download_segment(session, tick, rep) for tick, rep in batch] + + results = await asyncio.gather(*tasks) + + for success, tick, rep in results: + if success: + retry_successful += 1 + successful += 1 + else: + final_failures.append((tick, rep)) + pbar.update(1) + + if final_failures: + logger.warning( + "Failed to download %d segments after retry: %s", + len(final_failures), + [tick for tick, _ in final_failures] + ) + else: + logger.info("All %d retried segments downloaded successfully", retry_successful) end_time = time.time() elapsed = end_time - start_time logger.debug("Download completed in %.2fs", elapsed) + logger.info("Successfully downloaded %d/%d segments", successful, rep_nb) logger.info("Files saved to %s/segments_%s/", output_folder, track_id)