Created
October 22, 2021 12:52
-
-
Save YugeTen/57bc50cd786f4bd472df592570fca0cf to your computer and use it in GitHub Desktop.
cdsprites.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import io | |
import subprocess | |
import numpy as np | |
from scipy.sparse.linalg import svds | |
from scipy.io import loadmat | |
import torch | |
from torchvision import datasets, transforms | |
CDspritesDataSize = torch.Size([3, 64, 64]) | |
class CDspritesDataset(torch.utils.data.Dataset): | |
"""2D shapes dataset, WITH COLORS. | |
More info here: | |
https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_reloading_example.ipynb | |
""" | |
data_root = '../data' | |
filename = 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz' | |
npz_file = data_root + '/' + filename | |
npzc_file = data_root + '/' + 'color_' + filename | |
npzc_train_file = data_root + '/train_color_' + filename | |
npzc_test_file = data_root + '/test_color_' + filename | |
pca_filename = data_root + '/pca_color_' + filename | |
def download_dataset(self, npz_file): | |
from urllib import request | |
url = 'https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true' | |
print('Downloading ' + url) | |
data = request.urlopen(url) | |
with open(npz_file, 'wb') as f: | |
f.write(data.read()) | |
def compute_pca(self, path, X): | |
print('Compute singular values...') | |
_, s, _ = svds(X.reshape(X.shape[0], -1).astype(np.float32), k=20) | |
s = s**2 | |
np.savez(path, pca=s) | |
@classmethod | |
def load_pca(cls, latent_dim): | |
pca = np.load(cls.pca_filename, mmap_mode='r')['pca'] | |
return torch.Tensor(pca[:latent_dim]) | |
def gen_and_split_dataset(self, npz_file, npzc_train_file, npzc_test_file, npzc_file, | |
train_fract): | |
print('Computing dataset split') | |
rdataset = np.load(npz_file, encoding='latin1', mmap_mode='r') | |
dataset = {'latents': rdataset['latents_values'][:, 1:], # drop colour | |
'images': rdataset['imgs']} | |
latents = [] | |
for r in range(3): | |
colors = np.random.choice(3, (len(dataset['latents']),1)) | |
latents.append(np.concatenate([dataset['latents'], colors], axis=-1)) | |
images = np.tile(dataset['images'], (3, 1, 1)) | |
latents = np.concatenate(latents, axis=0) | |
split_idx = np.int(train_fract * len(latents)) | |
shuffled_range = np.random.permutation(len(latents)) | |
train_idx = shuffled_range[range(0, split_idx)] | |
test_idx = shuffled_range[range(split_idx, len(latents))] | |
np.savez(npzc_train_file, images=images[train_idx], latents=latents[train_idx]) | |
np.savez(npzc_test_file, images=images[test_idx], latents=latents[test_idx]) | |
np.savez(npzc_file, images=images, latents=latents) | |
def __init__(self, data_root, train=True, train_fract=0.8, split=True): | |
""" | |
Args: | |
npz_file (string): Path to the npz file. | |
""" | |
if not os.path.isfile(self.npz_file): | |
self.download_dataset(self.npz_file) | |
if not (os.path.isfile(self.npzc_train_file) or os.path.isfile(self.npzc_test_file) | |
or os.path.isfile(self.npzc_file)): | |
self.gen_and_split_dataset(self.npz_file, self.npzc_train_file, self.npzc_test_file, | |
self.npzc_file, train_fract) | |
if split: | |
dataset = np.load(self.npzc_train_file if train else self.npzc_test_file, mmap_mode='r') | |
else: | |
dataset = np.load(self.npzc_file, mmap_mode='r') | |
self.latents, self.images = dataset['latents'], dataset['images'] | |
def __len__(self): | |
return len(self.latents) | |
def __getitem__(self, idx): | |
image = torch.Tensor(self.images[idx]) | |
latent = torch.Tensor(self.latents[idx]) | |
canvas = torch.zeros_like(image).unsqueeze(0).repeat(3,1,1) | |
canvas[latent[-1].long(),...] = image | |
return (canvas, latent) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment