#!/usr/bin/env python3 """Computes the channel-wise means, standard deviations, and covariance matrix of a dataset of images.""" import argparse import torch from torch.utils import data from torchvision import datasets, transforms as T from tqdm import tqdm def get_raw_moments(x): x = x.flatten(1).double() return x.mean(1), x @ x.T / x[0].numel() def main(): p = argparse.ArgumentParser(description=__doc__) p.add_argument('dataset', type=str, help='the root of the dataset') p.add_argument('--num-workers', type=int, default=16, help='the number of worker processes') args = p.parse_args() tf = T.Compose([ T.Lambda(lambda x: x.convert('RGB')), T.ToTensor(), T.Lambda(get_raw_moments), ]) dataset = datasets.ImageFolder(args.dataset, transform=tf) dataloader = data.DataLoader(dataset, batch_size=1000, num_workers=args.num_workers) n = 0 mom_1_accum = torch.zeros([3], dtype=torch.double) mom_2_accum = torch.zeros([3, 3], dtype=torch.double) print(f'Using {args.num_workers} worker processes...') with tqdm(total=len(dataset), unit='images') as pbar: for (mom_1, mom_2), _ in dataloader: n += len(mom_1) mom_1_accum += mom_1.sum(0) mom_2_accum += mom_2.sum(0) pbar.update(len(mom_1)) mean = mom_1_accum / n cov = mom_2_accum / n - mean[:, None] @ mean[None, :] print(f'Number of images: {n}') print(f'Means: {mean.tolist()}') print(f'Standard deviations: {cov.diag().sqrt().tolist()}') print('Covariance matrix:') rows = [str(row.tolist()) for row in cov.unbind(1)] print(f'[{rows[0]},\n {rows[1]},\n {rows[2]}]') if __name__ == '__main__': main()