Skip to content

Instantly share code, notes, and snippets.

@securetorobert
Created July 15, 2018 21:39
Show Gist options
  • Save securetorobert/0301148d6ecf454580e9c9c9fe0bee52 to your computer and use it in GitHub Desktop.
Save securetorobert/0301148d6ecf454580e9c9c9fe0bee52 to your computer and use it in GitHub Desktop.
Input function for estimator in TensorFlow
def input_fn(X, y, batch_size=16, epochs=1, shuffle=False):
#create a dictionary of features and their values
features = {key:np.array(value) for key,value in dict(X).items()}
#create a tf.data compliant dataset
d = tf.data.Dataset.from_tensor_slices((features, y))
#we need our data in batches
d = d.batch(batch_size).repeat(epochs)
#optionally shuffle our data
if shuffle:
d.shuffle(10000)
#make an iterator
iterator = d.make_one_shot_iterator()
return iterator.get_next()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment