Created
November 8, 2025 07:52
-
-
Save mehdidc/141c556e1034670d87cc8d58f3ef3467 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import os | |
| import argparse | |
| import webdataset as wds | |
| from PIL import Image | |
| import json | |
| import io | |
| import glob | |
| import multiprocessing as mp | |
| from functools import partial | |
| import time | |
| def log_and_continue(exn): | |
| """Call in an exception handler to ignore any exception, issue a warning, and continue.""" | |
| print(f'Handling webdataset error ({repr(exn)}). Ignoring.') | |
| return True | |
| def filter_by_resolution(sample, min_size=512): | |
| """Filter samples by minimum image resolution.""" | |
| try: | |
| # Get image from sample (support jpg, webp, png) | |
| image_data = None | |
| for ext in ['jpg', 'jpeg', 'webp', 'png']: | |
| if ext in sample: | |
| image_data = sample[ext] | |
| break | |
| if image_data is None: | |
| return False | |
| image = Image.open(io.BytesIO(image_data)) | |
| width, height = image.size | |
| # Check if both dimensions meet minimum requirement | |
| if width >= min_size and height >= min_size: | |
| return True | |
| return False | |
| except Exception as e: | |
| return False | |
| def process_multiple_shards(args): | |
| """Process multiple tar file shards in a single process.""" | |
| shard_paths, output_dir, min_size, process_id = args | |
| # Create process-specific output pattern | |
| output_pattern = os.path.join(output_dir, f"filtered_process_{process_id}_%06d.tar") | |
| # Statistics | |
| total_samples = 0 | |
| filtered_samples = 0 | |
| sink = wds.ShardWriter(output_pattern, maxcount=10000) | |
| start_time = time.time() | |
| dataset = wds.DataPipeline( | |
| wds.SimpleShardList(shard_paths), | |
| wds.tarfile_to_samples(handler=log_and_continue) | |
| ) | |
| print(f"Process {process_id}") | |
| shard_samples = 0 | |
| shard_filtered = 0 | |
| for sample in dataset: | |
| total_samples += 1 | |
| # Filter by resolution | |
| if filter_by_resolution(sample, min_size): | |
| filtered_samples += 1 | |
| shard_filtered += 1 | |
| # Write filtered sample to output | |
| key = sample["__key__"] | |
| #key = f"sample_{process_id}_{total_samples:06d}" | |
| output_sample = { | |
| "__key__": key, | |
| "txt": sample["txt"], | |
| "json": sample["json"] | |
| } | |
| # Add image data (preserve original format) | |
| for ext in ['jpg', 'jpeg', 'webp', 'png']: | |
| if ext in sample: | |
| output_sample[ext] = sample[ext] | |
| break | |
| sink.write(output_sample) | |
| print(f"Process {process_id}: {total_samples} samples, kept {filtered_samples}") | |
| sink.close() | |
| elapsed_time = time.time() - start_time | |
| return { | |
| 'process_id': process_id, | |
| 'processed_shards': len(shard_paths), | |
| 'total_samples': total_samples, | |
| 'filtered_samples': filtered_samples, | |
| 'elapsed_time': elapsed_time | |
| } | |
| def process_webdataset_parallel(input_pattern, output_dir, min_size=512, num_processes=None): | |
| """Process WebDataset tar files in parallel.""" | |
| # Create output directory if it doesn't exist | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Get list of all tar files | |
| tar_files = sorted(glob.glob(input_pattern)) | |
| if not tar_files: | |
| print(f"No tar files found matching pattern: {input_pattern}") | |
| return | |
| print(f"Found {len(tar_files)} tar files to process") | |
| # Set number of processes | |
| if num_processes is None: | |
| num_processes = mp.cpu_count() | |
| print(f"Using {num_processes} processes") | |
| # Distribute tar files across processes | |
| files_per_process = len(tar_files) // num_processes | |
| remainder = len(tar_files) % num_processes | |
| process_args = [] | |
| start_idx = 0 | |
| for i in range(num_processes): | |
| end_idx = start_idx + files_per_process | |
| if i < remainder: # Distribute remainder files | |
| end_idx += 1 | |
| process_files = tar_files[start_idx:end_idx] | |
| if process_files: # Only create process if it has files to process | |
| process_args.append((process_files, output_dir, min_size, i)) | |
| start_idx = end_idx | |
| print(f"Files per process: {[len(args[0]) for args in process_args]}") | |
| # Process shards in parallel | |
| start_time = time.time() | |
| with mp.Pool(processes=len(process_args)) as pool: | |
| results = pool.map(process_multiple_shards, process_args) | |
| total_elapsed = time.time() - start_time | |
| # Aggregate results | |
| total_samples = sum(r['total_samples'] for r in results if 'error' not in r) | |
| total_filtered = sum(r['filtered_samples'] for r in results if 'error' not in r) | |
| total_shards = sum(r['processed_shards'] for r in results if 'error' not in r) | |
| errors = [r for r in results if 'error' in r] | |
| print(f"\nParallel filtering complete:") | |
| print(f"Total files processed: {total_shards}") | |
| print(f"Total samples processed: {total_samples:,}") | |
| print(f"Samples kept (>= {min_size}x{min_size}): {total_filtered:,}") | |
| print(f"Filtering rate: {total_filtered/total_samples*100:.2f}%") | |
| print(f"Total time: {total_elapsed:.2f} seconds") | |
| print(f"Throughput: {total_samples/total_elapsed:.0f} samples/second") | |
| print(f"Filtered samples written to: {output_dir}") | |
| if errors: | |
| print(f"\nErrors encountered: {len(errors)}") | |
| for error in errors[:5]: # Show first 5 errors | |
| print(f" Process {error['process_id']}: {error['error']}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Filter DataComp 1B by image resolution (parallel)") | |
| parser.add_argument("--input_pattern", | |
| default="/p/data1/mmlaion/datacomp/datacomp_1B/flat/*.tar", | |
| help="Input WebDataset pattern") | |
| parser.add_argument("--output_dir", | |
| default="/p/data1/mmlaion/datacomp/datacomp_512x512/filtered", | |
| help="Output directory for filtered data") | |
| parser.add_argument("--min_size", type=int, default=512, | |
| help="Minimum image size (width and height)") | |
| parser.add_argument("--num_processes", type=int, default=None, | |
| help="Number of parallel processes (default: CPU count)") | |
| args = parser.parse_args() | |
| print(f"Filtering DataComp 1B with minimum size: {args.min_size}x{args.min_size}") | |
| print(f"Input pattern: {args.input_pattern}") | |
| print(f"Output directory: {args.output_dir}") | |
| process_webdataset_parallel(args.input_pattern, args.output_dir, | |
| args.min_size, args.num_processes) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment