Skip to content

Instantly share code, notes, and snippets.

@YugeTen
Created October 22, 2021 12:52
Show Gist options
  • Save YugeTen/57bc50cd786f4bd472df592570fca0cf to your computer and use it in GitHub Desktop.
Save YugeTen/57bc50cd786f4bd472df592570fca0cf to your computer and use it in GitHub Desktop.
cdsprites.py
import os
import io
import subprocess
import numpy as np
from scipy.sparse.linalg import svds
from scipy.io import loadmat
import torch
from torchvision import datasets, transforms
CDspritesDataSize = torch.Size([3, 64, 64])
class CDspritesDataset(torch.utils.data.Dataset):
"""2D shapes dataset, WITH COLORS.
More info here:
https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_reloading_example.ipynb
"""
data_root = '../data'
filename = 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'
npz_file = data_root + '/' + filename
npzc_file = data_root + '/' + 'color_' + filename
npzc_train_file = data_root + '/train_color_' + filename
npzc_test_file = data_root + '/test_color_' + filename
pca_filename = data_root + '/pca_color_' + filename
def download_dataset(self, npz_file):
from urllib import request
url = 'https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true'
print('Downloading ' + url)
data = request.urlopen(url)
with open(npz_file, 'wb') as f:
f.write(data.read())
def compute_pca(self, path, X):
print('Compute singular values...')
_, s, _ = svds(X.reshape(X.shape[0], -1).astype(np.float32), k=20)
s = s**2
np.savez(path, pca=s)
@classmethod
def load_pca(cls, latent_dim):
pca = np.load(cls.pca_filename, mmap_mode='r')['pca']
return torch.Tensor(pca[:latent_dim])
def gen_and_split_dataset(self, npz_file, npzc_train_file, npzc_test_file, npzc_file,
train_fract):
print('Computing dataset split')
rdataset = np.load(npz_file, encoding='latin1', mmap_mode='r')
dataset = {'latents': rdataset['latents_values'][:, 1:], # drop colour
'images': rdataset['imgs']}
latents = []
for r in range(3):
colors = np.random.choice(3, (len(dataset['latents']),1))
latents.append(np.concatenate([dataset['latents'], colors], axis=-1))
images = np.tile(dataset['images'], (3, 1, 1))
latents = np.concatenate(latents, axis=0)
split_idx = np.int(train_fract * len(latents))
shuffled_range = np.random.permutation(len(latents))
train_idx = shuffled_range[range(0, split_idx)]
test_idx = shuffled_range[range(split_idx, len(latents))]
np.savez(npzc_train_file, images=images[train_idx], latents=latents[train_idx])
np.savez(npzc_test_file, images=images[test_idx], latents=latents[test_idx])
np.savez(npzc_file, images=images, latents=latents)
def __init__(self, data_root, train=True, train_fract=0.8, split=True):
"""
Args:
npz_file (string): Path to the npz file.
"""
if not os.path.isfile(self.npz_file):
self.download_dataset(self.npz_file)
if not (os.path.isfile(self.npzc_train_file) or os.path.isfile(self.npzc_test_file)
or os.path.isfile(self.npzc_file)):
self.gen_and_split_dataset(self.npz_file, self.npzc_train_file, self.npzc_test_file,
self.npzc_file, train_fract)
if split:
dataset = np.load(self.npzc_train_file if train else self.npzc_test_file, mmap_mode='r')
else:
dataset = np.load(self.npzc_file, mmap_mode='r')
self.latents, self.images = dataset['latents'], dataset['images']
def __len__(self):
return len(self.latents)
def __getitem__(self, idx):
image = torch.Tensor(self.images[idx])
latent = torch.Tensor(self.latents[idx])
canvas = torch.zeros_like(image).unsqueeze(0).repeat(3,1,1)
canvas[latent[-1].long(),...] = image
return (canvas, latent)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment