Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Created January 17, 2022 19:29
Show Gist options
  • Save crowsonkb/f3c76d1605260c55d569e1f247565a06 to your computer and use it in GitHub Desktop.
Save crowsonkb/f3c76d1605260c55d569e1f247565a06 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""Computes the channel-wise means, standard deviations, and covariance
matrix of a dataset of images."""
import argparse
import torch
from torch.utils import data
from torchvision import datasets, transforms as T
from tqdm import tqdm
def get_raw_moments(x):
x = x.flatten(1).double()
return x.mean(1), x @ x.T / x[0].numel()
def main():
p = argparse.ArgumentParser(description=__doc__)
p.add_argument('dataset', type=str,
help='the root of the dataset')
p.add_argument('--num-workers', type=int, default=16,
help='the number of worker processes')
args = p.parse_args()
tf = T.Compose([
T.Lambda(lambda x: x.convert('RGB')),
T.ToTensor(),
T.Lambda(get_raw_moments),
])
dataset = datasets.ImageFolder(args.dataset, transform=tf)
dataloader = data.DataLoader(dataset, batch_size=1000,
num_workers=args.num_workers)
n = 0
mom_1_accum = torch.zeros([3], dtype=torch.double)
mom_2_accum = torch.zeros([3, 3], dtype=torch.double)
print(f'Using {args.num_workers} worker processes...')
with tqdm(total=len(dataset), unit='images') as pbar:
for (mom_1, mom_2), _ in dataloader:
n += len(mom_1)
mom_1_accum += mom_1.sum(0)
mom_2_accum += mom_2.sum(0)
pbar.update(len(mom_1))
mean = mom_1_accum / n
cov = mom_2_accum / n - mean[:, None] @ mean[None, :]
print(f'Number of images: {n}')
print(f'Means: {mean.tolist()}')
print(f'Standard deviations: {cov.diag().sqrt().tolist()}')
print('Covariance matrix:')
rows = [str(row.tolist()) for row in cov.unbind(1)]
print(f'[{rows[0]},\n {rows[1]},\n {rows[2]}]')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment