Skip to content

Instantly share code, notes, and snippets.

@geffy
Created May 27, 2017 16:49
Show Gist options
  • Save geffy/091634cb3475f65a77e74dc7f64d08b0 to your computer and use it in GitHub Desktop.
Save geffy/091634cb3475f65a77e74dc7f64d08b0 to your computer and use it in GitHub Desktop.
import numpy as np
import tensorflow as tf
import os
import glob
def tf2npz(tf_path, export_folder='/ssd/yt8m/data_npz/'):
vid_ids = []
labels = []
mean_rgb = []
mean_audio = []
tf_basename = os.path.basename(tf_path)
npz_basename = tf_basename[:-len('.tfrecord')] + '.npz'
isTrain = '/test' not in tf_path
for example in tf.python_io.tf_record_iterator(tf_path):
tf_example = tf.train.Example.FromString(example).features
vid_ids.append(tf_example.feature['video_id'].bytes_list.value[0].decode(encoding='UTF-8'))
if isTrain:
labels.append(np.array(tf_example.feature['labels'].int64_list.value))
mean_rgb.append(np.array(tf_example.feature['mean_rgb'].float_list.value).astype(np.float16))
mean_audio.append(np.array(tf_example.feature['mean_audio'].float_list.value).astype(np.float16))
save_path = export_folder + '/' + npz_basename
np.savez(save_path,
rgb=np.array(mean_rgb),
audio=np.array(mean_audio),
ids=np.array(vid_ids),
labels=labels
)
from multiprocessing import Pool
with Pool(6) as p:
p.map(tf2npz, glob.glob('/ssd/yt8m/data_tfrecord/*.tfrecord'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment