Skip to content

Instantly share code, notes, and snippets.

@dsalaj
Created February 7, 2020 09:47
Show Gist options
  • Save dsalaj/85cda6a6e819cc5b9136fc8d45fe1e99 to your computer and use it in GitHub Desktop.
Save dsalaj/85cda6a6e819cc5b9136fc8d45fe1e99 to your computer and use it in GitHub Desktop.
Example of tf.data.Dataset.from_generator usage with parametrized generator
import tensorflow as tf
x_train = [i for i in range(0, 20, 2)] # even
x_val = [i for i in range(1, 20, 2)] # odd
y_train = [i**2 for i in x_train] # squared
y_val = [i**2 for i in x_val]
def gen_data_epoch(test=False): # parametrized generator
train_data = x_val if test else x_train
label_data = y_val if test else y_train
n_tests = len(train_data)
for test_idx in range(len(train_data)):
yield train_data[test_idx], label_data[test_idx]
def get_dataset(test=False):
return tf.data.Dataset.from_generator(
gen_data_epoch, args=(test,),
output_types=(tf.int32, tf.int32))
print("Train:", [(i[0].numpy(), i[1].numpy()) for i in get_dataset().take(5)])
print("Test: ", [(i[0].numpy(), i[1].numpy()) for i in get_dataset(test=True).take(5)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment