Skip to content

Instantly share code, notes, and snippets.

@securetorobert
Created July 15, 2018 10:28
Show Gist options
  • Save securetorobert/d06d72875a3a4535ffecef7249c01c50 to your computer and use it in GitHub Desktop.
Save securetorobert/d06d72875a3a4535ffecef7249c01c50 to your computer and use it in GitHub Desktop.
An input function for use with TensorFlow Estimators
def my_input_fn(file_path, perform_shuffle=False, repeat_count=1):
def decode_csv(line):
parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
label = parsed_line[-1:] # Last element is the label
del parsed_line[-1] # Delete last element
features = parsed_line # Everything (but last element) are the features
d = dict(zip(feature_names, features)), label
return d
dataset = (tf.data.TextLineDataset(file_path) # Read text file
.skip(1) # Skip header row
.map(decode_csv)) # Transform each elem by applying decode_csv fn
if perform_shuffle:
# Randomizes input using a window of 256 elements (read into memory)
dataset = dataset.shuffle(buffer_size=256)
dataset = dataset.repeat(repeat_count) # Repeats dataset this # times
dataset = dataset.batch(32) # Batch size to use
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment