Created
March 10, 2025 21:51
-
-
Save richardliaw/87b1f03e1440866abe86a84fee98b118 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
dataloader_benchmark.py - A script to benchmark Ray Data performance on image datasets. | |
""" | |
import ray | |
import argparse | |
import os | |
import time | |
import torch | |
import torch.utils.data | |
# import torchvision.datasets as datasets | |
import torchvision.transforms as transforms | |
from tabulate import tabulate | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='PyTorch DataLoader Benchmark') | |
parser.add_argument('data_path', type=str, help='Path to the image folder') | |
parser.add_argument('--batch-sizes', type=int, nargs='+', default=[32, 64, 128, 256], | |
help='Batch sizes to test (default: [32, 64, 128, 256])') | |
parser.add_argument('--warmup-batches', type=int, default=100, | |
help='Number of warmup batches before measuring (default: 100)') | |
parser.add_argument('--measure-batches', type=int, default=500, | |
help='Number of batches to measure (default: 500)') | |
parser.add_argument('--image-size', type=int, default=224, | |
help='Size to resize images to (default: 224)') | |
parser.add_argument('--prefetch-factor', type=int, default=2, | |
help='Number of batches loaded in advance by each worker') | |
return parser.parse_args() | |
def get_transforms(image_size): | |
"""Returns standard image transformations.""" | |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
return transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.RandomResizedCrop(image_size), | |
transforms.RandomHorizontalFlip(), | |
normalize, | |
]) | |
def measure_dataloader_performance(data_path, batch_size, warmup_batches, | |
measure_batches, image_size, | |
prefetch_factor): | |
"""Measures dataloader performance with specific configuration.""" | |
# Set up device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Create dataset using Ray Data | |
try: | |
ds = ray.data.read_images(data_path) | |
# Apply transformations using Ray Data's map operations | |
transform = get_transforms(image_size) | |
def test_transform(x): | |
try: | |
transformed = transform(x["image"]) | |
except Exception: | |
import numpy as np; | |
obj = np.random.rand(500, 500, 3) | |
obj = obj.astype(np.uint8) | |
transformed = transform(obj) | |
return {"image": transformed} | |
ds = ds.map(test_transform) | |
except Exception as e: | |
print(f"Error loading dataset from {data_path}: {e}") | |
return None | |
# Calculate size of each image in MB | |
# 3 channels, float32 (4 bytes), image_size x image_size | |
image_size_mb = 3 * 4 * image_size * image_size / (1024 * 1024) | |
print(f"\nConfiguration: batch_size={batch_size}") | |
# Create iterator | |
data_iter = ds.iter_torch_batches( | |
batch_size=batch_size, | |
prefetch_batches=prefetch_factor, | |
drop_last=False, | |
device=device | |
) | |
# Warmup | |
print("Running warmup batches...") | |
batch_count = 0 | |
start_time = time.time() | |
for batch in data_iter: | |
images = batch["image"] | |
if torch.cuda.is_available(): | |
images = images.to(device, non_blocking=True) | |
batch_count += 1 | |
if batch_count >= warmup_batches: | |
break | |
if batch_count % max(1, warmup_batches // 10) == 0: | |
print(f"Warmup batch {batch_count}/{warmup_batches}") | |
warmup_time = time.time() - start_time | |
print(f"Warmup time: {warmup_time:.2f}s") | |
# Reset iterator for measurement | |
data_iter = ds.iter_torch_batches( | |
batch_size=batch_size, | |
prefetch_batches=prefetch_factor, | |
drop_last=False, | |
device=device | |
) | |
# Measurement | |
print("Measuring performance...") | |
batch_times = [] | |
batch_count = 0 | |
batch_start = time.time() | |
for batch in data_iter: | |
images = batch["image"] | |
if torch.cuda.is_available(): | |
images = images.to(device, non_blocking=True) | |
# Simulate a small computation to ensure GPU transfer completes | |
if torch.cuda.is_available(): | |
_ = images.mean() | |
batch_end = time.time() | |
batch_times.append(batch_end - batch_start) | |
batch_count += 1 | |
if batch_count >= measure_batches: | |
break | |
if batch_count % max(1, measure_batches // 10) == 0: | |
print(f"Measuring batch {batch_count}/{measure_batches}") | |
batch_start = time.time() | |
# Calculate statistics | |
total_time = sum(batch_times) | |
total_samples_processed = batch_count * batch_size | |
avg_samples_per_sec = total_samples_processed / total_time | |
avg_batch_time = total_time / batch_count | |
# Calculate throughput in MB/s | |
throughput_mbs = (avg_samples_per_sec * image_size_mb) | |
print(f"Time: {total_time:.2f}s, Throughput: {avg_samples_per_sec:.2f} samples/sec ({throughput_mbs:.2f} MB/s)") | |
return { | |
'batch_size': batch_size, | |
'total_time': total_time, | |
'avg_samples_per_sec': avg_samples_per_sec, | |
'avg_batch_time': avg_batch_time * 1000, # Convert to ms | |
'samples_per_batch': batch_size, | |
'batches_per_sec': 1.0 / avg_batch_time, | |
'throughput_mbs': throughput_mbs | |
} | |
def main(): | |
args = parse_args() | |
# Check if the data path exists | |
if not os.path.exists(args.data_path): | |
print(f"Error: Data path '{args.data_path}' does not exist.") | |
return | |
print(f"DataLoader Benchmark on: {args.data_path}") | |
print(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}") | |
if torch.cuda.is_available(): | |
print(f"GPU: {torch.cuda.get_device_name(0)}") | |
print(f"PyTorch version: {torch.__version__}") | |
results = [] | |
# Test different configurations | |
for batch_size in args.batch_sizes: | |
result = measure_dataloader_performance( | |
args.data_path, batch_size, args.warmup_batches, | |
args.measure_batches, args.image_size, | |
args.prefetch_factor | |
) | |
if result: | |
results.append(result) | |
# Sort results by batch size (highest to lowest) | |
results.sort(key=lambda x: x['batch_size'], reverse=True) | |
# Display results in a table | |
headers = [ | |
'Batch Size', 'Total Time (s)', | |
'Throughput (samples/s)', 'Throughput (MB/s)', 'Batch Time (ms)', 'Batches/sec' | |
] | |
table_data = [ | |
[ | |
r['batch_size'], f"{r['total_time']:.2f}", | |
f"{r['avg_samples_per_sec']:.2f}", f"{r['throughput_mbs']:.2f}", | |
f"{r['avg_batch_time']:.2f}", f"{r['batches_per_sec']:.2f}" | |
] for r in results | |
] | |
print("\nResults (sorted by batch size):") | |
print(tabulate(table_data, headers=headers, tablefmt='grid')) | |
# Print best configuration | |
best = results[0] | |
print(f"\nLargest batch size configuration:") | |
print(f"Batch size: {best['batch_size']}") | |
print(f"Throughput: {best['avg_samples_per_sec']:.2f} samples/sec ({best['throughput_mbs']:.2f} MB/s)") | |
print(f"Average batch time: {best['avg_batch_time']:.2f} ms") | |
print(f"Total time: {best['total_time']:.2f} s") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment