Created
September 18, 2018 09:47
-
-
Save securetorobert/e9df5cb08848b7fb02c0072c5c3661fe to your computer and use it in GitHub Desktop.
Prepare data for input into an estimator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
CSV_COLUMNS = ['medv', 'crim', 'zn', 'indus', 'chas', 'nox', 'rm', 'age', 'dis', 'rad', 'tax', 'ptratio', 'black', 'lstat'] | |
LABEL_COLUMN = 'medv' | |
DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0], [0.0], [0.0], [0.0], [0.0], [0], [0], [0.0], [0.0], [0.0]] | |
def read_dataset(filename, mode, batch_size = 16): | |
def _input_fn(): | |
def decode_csv(value_column): | |
columns = tf.decode_csv(value_column, record_defaults = DEFAULTS) | |
features = dict(zip(CSV_COLUMNS, columns)) | |
label = features.pop(LABEL_COLUMN) | |
return features, label | |
# Create list of files that match pattern | |
file_list = tf.gfile.Glob(filename) | |
# Create dataset from file list | |
dataset = tf.data.TextLineDataset(file_list).map(decode_csv) | |
if mode == tf.estimator.ModeKeys.TRAIN: | |
num_epochs = None # indefinitely | |
dataset = dataset.shuffle(buffer_size = 10 * batch_size) | |
else: | |
num_epochs = 1 # end-of-input after this | |
dataset = dataset.repeat(num_epochs).batch(batch_size) | |
return dataset.make_one_shot_iterator().get_next() | |
return _input_fn |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment