Last active
October 9, 2018 20:00
-
-
Save versatran01/04216a35d44b76011275e5bf6478f880 to your computer and use it in GitHub Desktop.
convert gqn dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from collections import namedtuple | |
import tensorflow as tf | |
import torch | |
import gzip | |
import multiprocessing as mp | |
from functools import partial | |
def collect_files(path, ext=None, key=None): | |
if key is None: | |
files = sorted(os.listdir(path)) | |
else: | |
files = sorted(os.listdir(path), key=key) | |
if ext is not None: | |
files = [f for f in files if os.path.splitext(f)[-1] == ext] | |
return [os.path.join(path, fname) for fname in files] | |
tf.logging.set_verbosity(tf.logging.ERROR) | |
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
DatasetInfo = namedtuple('DatasetInfo', ['image_size', 'seq_length']) | |
all_datasets = dict( | |
jaco=DatasetInfo(image_size=64, seq_length=11), | |
mazes=DatasetInfo(image_size=84, seq_length=300), | |
rooms_free_camera_with_object_rotations=DatasetInfo(image_size=128, seq_length=10), | |
rooms_ring_camera=DatasetInfo(image_size=64, seq_length=10), | |
rooms_free_camera_no_object_rotations=DatasetInfo(image_size=64, seq_length=10), | |
shepard_metzler_5_parts=DatasetInfo(image_size=64, seq_length=15), | |
shepard_metzler_7_parts=DatasetInfo(image_size=64, seq_length=15) | |
) | |
_pose_dim = 5 | |
def convert_record(record, info, batch_size=None): | |
print(record) | |
path, filename = os.path.split(record) | |
basename = os.path.basename(filename) | |
scenes = process_record(record, info, batch_size) | |
out = os.path.join(path, f'{basename}.pt.gz') | |
save_to_dist(scenes, out) | |
def save_to_dist(scenes, path): | |
with gzip.open(path, 'wb') as f: | |
torch.save(scenes, f) | |
def process_record(record, info, batch_size=None): | |
engine = tf.python_io.tf_record_iterator(record) | |
scenes = [] | |
for i, data in enumerate(engine): | |
if i == batch_size: | |
break | |
scene = convert_to_numpy(data, info) | |
scenes.append(scene) | |
return scenes | |
def process_images(example, seq_length, image_size): | |
"""Instantiates the ops used to preprocess the frames data.""" | |
images = tf.concat(example['frames'], axis=0) | |
images = tf.map_fn(tf.image.decode_jpeg, tf.reshape(images, [-1]), | |
dtype=tf.uint8, back_prop=False) | |
shape = (image_size, image_size, 3) | |
images = tf.reshape(images, (-1, seq_length) + shape) | |
return images | |
def process_poses(example, seq_length): | |
"""Instantiates the ops used to preprocess the cameras data.""" | |
poses = example['cameras'] | |
poses = tf.reshape(poses, (-1, seq_length, _pose_dim)) | |
return poses | |
def convert_to_numpy(raw_data, info): | |
seq_length = info.seq_length | |
image_size = info.image_size | |
feature = {'frames': tf.FixedLenFeature(shape=seq_length, dtype=tf.string), | |
'cameras': tf.FixedLenFeature(shape=seq_length * _pose_dim, dtype=tf.float32)} | |
example = tf.parse_single_example(raw_data, feature) | |
images = process_images(example, seq_length, image_size) | |
poses = process_poses(example, seq_length) | |
return images.numpy().squeeze(), poses.numpy().squeeze() | |
if __name__ == '__main__': | |
tf.enable_eager_execution() | |
base_dir = '~/Workspace/dataset/gqn_dataset' | |
dataset = 'shepard_metzler_5_parts' | |
base_dir = os.path.expanduser(base_dir) | |
mode = 'train' | |
n = 10 # only convert the first n tfrecord | |
batch_size = 32 # number of sequence in each converted file | |
# this will convert the first 10 tfrecord to 10 pt.gz files and each file | |
# will have 32 (batch_size) sequences and each sequence is a tuple of | |
# images (15,64,64,3) and poses (15,5) | |
print(f'dataset: {dataset}') | |
print(f'base_dir: {base_dir}') | |
info = all_datasets[dataset] | |
data_dir = os.path.join(base_dir, dataset) | |
records = collect_files(os.path.join(data_dir, mode), '.tfrecord') | |
if n is not None: | |
records = records[:n] | |
# num_proc = mp.cpu_count() | |
with mp.Pool(processes=4) as pool: | |
f = partial(convert_record, info=info, batch_size=batch_size) | |
pool.map(f, records) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment