Created
February 7, 2020 09:47
-
-
Save dsalaj/85cda6a6e819cc5b9136fc8d45fe1e99 to your computer and use it in GitHub Desktop.
Example of tf.data.Dataset.from_generator usage with parametrized generator
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
x_train = [i for i in range(0, 20, 2)] # even | |
x_val = [i for i in range(1, 20, 2)] # odd | |
y_train = [i**2 for i in x_train] # squared | |
y_val = [i**2 for i in x_val] | |
def gen_data_epoch(test=False): # parametrized generator | |
train_data = x_val if test else x_train | |
label_data = y_val if test else y_train | |
n_tests = len(train_data) | |
for test_idx in range(len(train_data)): | |
yield train_data[test_idx], label_data[test_idx] | |
def get_dataset(test=False): | |
return tf.data.Dataset.from_generator( | |
gen_data_epoch, args=(test,), | |
output_types=(tf.int32, tf.int32)) | |
print("Train:", [(i[0].numpy(), i[1].numpy()) for i in get_dataset().take(5)]) | |
print("Test: ", [(i[0].numpy(), i[1].numpy()) for i in get_dataset(test=True).take(5)]) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment