Created
January 28, 2025 20:22
-
-
Save crypdick/99cce761702c143bcf4439197ca2b5af to your computer and use it in GitHub Desktop.
code used to compute mean and standard deviation of all pytorch datasets. Output available on my blog post.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
import csv | |
import os | |
import torch | |
from torch.utils.data import DataLoader | |
from torchvision import datasets | |
from torchvision.transforms import v2 | |
dataset_names = [ | |
"Caltech101", "Caltech256", "CelebA", "CIFAR10", "CIFAR100", "Country211", "DTD", | |
"EMNIST", "EuroSAT", "FakeData", "FashionMNIST", "FER2013", | |
"Flickr8k", "Flickr30k", "Flowers102", "Food101", "GTSRB", "INaturalist", "ImageNet", | |
"Imagenette", "KMNIST", "LFWPeople", "LSUN", "MNIST", "Omniglot", "OxfordIIITPet", | |
"Places365", "PCAM", "QMNIST", "RenderedSST2", "SEMEION", "SBU", "StanfordCars", | |
"STL10", "SUN397", "SVHN", "USPS", "CocoDetection", "Cityscapes", "Kitti", | |
"SBDataset", "VOCSegmentation", "VOCDetection", "WIDERFace" | |
] | |
def compute_global_mean(loader): | |
global_sum = 0 | |
global_count = 0 | |
for batch in loader: | |
x, _y = batch | |
global_sum += x.flatten().sum() | |
global_count += len(x.flatten()) | |
global_mean = global_sum / global_count | |
return global_mean | |
def compute_global_stddev(loader, global_mean): | |
residual_sum = 0 | |
count = 0 | |
for batch in loader: | |
x, _y = batch | |
# Subtract mean from each element, then square the differences | |
residual_sum += ((x.flatten() - global_mean) ** 2).sum() | |
count += len(x.flatten()) | |
# Calculate variance, then take square root for stddev | |
global_stddev = torch.sqrt(residual_sum / count) | |
return global_stddev | |
def load_existing_results(csv_path): | |
results = {} | |
if os.path.exists(csv_path): | |
with open(csv_path, 'r') as f: | |
reader = csv.DictReader(f) | |
for row in reader: | |
results[row['dataset']] = { | |
'mean': float(row['mean']), | |
'stddev': float(row['stddev']) | |
} | |
return results | |
def save_result(csv_path, dataset_name, mean, stddev): | |
file_exists = os.path.exists(csv_path) | |
with open(csv_path, 'a', newline='') as f: | |
writer = csv.DictWriter(f, fieldnames=['dataset', 'mean', 'stddev']) | |
if not file_exists: | |
writer.writeheader() | |
writer.writerow({ | |
'dataset': dataset_name, | |
'mean': f"{mean:.6f}", | |
'stddev': f"{stddev:.6f}" | |
}) | |
# Initialize results file path | |
csv_path = "dataset_stats.csv" | |
existing_results = load_existing_results(csv_path) | |
for name in dataset_names: | |
if name in existing_results: | |
print(f"Skipping {name} - already processed. " | |
f"mean: {existing_results[name]['mean']:.4f}, " | |
f"stddev: {existing_results[name]['stddev']:.4f}") | |
continue | |
try: | |
dataset_constructor = getattr(datasets, name) | |
print(f"Computing stats for {name}...") | |
# note: some datasets use `train=True` and others use `split="train"` | |
# Get constructor signature to check available kwargs | |
params = inspect.signature(dataset_constructor).parameters | |
transform = v2.Compose([ | |
v2.ToImage(), | |
# ensure all images in a minibatch are the same size so that they can be stacked | |
# this has a very small effect on the stats but it's a reasonable approximation | |
v2.Resize((224, 224), antialias=True), | |
v2.ToDtype(torch.float32, scale=True) | |
]) | |
if 'train' in params: | |
training_data = dataset_constructor( | |
root="~/tmp/data", | |
train=True, | |
download=True, | |
transform=transform, | |
) | |
else: | |
training_data = dataset_constructor( | |
root="~/tmp/data", | |
split="train", | |
download=True, | |
transform=transform, | |
) | |
loader = DataLoader(training_data, batch_size=500) | |
mean = compute_global_mean(loader) | |
stddev = compute_global_stddev(loader, mean) | |
# Save result immediately after computation | |
save_result(csv_path, name, mean, stddev) | |
print(f"{name} - mean: {mean:.4f}, stddev: {stddev:.4f}") | |
except Exception as e: | |
print(f"Failed to process {name}: {str(e)}") | |
# after all datasets are processed, print the results | |
existing_results = load_existing_results(csv_path) | |
print(existing_results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment