Skip to content

Instantly share code, notes, and snippets.

@mehdidc
Created November 8, 2025 07:52
Show Gist options
  • Save mehdidc/141c556e1034670d87cc8d58f3ef3467 to your computer and use it in GitHub Desktop.
Save mehdidc/141c556e1034670d87cc8d58f3ef3467 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import os
import argparse
import webdataset as wds
from PIL import Image
import json
import io
import glob
import multiprocessing as mp
from functools import partial
import time
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
print(f'Handling webdataset error ({repr(exn)}). Ignoring.')
return True
def filter_by_resolution(sample, min_size=512):
"""Filter samples by minimum image resolution."""
try:
# Get image from sample (support jpg, webp, png)
image_data = None
for ext in ['jpg', 'jpeg', 'webp', 'png']:
if ext in sample:
image_data = sample[ext]
break
if image_data is None:
return False
image = Image.open(io.BytesIO(image_data))
width, height = image.size
# Check if both dimensions meet minimum requirement
if width >= min_size and height >= min_size:
return True
return False
except Exception as e:
return False
def process_multiple_shards(args):
"""Process multiple tar file shards in a single process."""
shard_paths, output_dir, min_size, process_id = args
# Create process-specific output pattern
output_pattern = os.path.join(output_dir, f"filtered_process_{process_id}_%06d.tar")
# Statistics
total_samples = 0
filtered_samples = 0
sink = wds.ShardWriter(output_pattern, maxcount=10000)
start_time = time.time()
dataset = wds.DataPipeline(
wds.SimpleShardList(shard_paths),
wds.tarfile_to_samples(handler=log_and_continue)
)
print(f"Process {process_id}")
shard_samples = 0
shard_filtered = 0
for sample in dataset:
total_samples += 1
# Filter by resolution
if filter_by_resolution(sample, min_size):
filtered_samples += 1
shard_filtered += 1
# Write filtered sample to output
key = sample["__key__"]
#key = f"sample_{process_id}_{total_samples:06d}"
output_sample = {
"__key__": key,
"txt": sample["txt"],
"json": sample["json"]
}
# Add image data (preserve original format)
for ext in ['jpg', 'jpeg', 'webp', 'png']:
if ext in sample:
output_sample[ext] = sample[ext]
break
sink.write(output_sample)
print(f"Process {process_id}: {total_samples} samples, kept {filtered_samples}")
sink.close()
elapsed_time = time.time() - start_time
return {
'process_id': process_id,
'processed_shards': len(shard_paths),
'total_samples': total_samples,
'filtered_samples': filtered_samples,
'elapsed_time': elapsed_time
}
def process_webdataset_parallel(input_pattern, output_dir, min_size=512, num_processes=None):
"""Process WebDataset tar files in parallel."""
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Get list of all tar files
tar_files = sorted(glob.glob(input_pattern))
if not tar_files:
print(f"No tar files found matching pattern: {input_pattern}")
return
print(f"Found {len(tar_files)} tar files to process")
# Set number of processes
if num_processes is None:
num_processes = mp.cpu_count()
print(f"Using {num_processes} processes")
# Distribute tar files across processes
files_per_process = len(tar_files) // num_processes
remainder = len(tar_files) % num_processes
process_args = []
start_idx = 0
for i in range(num_processes):
end_idx = start_idx + files_per_process
if i < remainder: # Distribute remainder files
end_idx += 1
process_files = tar_files[start_idx:end_idx]
if process_files: # Only create process if it has files to process
process_args.append((process_files, output_dir, min_size, i))
start_idx = end_idx
print(f"Files per process: {[len(args[0]) for args in process_args]}")
# Process shards in parallel
start_time = time.time()
with mp.Pool(processes=len(process_args)) as pool:
results = pool.map(process_multiple_shards, process_args)
total_elapsed = time.time() - start_time
# Aggregate results
total_samples = sum(r['total_samples'] for r in results if 'error' not in r)
total_filtered = sum(r['filtered_samples'] for r in results if 'error' not in r)
total_shards = sum(r['processed_shards'] for r in results if 'error' not in r)
errors = [r for r in results if 'error' in r]
print(f"\nParallel filtering complete:")
print(f"Total files processed: {total_shards}")
print(f"Total samples processed: {total_samples:,}")
print(f"Samples kept (>= {min_size}x{min_size}): {total_filtered:,}")
print(f"Filtering rate: {total_filtered/total_samples*100:.2f}%")
print(f"Total time: {total_elapsed:.2f} seconds")
print(f"Throughput: {total_samples/total_elapsed:.0f} samples/second")
print(f"Filtered samples written to: {output_dir}")
if errors:
print(f"\nErrors encountered: {len(errors)}")
for error in errors[:5]: # Show first 5 errors
print(f" Process {error['process_id']}: {error['error']}")
def main():
parser = argparse.ArgumentParser(description="Filter DataComp 1B by image resolution (parallel)")
parser.add_argument("--input_pattern",
default="/p/data1/mmlaion/datacomp/datacomp_1B/flat/*.tar",
help="Input WebDataset pattern")
parser.add_argument("--output_dir",
default="/p/data1/mmlaion/datacomp/datacomp_512x512/filtered",
help="Output directory for filtered data")
parser.add_argument("--min_size", type=int, default=512,
help="Minimum image size (width and height)")
parser.add_argument("--num_processes", type=int, default=None,
help="Number of parallel processes (default: CPU count)")
args = parser.parse_args()
print(f"Filtering DataComp 1B with minimum size: {args.min_size}x{args.min_size}")
print(f"Input pattern: {args.input_pattern}")
print(f"Output directory: {args.output_dir}")
process_webdataset_parallel(args.input_pattern, args.output_dir,
args.min_size, args.num_processes)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment