Created
November 8, 2025 18:26
-
-
Save mehdidc/78e64cdd085e1079edb3f1b25ca6a447 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Compute resolution statistics for WebDataset. | |
| This script processes a WebDataset in parallel and counts how many samples | |
| meet various resolution thresholds. It outputs statistics in the format: | |
| >= 256x256 1_055_309_295 | |
| >= 384x384 698_616_282 | |
| ... | |
| Based on a WebDataset filtering script, but modified to only compute statistics. | |
| """ | |
| import os | |
| import argparse | |
| import webdataset as wds | |
| from PIL import Image | |
| import io | |
| import glob | |
| import multiprocessing as mp | |
| import time | |
| def log_and_continue(exn): | |
| """Call in an exception handler to ignore any exception, issue a warning, and continue.""" | |
| print(f'Handling webdataset error ({repr(exn)}). Ignoring.') | |
| return True | |
| def get_image_resolution(sample): | |
| """Extract image resolution from sample. Returns (width, height) or None.""" | |
| try: | |
| # Get image from sample (support jpg, webp, png) | |
| image_data = None | |
| for ext in ['jpg', 'jpeg', 'webp', 'png']: | |
| if ext in sample: | |
| image_data = sample[ext] | |
| break | |
| if image_data is None: | |
| return None | |
| image = Image.open(io.BytesIO(image_data)) | |
| return image.size # (width, height) | |
| except Exception as e: | |
| return None | |
| def process_multiple_shards(args): | |
| """Process multiple tar file shards in a single process and count resolutions.""" | |
| shard_paths, resolution_list, process_id = args | |
| try: | |
| # Initialize counters for each resolution threshold | |
| resolution_counts = [0] * len(resolution_list) | |
| total_samples = 0 | |
| dataset = wds.DataPipeline( | |
| wds.SimpleShardList(shard_paths), | |
| wds.tarfile_to_samples(handler=log_and_continue) | |
| ) | |
| print(f"Process {process_id} started") | |
| start_time = time.time() | |
| for sample in dataset: | |
| total_samples += 1 | |
| # Get image resolution | |
| resolution = get_image_resolution(sample) | |
| if resolution is None: | |
| continue | |
| width, height = resolution | |
| # Check against each threshold and increment counters | |
| for i, min_size in enumerate(resolution_list): | |
| if width >= min_size and height >= min_size: | |
| resolution_counts[i] += 1 | |
| elapsed_time = time.time() - start_time | |
| print(f"Process {process_id}: {total_samples} samples processed") | |
| return { | |
| 'process_id': process_id, | |
| 'processed_shards': len(shard_paths), | |
| 'total_samples': total_samples, | |
| 'resolution_counts': resolution_counts, | |
| 'elapsed_time': elapsed_time | |
| } | |
| except Exception as e: | |
| return { | |
| 'process_id': process_id, | |
| 'error': str(e) | |
| } | |
| def process_webdataset_parallel(input_pattern, resolution_list, num_processes=None): | |
| """Process WebDataset tar files in parallel and compute resolution statistics.""" | |
| # Get list of all tar files | |
| tar_files = sorted(glob.glob(input_pattern)) | |
| if not tar_files: | |
| print(f"No tar files found matching pattern: {input_pattern}") | |
| return | |
| print(f"Found {len(tar_files)} tar files to process") | |
| # Set number of processes | |
| if num_processes is None: | |
| num_processes = mp.cpu_count() | |
| print(f"Using {num_processes} processes") | |
| # Distribute tar files across processes | |
| files_per_process = len(tar_files) // num_processes | |
| remainder = len(tar_files) % num_processes | |
| process_args = [] | |
| start_idx = 0 | |
| for i in range(num_processes): | |
| end_idx = start_idx + files_per_process | |
| if i < remainder: # Distribute remainder files | |
| end_idx += 1 | |
| process_files = tar_files[start_idx:end_idx] | |
| if process_files: # Only create process if it has files to process | |
| process_args.append((process_files, resolution_list, i)) | |
| start_idx = end_idx | |
| print(f"Files per process: {[len(args[0]) for args in process_args]}") | |
| # Process shards in parallel | |
| start_time = time.time() | |
| with mp.Pool(processes=len(process_args)) as pool: | |
| results = pool.map(process_multiple_shards, process_args) | |
| total_elapsed = time.time() - start_time | |
| # Aggregate results | |
| total_samples = sum(r['total_samples'] for r in results if 'error' not in r) | |
| total_shards = sum(r['processed_shards'] for r in results if 'error' not in r) | |
| # Aggregate resolution counts from all processes | |
| total_counts = [0] * len(resolution_list) | |
| for r in results: | |
| if 'error' not in r: | |
| for i, count in enumerate(r['resolution_counts']): | |
| total_counts[i] += count | |
| errors = [r for r in results if 'error' in r] | |
| # Print statistics in the requested format | |
| print(f"\nResolution statistics:") | |
| print("=" * 40) | |
| for min_size, count in zip(resolution_list, total_counts): | |
| print(f">= {min_size}x{min_size} {count:_}") | |
| print(f"\nSummary:") | |
| print(f"Total files processed: {total_shards}") | |
| print(f"Total samples processed: {total_samples:,}") | |
| print(f"Total time: {total_elapsed:.2f} seconds") | |
| print(f"Throughput: {total_samples/total_elapsed:.0f} samples/second") | |
| if errors: | |
| print(f"\nErrors encountered: {len(errors)}") | |
| for error in errors[:5]: # Show first 5 errors | |
| print(f" Process {error['process_id']}: {error['error']}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Compute resolution statistics for WebDataset") | |
| parser.add_argument("--input_pattern", | |
| default="/p/data1/mmlaion/datacomp/datacomp_1B/flat/*.tar", | |
| help="Input WebDataset pattern") | |
| parser.add_argument("--resolutions", | |
| type=int, | |
| nargs='+', | |
| default=[256, 384, 448, 512, 640, 784, 896, 1024], | |
| help="List of resolution thresholds") | |
| parser.add_argument("--num_processes", type=int, default=None, | |
| help="Number of parallel processes (default: CPU count)") | |
| args = parser.parse_args() | |
| # Sort resolutions to ensure consistent output order | |
| resolution_list = sorted(args.resolutions) | |
| print(f"Computing resolution statistics for thresholds: {resolution_list}") | |
| print(f"Input pattern: {args.input_pattern}") | |
| process_webdataset_parallel(args.input_pattern, resolution_list, args.num_processes) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment