#!/usr/bin/env python3

"""Computes the channel-wise means, standard deviations, and covariance
matrix of a dataset of images."""

import argparse

import torch
from torch.utils import data
from torchvision import datasets, transforms as T
from tqdm import tqdm


def get_raw_moments(x):
    x = x.flatten(1).double()
    return x.mean(1), x @ x.T / x[0].numel()


def main():
    p = argparse.ArgumentParser(description=__doc__)
    p.add_argument('dataset', type=str,
                   help='the root of the dataset')
    p.add_argument('--num-workers', type=int, default=16,
                   help='the number of worker processes')
    args = p.parse_args()

    tf = T.Compose([
        T.Lambda(lambda x: x.convert('RGB')),
        T.ToTensor(),
        T.Lambda(get_raw_moments),
    ])
    dataset = datasets.ImageFolder(args.dataset, transform=tf)
    dataloader = data.DataLoader(dataset, batch_size=1000,
                                 num_workers=args.num_workers)

    n = 0
    mom_1_accum = torch.zeros([3], dtype=torch.double)
    mom_2_accum = torch.zeros([3, 3], dtype=torch.double)

    print(f'Using {args.num_workers} worker processes...')

    with tqdm(total=len(dataset), unit='images') as pbar:
        for (mom_1, mom_2), _ in dataloader:
            n += len(mom_1)
            mom_1_accum += mom_1.sum(0)
            mom_2_accum += mom_2.sum(0)
            pbar.update(len(mom_1))
    
    mean = mom_1_accum / n
    cov = mom_2_accum / n - mean[:, None] @ mean[None, :]

    print(f'Number of images: {n}')
    print(f'Means: {mean.tolist()}')
    print(f'Standard deviations: {cov.diag().sqrt().tolist()}')
    print('Covariance matrix:')
    rows = [str(row.tolist()) for row in cov.unbind(1)]
    print(f'[{rows[0]},\n {rows[1]},\n {rows[2]}]')


if __name__ == '__main__':
    main()