Skip to content

Instantly share code, notes, and snippets.

@dpressel
Created August 12, 2020 17:33
Show Gist options
  • Save dpressel/c60e61726daed146f9b3dd69e31d156f to your computer and use it in GitHub Desktop.
Save dpressel/c60e61726daed146f9b3dd69e31d156f to your computer and use it in GitHub Desktop.
import tensorflow as tf
import sys
import time
import os
assert len(sys.argv) == 2, f"Usage: {sys.argv[0]} <directory>"
dirname = sys.argv[1]
print(dirname)
md_file = os.path.join(dirname, 'md.yml')
input_files = tf.io.gfile.glob(os.path.join(dirname, '*.tfrecord'))
nfiles = len(input_files)
print(f'Loading {nfiles} files')
start_time = time.time()
loader = tf.data.TFRecordDataset(input_files)
num_samples = sum(1 for record in loader)
elapsed = time.time() - start_time
print(f'Loaded {num_samples} in {elapsed} seconds')
with open(md_file, 'w') as wf:
wf.write(f'num_samples: {num_samples}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment