Skip to content

Instantly share code, notes, and snippets.

@mehdidc
Created November 8, 2025 18:26
Show Gist options
  • Save mehdidc/78e64cdd085e1079edb3f1b25ca6a447 to your computer and use it in GitHub Desktop.
Save mehdidc/78e64cdd085e1079edb3f1b25ca6a447 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Compute resolution statistics for WebDataset.
This script processes a WebDataset in parallel and counts how many samples
meet various resolution thresholds. It outputs statistics in the format:
>= 256x256 1_055_309_295
>= 384x384 698_616_282
...
Based on a WebDataset filtering script, but modified to only compute statistics.
"""
import os
import argparse
import webdataset as wds
from PIL import Image
import io
import glob
import multiprocessing as mp
import time
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
print(f'Handling webdataset error ({repr(exn)}). Ignoring.')
return True
def get_image_resolution(sample):
"""Extract image resolution from sample. Returns (width, height) or None."""
try:
# Get image from sample (support jpg, webp, png)
image_data = None
for ext in ['jpg', 'jpeg', 'webp', 'png']:
if ext in sample:
image_data = sample[ext]
break
if image_data is None:
return None
image = Image.open(io.BytesIO(image_data))
return image.size # (width, height)
except Exception as e:
return None
def process_multiple_shards(args):
"""Process multiple tar file shards in a single process and count resolutions."""
shard_paths, resolution_list, process_id = args
try:
# Initialize counters for each resolution threshold
resolution_counts = [0] * len(resolution_list)
total_samples = 0
dataset = wds.DataPipeline(
wds.SimpleShardList(shard_paths),
wds.tarfile_to_samples(handler=log_and_continue)
)
print(f"Process {process_id} started")
start_time = time.time()
for sample in dataset:
total_samples += 1
# Get image resolution
resolution = get_image_resolution(sample)
if resolution is None:
continue
width, height = resolution
# Check against each threshold and increment counters
for i, min_size in enumerate(resolution_list):
if width >= min_size and height >= min_size:
resolution_counts[i] += 1
elapsed_time = time.time() - start_time
print(f"Process {process_id}: {total_samples} samples processed")
return {
'process_id': process_id,
'processed_shards': len(shard_paths),
'total_samples': total_samples,
'resolution_counts': resolution_counts,
'elapsed_time': elapsed_time
}
except Exception as e:
return {
'process_id': process_id,
'error': str(e)
}
def process_webdataset_parallel(input_pattern, resolution_list, num_processes=None):
"""Process WebDataset tar files in parallel and compute resolution statistics."""
# Get list of all tar files
tar_files = sorted(glob.glob(input_pattern))
if not tar_files:
print(f"No tar files found matching pattern: {input_pattern}")
return
print(f"Found {len(tar_files)} tar files to process")
# Set number of processes
if num_processes is None:
num_processes = mp.cpu_count()
print(f"Using {num_processes} processes")
# Distribute tar files across processes
files_per_process = len(tar_files) // num_processes
remainder = len(tar_files) % num_processes
process_args = []
start_idx = 0
for i in range(num_processes):
end_idx = start_idx + files_per_process
if i < remainder: # Distribute remainder files
end_idx += 1
process_files = tar_files[start_idx:end_idx]
if process_files: # Only create process if it has files to process
process_args.append((process_files, resolution_list, i))
start_idx = end_idx
print(f"Files per process: {[len(args[0]) for args in process_args]}")
# Process shards in parallel
start_time = time.time()
with mp.Pool(processes=len(process_args)) as pool:
results = pool.map(process_multiple_shards, process_args)
total_elapsed = time.time() - start_time
# Aggregate results
total_samples = sum(r['total_samples'] for r in results if 'error' not in r)
total_shards = sum(r['processed_shards'] for r in results if 'error' not in r)
# Aggregate resolution counts from all processes
total_counts = [0] * len(resolution_list)
for r in results:
if 'error' not in r:
for i, count in enumerate(r['resolution_counts']):
total_counts[i] += count
errors = [r for r in results if 'error' in r]
# Print statistics in the requested format
print(f"\nResolution statistics:")
print("=" * 40)
for min_size, count in zip(resolution_list, total_counts):
print(f">= {min_size}x{min_size} {count:_}")
print(f"\nSummary:")
print(f"Total files processed: {total_shards}")
print(f"Total samples processed: {total_samples:,}")
print(f"Total time: {total_elapsed:.2f} seconds")
print(f"Throughput: {total_samples/total_elapsed:.0f} samples/second")
if errors:
print(f"\nErrors encountered: {len(errors)}")
for error in errors[:5]: # Show first 5 errors
print(f" Process {error['process_id']}: {error['error']}")
def main():
parser = argparse.ArgumentParser(description="Compute resolution statistics for WebDataset")
parser.add_argument("--input_pattern",
default="/p/data1/mmlaion/datacomp/datacomp_1B/flat/*.tar",
help="Input WebDataset pattern")
parser.add_argument("--resolutions",
type=int,
nargs='+',
default=[256, 384, 448, 512, 640, 784, 896, 1024],
help="List of resolution thresholds")
parser.add_argument("--num_processes", type=int, default=None,
help="Number of parallel processes (default: CPU count)")
args = parser.parse_args()
# Sort resolutions to ensure consistent output order
resolution_list = sorted(args.resolutions)
print(f"Computing resolution statistics for thresholds: {resolution_list}")
print(f"Input pattern: {args.input_pattern}")
process_webdataset_parallel(args.input_pattern, resolution_list, args.num_processes)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment