Created
January 17, 2022 19:29
-
-
Save crowsonkb/f3c76d1605260c55d569e1f247565a06 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Computes the channel-wise means, standard deviations, and covariance | |
matrix of a dataset of images.""" | |
import argparse | |
import torch | |
from torch.utils import data | |
from torchvision import datasets, transforms as T | |
from tqdm import tqdm | |
def get_raw_moments(x): | |
x = x.flatten(1).double() | |
return x.mean(1), x @ x.T / x[0].numel() | |
def main(): | |
p = argparse.ArgumentParser(description=__doc__) | |
p.add_argument('dataset', type=str, | |
help='the root of the dataset') | |
p.add_argument('--num-workers', type=int, default=16, | |
help='the number of worker processes') | |
args = p.parse_args() | |
tf = T.Compose([ | |
T.Lambda(lambda x: x.convert('RGB')), | |
T.ToTensor(), | |
T.Lambda(get_raw_moments), | |
]) | |
dataset = datasets.ImageFolder(args.dataset, transform=tf) | |
dataloader = data.DataLoader(dataset, batch_size=1000, | |
num_workers=args.num_workers) | |
n = 0 | |
mom_1_accum = torch.zeros([3], dtype=torch.double) | |
mom_2_accum = torch.zeros([3, 3], dtype=torch.double) | |
print(f'Using {args.num_workers} worker processes...') | |
with tqdm(total=len(dataset), unit='images') as pbar: | |
for (mom_1, mom_2), _ in dataloader: | |
n += len(mom_1) | |
mom_1_accum += mom_1.sum(0) | |
mom_2_accum += mom_2.sum(0) | |
pbar.update(len(mom_1)) | |
mean = mom_1_accum / n | |
cov = mom_2_accum / n - mean[:, None] @ mean[None, :] | |
print(f'Number of images: {n}') | |
print(f'Means: {mean.tolist()}') | |
print(f'Standard deviations: {cov.diag().sqrt().tolist()}') | |
print('Covariance matrix:') | |
rows = [str(row.tolist()) for row in cov.unbind(1)] | |
print(f'[{rows[0]},\n {rows[1]},\n {rows[2]}]') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment