Created
March 10, 2025 19:24
-
-
Save richardliaw/825407493ae4fc490a74e39d4b536914 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
dataloader_benchmark.py - A script to benchmark PyTorch DataLoader performance on image datasets. | |
""" | |
import argparse | |
import os | |
import time | |
import torch | |
import torch.utils.data | |
import torchvision.datasets as datasets | |
import torchvision.transforms as transforms | |
from tabulate import tabulate | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='PyTorch DataLoader Benchmark') | |
parser.add_argument('data_path', type=str, help='Path to the image folder') | |
parser.add_argument('--batch-sizes', type=int, nargs='+', default=[32, 64, 128, 256], | |
help='Batch sizes to test (default: [32, 64, 128, 256])') | |
parser.add_argument('--workers', type=int, nargs='+', default=[1, 2, 4, 8], | |
help='Number of workers to test (default: [1, 2, 4, 8])') | |
parser.add_argument('--epochs', type=int, default=3, | |
help='Number of epochs for each configuration (default: 3)') | |
parser.add_argument('--warmup-epochs', type=int, default=1, | |
help='Number of warmup epochs before measuring (default: 1)') | |
parser.add_argument('--pin-memory', action='store_true', default=True, | |
help='Use pin_memory in DataLoader') | |
parser.add_argument('--image-size', type=int, default=224, | |
help='Size to resize images to (default: 224)') | |
parser.add_argument('--use-gpu', action='store_true', | |
help='Transfer data to GPU to simulate training scenario') | |
parser.add_argument('--prefetch-factor', type=int, default=2, | |
help='Number of batches loaded in advance by each worker') | |
return parser.parse_args() | |
def get_transforms(image_size): | |
"""Returns standard image transformations.""" | |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
return transforms.Compose([ | |
transforms.RandomResizedCrop(image_size), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
def measure_dataloader_performance(data_path, batch_size, num_workers, num_epochs, | |
warmup_epochs, pin_memory, image_size, use_gpu, | |
prefetch_factor): | |
"""Measures dataloader performance with specific configuration.""" | |
# Set up device | |
device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu") | |
# Create dataset | |
try: | |
dataset = datasets.ImageFolder(data_path, get_transforms(image_size)) | |
except Exception as e: | |
print(f"Error loading dataset from {data_path}: {e}") | |
return None | |
# Create dataloader | |
data_loader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=num_workers, | |
pin_memory=pin_memory, | |
prefetch_factor=prefetch_factor if num_workers > 0 else None, | |
persistent_workers=True if num_workers > 0 else False | |
) | |
total_samples = len(dataset) | |
num_batches = len(data_loader) | |
print(f"\nConfiguration: batch_size={batch_size}, workers={num_workers}") | |
print(f"Total samples: {total_samples}, Batches per epoch: {num_batches}") | |
# Warmup | |
print("Running warmup epochs...") | |
for epoch in range(warmup_epochs): | |
start_time = time.time() | |
for i, (images, _) in enumerate(data_loader): | |
if use_gpu and torch.cuda.is_available(): | |
images = images.to(device, non_blocking=True) | |
if i % max(1, num_batches // 10) == 0: | |
print(f"Warmup Epoch {epoch+1}/{warmup_epochs}, Batch {i+1}/{num_batches}") | |
warmup_time = time.time() - start_time | |
print(f"Warmup epoch {epoch+1} time: {warmup_time:.2f}s") | |
# Measurement | |
print("Measuring performance...") | |
epoch_times = [] | |
batch_times = [] | |
for epoch in range(num_epochs): | |
start_time = time.time() | |
batch_start = time.time() | |
for i, (images, _) in enumerate(data_loader): | |
if use_gpu and torch.cuda.is_available(): | |
images = images.to(device, non_blocking=True) | |
# Simulate a small computation to ensure GPU transfer completes | |
if use_gpu and torch.cuda.is_available(): | |
_ = images.mean() | |
batch_end = time.time() | |
batch_times.append(batch_end - batch_start) | |
if i % max(1, num_batches // 10) == 0: | |
print(f"Epoch {epoch+1}/{num_epochs}, Batch {i+1}/{num_batches}") | |
batch_start = time.time() | |
epoch_time = time.time() - start_time | |
epoch_times.append(epoch_time) | |
# Calculate statistics for this epoch | |
samples_per_sec = total_samples / epoch_time | |
avg_batch_time = sum(batch_times[-num_batches:]) / min(len(batch_times), num_batches) | |
print(f"Epoch {epoch+1} - Time: {epoch_time:.2f}s, Throughput: {samples_per_sec:.2f} samples/sec") | |
# Overall statistics | |
avg_epoch_time = sum(epoch_times) / len(epoch_times) | |
avg_samples_per_sec = total_samples / avg_epoch_time | |
avg_batch_time = sum(batch_times) / len(batch_times) | |
return { | |
'batch_size': batch_size, | |
'num_workers': num_workers, | |
'avg_epoch_time': avg_epoch_time, | |
'avg_samples_per_sec': avg_samples_per_sec, | |
'avg_batch_time': avg_batch_time * 1000, # Convert to ms | |
'samples_per_batch': batch_size, | |
'batches_per_sec': 1.0 / avg_batch_time | |
} | |
def main(): | |
args = parse_args() | |
# Check if the data path exists | |
if not os.path.exists(args.data_path): | |
print(f"Error: Data path '{args.data_path}' does not exist.") | |
return | |
print(f"DataLoader Benchmark on: {args.data_path}") | |
print(f"Device: {'GPU' if args.use_gpu and torch.cuda.is_available() else 'CPU'}") | |
if args.use_gpu and torch.cuda.is_available(): | |
print(f"GPU: {torch.cuda.get_device_name(0)}") | |
print(f"PyTorch version: {torch.__version__}") | |
results = [] | |
# Test different configurations | |
for batch_size in args.batch_sizes: | |
for num_workers in args.workers: | |
result = measure_dataloader_performance( | |
args.data_path, batch_size, num_workers, args.epochs, | |
args.warmup_epochs, args.pin_memory, args.image_size, | |
args.use_gpu, args.prefetch_factor | |
) | |
if result: | |
results.append(result) | |
# Sort results by throughput (samples/sec) | |
results.sort(key=lambda x: x['avg_samples_per_sec'], reverse=True) | |
# Display results in a table | |
headers = [ | |
'Batch Size', 'Workers', 'Epoch Time (s)', | |
'Throughput (samples/s)', 'Batch Time (ms)', 'Batches/sec' | |
] | |
table_data = [ | |
[ | |
r['batch_size'], r['num_workers'], f"{r['avg_epoch_time']:.2f}", | |
f"{r['avg_samples_per_sec']:.2f}", f"{r['avg_batch_time']:.2f}", | |
f"{r['batches_per_sec']:.2f}" | |
] for r in results | |
] | |
print("\nResults (sorted by throughput):") | |
print(tabulate(table_data, headers=headers, tablefmt='grid')) | |
# Print best configuration | |
best = results[0] | |
print(f"\nBest configuration:") | |
print(f"Batch size: {best['batch_size']}") | |
print(f"Workers: {best['num_workers']}") | |
print(f"Throughput: {best['avg_samples_per_sec']:.2f} samples/sec") | |
print(f"Average batch time: {best['avg_batch_time']:.2f} ms") | |
print(f"Average epoch time: {best['avg_epoch_time']:.2f} s") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment