Created
October 22, 2021 12:52
-
-
Save YugeTen/575065b6fc9dbff8b2563806f92e97fd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import io | |
| import subprocess | |
| import numpy as np | |
| from scipy.sparse.linalg import svds | |
| from scipy.io import loadmat | |
| import torch | |
| from torchvision import datasets, transforms | |
| CDspritesDataSize = torch.Size([3, 64, 64]) | |
| class CDspritesDataset(torch.utils.data.Dataset): | |
| """2D shapes dataset, WITH COLORS. | |
| More info here: | |
| https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_reloading_example.ipynb | |
| """ | |
| data_root = '../data' | |
| filename = 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz' | |
| npz_file = data_root + '/' + filename | |
| npzc_file = data_root + '/' + 'color_' + filename | |
| npzc_train_file = data_root + '/train_color_' + filename | |
| npzc_test_file = data_root + '/test_color_' + filename | |
| pca_filename = data_root + '/pca_color_' + filename | |
| def download_dataset(self, npz_file): | |
| from urllib import request | |
| url = 'https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true' | |
| print('Downloading ' + url) | |
| data = request.urlopen(url) | |
| with open(npz_file, 'wb') as f: | |
| f.write(data.read()) | |
| def compute_pca(self, path, X): | |
| print('Compute singular values...') | |
| _, s, _ = svds(X.reshape(X.shape[0], -1).astype(np.float32), k=20) | |
| s = s**2 | |
| np.savez(path, pca=s) | |
| @classmethod | |
| def load_pca(cls, latent_dim): | |
| pca = np.load(cls.pca_filename, mmap_mode='r')['pca'] | |
| return torch.Tensor(pca[:latent_dim]) | |
| def gen_and_split_dataset(self, npz_file, npzc_train_file, npzc_test_file, npzc_file, | |
| train_fract): | |
| print('Computing dataset split') | |
| rdataset = np.load(npz_file, encoding='latin1', mmap_mode='r') | |
| dataset = {'latents': rdataset['latents_values'][:, 1:], # drop colour | |
| 'images': rdataset['imgs']} | |
| latents = [] | |
| for r in range(3): | |
| colors = np.random.choice(3, (len(dataset['latents']),1)) | |
| latents.append(np.concatenate([dataset['latents'], colors], axis=-1)) | |
| images = np.tile(dataset['images'], (3, 1, 1)) | |
| latents = np.concatenate(latents, axis=0) | |
| split_idx = np.int(train_fract * len(latents)) | |
| shuffled_range = np.random.permutation(len(latents)) | |
| train_idx = shuffled_range[range(0, split_idx)] | |
| test_idx = shuffled_range[range(split_idx, len(latents))] | |
| np.savez(npzc_train_file, images=images[train_idx], latents=latents[train_idx]) | |
| np.savez(npzc_test_file, images=images[test_idx], latents=latents[test_idx]) | |
| np.savez(npzc_file, images=images, latents=latents) | |
| def __init__(self, data_root, train=True, train_fract=0.8, split=True): | |
| """ | |
| Args: | |
| npz_file (string): Path to the npz file. | |
| """ | |
| if not os.path.isfile(self.npz_file): | |
| self.download_dataset(self.npz_file) | |
| if not (os.path.isfile(self.npzc_train_file) or os.path.isfile(self.npzc_test_file) | |
| or os.path.isfile(self.npzc_file)): | |
| self.gen_and_split_dataset(self.npz_file, self.npzc_train_file, self.npzc_test_file, | |
| self.npzc_file, train_fract) | |
| if split: | |
| dataset = np.load(self.npzc_train_file if train else self.npzc_test_file, mmap_mode='r') | |
| else: | |
| dataset = np.load(self.npzc_file, mmap_mode='r') | |
| self.latents, self.images = dataset['latents'], dataset['images'] | |
| def __len__(self): | |
| return len(self.latents) | |
| def __getitem__(self, idx): | |
| image = torch.Tensor(self.images[idx]) | |
| latent = torch.Tensor(self.latents[idx]) | |
| canvas = torch.zeros_like(image).unsqueeze(0).repeat(3,1,1) | |
| canvas[latent[-1].long(),...] = image | |
| return (canvas, latent) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment