Skip to content

Instantly share code, notes, and snippets.

@richardliaw
Created March 10, 2025 19:24
Show Gist options
  • Save richardliaw/825407493ae4fc490a74e39d4b536914 to your computer and use it in GitHub Desktop.
Save richardliaw/825407493ae4fc490a74e39d4b536914 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
dataloader_benchmark.py - A script to benchmark PyTorch DataLoader performance on image datasets.
"""
import argparse
import os
import time
import torch
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tabulate import tabulate
def parse_args():
parser = argparse.ArgumentParser(description='PyTorch DataLoader Benchmark')
parser.add_argument('data_path', type=str, help='Path to the image folder')
parser.add_argument('--batch-sizes', type=int, nargs='+', default=[32, 64, 128, 256],
help='Batch sizes to test (default: [32, 64, 128, 256])')
parser.add_argument('--workers', type=int, nargs='+', default=[1, 2, 4, 8],
help='Number of workers to test (default: [1, 2, 4, 8])')
parser.add_argument('--epochs', type=int, default=3,
help='Number of epochs for each configuration (default: 3)')
parser.add_argument('--warmup-epochs', type=int, default=1,
help='Number of warmup epochs before measuring (default: 1)')
parser.add_argument('--pin-memory', action='store_true', default=True,
help='Use pin_memory in DataLoader')
parser.add_argument('--image-size', type=int, default=224,
help='Size to resize images to (default: 224)')
parser.add_argument('--use-gpu', action='store_true',
help='Transfer data to GPU to simulate training scenario')
parser.add_argument('--prefetch-factor', type=int, default=2,
help='Number of batches loaded in advance by each worker')
return parser.parse_args()
def get_transforms(image_size):
"""Returns standard image transformations."""
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
return transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
def measure_dataloader_performance(data_path, batch_size, num_workers, num_epochs,
warmup_epochs, pin_memory, image_size, use_gpu,
prefetch_factor):
"""Measures dataloader performance with specific configuration."""
# Set up device
device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu")
# Create dataset
try:
dataset = datasets.ImageFolder(data_path, get_transforms(image_size))
except Exception as e:
print(f"Error loading dataset from {data_path}: {e}")
return None
# Create dataloader
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
prefetch_factor=prefetch_factor if num_workers > 0 else None,
persistent_workers=True if num_workers > 0 else False
)
total_samples = len(dataset)
num_batches = len(data_loader)
print(f"\nConfiguration: batch_size={batch_size}, workers={num_workers}")
print(f"Total samples: {total_samples}, Batches per epoch: {num_batches}")
# Warmup
print("Running warmup epochs...")
for epoch in range(warmup_epochs):
start_time = time.time()
for i, (images, _) in enumerate(data_loader):
if use_gpu and torch.cuda.is_available():
images = images.to(device, non_blocking=True)
if i % max(1, num_batches // 10) == 0:
print(f"Warmup Epoch {epoch+1}/{warmup_epochs}, Batch {i+1}/{num_batches}")
warmup_time = time.time() - start_time
print(f"Warmup epoch {epoch+1} time: {warmup_time:.2f}s")
# Measurement
print("Measuring performance...")
epoch_times = []
batch_times = []
for epoch in range(num_epochs):
start_time = time.time()
batch_start = time.time()
for i, (images, _) in enumerate(data_loader):
if use_gpu and torch.cuda.is_available():
images = images.to(device, non_blocking=True)
# Simulate a small computation to ensure GPU transfer completes
if use_gpu and torch.cuda.is_available():
_ = images.mean()
batch_end = time.time()
batch_times.append(batch_end - batch_start)
if i % max(1, num_batches // 10) == 0:
print(f"Epoch {epoch+1}/{num_epochs}, Batch {i+1}/{num_batches}")
batch_start = time.time()
epoch_time = time.time() - start_time
epoch_times.append(epoch_time)
# Calculate statistics for this epoch
samples_per_sec = total_samples / epoch_time
avg_batch_time = sum(batch_times[-num_batches:]) / min(len(batch_times), num_batches)
print(f"Epoch {epoch+1} - Time: {epoch_time:.2f}s, Throughput: {samples_per_sec:.2f} samples/sec")
# Overall statistics
avg_epoch_time = sum(epoch_times) / len(epoch_times)
avg_samples_per_sec = total_samples / avg_epoch_time
avg_batch_time = sum(batch_times) / len(batch_times)
return {
'batch_size': batch_size,
'num_workers': num_workers,
'avg_epoch_time': avg_epoch_time,
'avg_samples_per_sec': avg_samples_per_sec,
'avg_batch_time': avg_batch_time * 1000, # Convert to ms
'samples_per_batch': batch_size,
'batches_per_sec': 1.0 / avg_batch_time
}
def main():
args = parse_args()
# Check if the data path exists
if not os.path.exists(args.data_path):
print(f"Error: Data path '{args.data_path}' does not exist.")
return
print(f"DataLoader Benchmark on: {args.data_path}")
print(f"Device: {'GPU' if args.use_gpu and torch.cuda.is_available() else 'CPU'}")
if args.use_gpu and torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"PyTorch version: {torch.__version__}")
results = []
# Test different configurations
for batch_size in args.batch_sizes:
for num_workers in args.workers:
result = measure_dataloader_performance(
args.data_path, batch_size, num_workers, args.epochs,
args.warmup_epochs, args.pin_memory, args.image_size,
args.use_gpu, args.prefetch_factor
)
if result:
results.append(result)
# Sort results by throughput (samples/sec)
results.sort(key=lambda x: x['avg_samples_per_sec'], reverse=True)
# Display results in a table
headers = [
'Batch Size', 'Workers', 'Epoch Time (s)',
'Throughput (samples/s)', 'Batch Time (ms)', 'Batches/sec'
]
table_data = [
[
r['batch_size'], r['num_workers'], f"{r['avg_epoch_time']:.2f}",
f"{r['avg_samples_per_sec']:.2f}", f"{r['avg_batch_time']:.2f}",
f"{r['batches_per_sec']:.2f}"
] for r in results
]
print("\nResults (sorted by throughput):")
print(tabulate(table_data, headers=headers, tablefmt='grid'))
# Print best configuration
best = results[0]
print(f"\nBest configuration:")
print(f"Batch size: {best['batch_size']}")
print(f"Workers: {best['num_workers']}")
print(f"Throughput: {best['avg_samples_per_sec']:.2f} samples/sec")
print(f"Average batch time: {best['avg_batch_time']:.2f} ms")
print(f"Average epoch time: {best['avg_epoch_time']:.2f} s")
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment