Skip to content

Instantly share code, notes, and snippets.

@securetorobert
Created September 18, 2018 09:47
Show Gist options
  • Save securetorobert/e9df5cb08848b7fb02c0072c5c3661fe to your computer and use it in GitHub Desktop.
Save securetorobert/e9df5cb08848b7fb02c0072c5c3661fe to your computer and use it in GitHub Desktop.
Prepare data for input into an estimator
CSV_COLUMNS = ['medv', 'crim', 'zn', 'indus', 'chas', 'nox', 'rm', 'age', 'dis', 'rad', 'tax', 'ptratio', 'black', 'lstat']
LABEL_COLUMN = 'medv'
DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0], [0.0], [0.0], [0.0], [0.0], [0], [0], [0.0], [0.0], [0.0]]
def read_dataset(filename, mode, batch_size = 16):
def _input_fn():
def decode_csv(value_column):
columns = tf.decode_csv(value_column, record_defaults = DEFAULTS)
features = dict(zip(CSV_COLUMNS, columns))
label = features.pop(LABEL_COLUMN)
return features, label
# Create list of files that match pattern
file_list = tf.gfile.Glob(filename)
# Create dataset from file list
dataset = tf.data.TextLineDataset(file_list).map(decode_csv)
if mode == tf.estimator.ModeKeys.TRAIN:
num_epochs = None # indefinitely
dataset = dataset.shuffle(buffer_size = 10 * batch_size)
else:
num_epochs = 1 # end-of-input after this
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
return _input_fn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment