Skip to content

Instantly share code, notes, and snippets.

@richardliaw
Created March 10, 2025 21:51
Show Gist options
  • Save richardliaw/87b1f03e1440866abe86a84fee98b118 to your computer and use it in GitHub Desktop.
Save richardliaw/87b1f03e1440866abe86a84fee98b118 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
dataloader_benchmark.py - A script to benchmark Ray Data performance on image datasets.
"""
import ray
import argparse
import os
import time
import torch
import torch.utils.data
# import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tabulate import tabulate
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch DataLoader Benchmark')
parser.add_argument('data_path', type=str, help='Path to the image folder')
parser.add_argument('--batch-sizes', type=int, nargs='+', default=[32, 64, 128, 256],
help='Batch sizes to test (default: [32, 64, 128, 256])')
parser.add_argument('--warmup-batches', type=int, default=100,
help='Number of warmup batches before measuring (default: 100)')
parser.add_argument('--measure-batches', type=int, default=500,
help='Number of batches to measure (default: 500)')
parser.add_argument('--image-size', type=int, default=224,
help='Size to resize images to (default: 224)')
parser.add_argument('--prefetch-factor', type=int, default=2,
help='Number of batches loaded in advance by each worker')
return parser.parse_args()
def get_transforms(image_size):
"""Returns standard image transformations."""
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
return transforms.Compose([
transforms.ToTensor(),
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
normalize,
])
def measure_dataloader_performance(data_path, batch_size, warmup_batches,
measure_batches, image_size,
prefetch_factor):
"""Measures dataloader performance with specific configuration."""
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create dataset using Ray Data
try:
ds = ray.data.read_images(data_path)
# Apply transformations using Ray Data's map operations
transform = get_transforms(image_size)
def test_transform(x):
try:
transformed = transform(x["image"])
except Exception:
import numpy as np;
obj = np.random.rand(500, 500, 3)
obj = obj.astype(np.uint8)
transformed = transform(obj)
return {"image": transformed}
ds = ds.map(test_transform)
except Exception as e:
print(f"Error loading dataset from {data_path}: {e}")
return None
# Calculate size of each image in MB
# 3 channels, float32 (4 bytes), image_size x image_size
image_size_mb = 3 * 4 * image_size * image_size / (1024 * 1024)
print(f"\nConfiguration: batch_size={batch_size}")
# Create iterator
data_iter = ds.iter_torch_batches(
batch_size=batch_size,
prefetch_batches=prefetch_factor,
drop_last=False,
device=device
)
# Warmup
print("Running warmup batches...")
batch_count = 0
start_time = time.time()
for batch in data_iter:
images = batch["image"]
if torch.cuda.is_available():
images = images.to(device, non_blocking=True)
batch_count += 1
if batch_count >= warmup_batches:
break
if batch_count % max(1, warmup_batches // 10) == 0:
print(f"Warmup batch {batch_count}/{warmup_batches}")
warmup_time = time.time() - start_time
print(f"Warmup time: {warmup_time:.2f}s")
# Reset iterator for measurement
data_iter = ds.iter_torch_batches(
batch_size=batch_size,
prefetch_batches=prefetch_factor,
drop_last=False,
device=device
)
# Measurement
print("Measuring performance...")
batch_times = []
batch_count = 0
batch_start = time.time()
for batch in data_iter:
images = batch["image"]
if torch.cuda.is_available():
images = images.to(device, non_blocking=True)
# Simulate a small computation to ensure GPU transfer completes
if torch.cuda.is_available():
_ = images.mean()
batch_end = time.time()
batch_times.append(batch_end - batch_start)
batch_count += 1
if batch_count >= measure_batches:
break
if batch_count % max(1, measure_batches // 10) == 0:
print(f"Measuring batch {batch_count}/{measure_batches}")
batch_start = time.time()
# Calculate statistics
total_time = sum(batch_times)
total_samples_processed = batch_count * batch_size
avg_samples_per_sec = total_samples_processed / total_time
avg_batch_time = total_time / batch_count
# Calculate throughput in MB/s
throughput_mbs = (avg_samples_per_sec * image_size_mb)
print(f"Time: {total_time:.2f}s, Throughput: {avg_samples_per_sec:.2f} samples/sec ({throughput_mbs:.2f} MB/s)")
return {
'batch_size': batch_size,
'total_time': total_time,
'avg_samples_per_sec': avg_samples_per_sec,
'avg_batch_time': avg_batch_time * 1000, # Convert to ms
'samples_per_batch': batch_size,
'batches_per_sec': 1.0 / avg_batch_time,
'throughput_mbs': throughput_mbs
}
def main():
args = parse_args()
# Check if the data path exists
if not os.path.exists(args.data_path):
print(f"Error: Data path '{args.data_path}' does not exist.")
return
print(f"DataLoader Benchmark on: {args.data_path}")
print(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"PyTorch version: {torch.__version__}")
results = []
# Test different configurations
for batch_size in args.batch_sizes:
result = measure_dataloader_performance(
args.data_path, batch_size, args.warmup_batches,
args.measure_batches, args.image_size,
args.prefetch_factor
)
if result:
results.append(result)
# Sort results by batch size (highest to lowest)
results.sort(key=lambda x: x['batch_size'], reverse=True)
# Display results in a table
headers = [
'Batch Size', 'Total Time (s)',
'Throughput (samples/s)', 'Throughput (MB/s)', 'Batch Time (ms)', 'Batches/sec'
]
table_data = [
[
r['batch_size'], f"{r['total_time']:.2f}",
f"{r['avg_samples_per_sec']:.2f}", f"{r['throughput_mbs']:.2f}",
f"{r['avg_batch_time']:.2f}", f"{r['batches_per_sec']:.2f}"
] for r in results
]
print("\nResults (sorted by batch size):")
print(tabulate(table_data, headers=headers, tablefmt='grid'))
# Print best configuration
best = results[0]
print(f"\nLargest batch size configuration:")
print(f"Batch size: {best['batch_size']}")
print(f"Throughput: {best['avg_samples_per_sec']:.2f} samples/sec ({best['throughput_mbs']:.2f} MB/s)")
print(f"Average batch time: {best['avg_batch_time']:.2f} ms")
print(f"Total time: {best['total_time']:.2f} s")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment