Created
March 5, 2019 02:10
-
-
Save crawles/0f583d6e96d702cb82a4d0f75729bdc3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Use entire batch since this is such a small dataset. | |
NUM_EXAMPLES = len(y_train) | |
def make_input_fn(X, y, n_epochs=None, shuffle=True): | |
def input_fn(): | |
dataset = tf.data.Dataset.from_tensor_slices((dict(X), y)) | |
if shuffle: | |
dataset = dataset.shuffle(NUM_EXAMPLES) | |
# For training, cycle thru dataset as many times as need (n_epochs=None). | |
dataset = dataset.repeat(n_epochs) | |
# In memory training doesn't use batching. | |
dataset = dataset.batch(NUM_EXAMPLES) | |
return dataset | |
return input_fn | |
# Training and evaluation input functions. | |
train_input_fn = make_input_fn(dftrain, y_train) | |
eval_input_fn = make_input_fn(dfeval, y_eval, shuffle=False, n_epochs=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment