Skip to content

Instantly share code, notes, and snippets.

@crypdick
Created January 28, 2025 20:22
Show Gist options
  • Save crypdick/99cce761702c143bcf4439197ca2b5af to your computer and use it in GitHub Desktop.
Save crypdick/99cce761702c143bcf4439197ca2b5af to your computer and use it in GitHub Desktop.
code used to compute mean and standard deviation of all pytorch datasets. Output available on my blog post.
import inspect
import csv
import os
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import v2
dataset_names = [
"Caltech101", "Caltech256", "CelebA", "CIFAR10", "CIFAR100", "Country211", "DTD",
"EMNIST", "EuroSAT", "FakeData", "FashionMNIST", "FER2013",
"Flickr8k", "Flickr30k", "Flowers102", "Food101", "GTSRB", "INaturalist", "ImageNet",
"Imagenette", "KMNIST", "LFWPeople", "LSUN", "MNIST", "Omniglot", "OxfordIIITPet",
"Places365", "PCAM", "QMNIST", "RenderedSST2", "SEMEION", "SBU", "StanfordCars",
"STL10", "SUN397", "SVHN", "USPS", "CocoDetection", "Cityscapes", "Kitti",
"SBDataset", "VOCSegmentation", "VOCDetection", "WIDERFace"
]
def compute_global_mean(loader):
global_sum = 0
global_count = 0
for batch in loader:
x, _y = batch
global_sum += x.flatten().sum()
global_count += len(x.flatten())
global_mean = global_sum / global_count
return global_mean
def compute_global_stddev(loader, global_mean):
residual_sum = 0
count = 0
for batch in loader:
x, _y = batch
# Subtract mean from each element, then square the differences
residual_sum += ((x.flatten() - global_mean) ** 2).sum()
count += len(x.flatten())
# Calculate variance, then take square root for stddev
global_stddev = torch.sqrt(residual_sum / count)
return global_stddev
def load_existing_results(csv_path):
results = {}
if os.path.exists(csv_path):
with open(csv_path, 'r') as f:
reader = csv.DictReader(f)
for row in reader:
results[row['dataset']] = {
'mean': float(row['mean']),
'stddev': float(row['stddev'])
}
return results
def save_result(csv_path, dataset_name, mean, stddev):
file_exists = os.path.exists(csv_path)
with open(csv_path, 'a', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['dataset', 'mean', 'stddev'])
if not file_exists:
writer.writeheader()
writer.writerow({
'dataset': dataset_name,
'mean': f"{mean:.6f}",
'stddev': f"{stddev:.6f}"
})
# Initialize results file path
csv_path = "dataset_stats.csv"
existing_results = load_existing_results(csv_path)
for name in dataset_names:
if name in existing_results:
print(f"Skipping {name} - already processed. "
f"mean: {existing_results[name]['mean']:.4f}, "
f"stddev: {existing_results[name]['stddev']:.4f}")
continue
try:
dataset_constructor = getattr(datasets, name)
print(f"Computing stats for {name}...")
# note: some datasets use `train=True` and others use `split="train"`
# Get constructor signature to check available kwargs
params = inspect.signature(dataset_constructor).parameters
transform = v2.Compose([
v2.ToImage(),
# ensure all images in a minibatch are the same size so that they can be stacked
# this has a very small effect on the stats but it's a reasonable approximation
v2.Resize((224, 224), antialias=True),
v2.ToDtype(torch.float32, scale=True)
])
if 'train' in params:
training_data = dataset_constructor(
root="~/tmp/data",
train=True,
download=True,
transform=transform,
)
else:
training_data = dataset_constructor(
root="~/tmp/data",
split="train",
download=True,
transform=transform,
)
loader = DataLoader(training_data, batch_size=500)
mean = compute_global_mean(loader)
stddev = compute_global_stddev(loader, mean)
# Save result immediately after computation
save_result(csv_path, name, mean, stddev)
print(f"{name} - mean: {mean:.4f}, stddev: {stddev:.4f}")
except Exception as e:
print(f"Failed to process {name}: {str(e)}")
# after all datasets are processed, print the results
existing_results = load_existing_results(csv_path)
print(existing_results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment