Created
July 15, 2018 21:39
-
-
Save securetorobert/0301148d6ecf454580e9c9c9fe0bee52 to your computer and use it in GitHub Desktop.
Input function for estimator in TensorFlow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def input_fn(X, y, batch_size=16, epochs=1, shuffle=False): | |
#create a dictionary of features and their values | |
features = {key:np.array(value) for key,value in dict(X).items()} | |
#create a tf.data compliant dataset | |
d = tf.data.Dataset.from_tensor_slices((features, y)) | |
#we need our data in batches | |
d = d.batch(batch_size).repeat(epochs) | |
#optionally shuffle our data | |
if shuffle: | |
d.shuffle(10000) | |
#make an iterator | |
iterator = d.make_one_shot_iterator() | |
return iterator.get_next() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment