Created
August 14, 2018 14:04
-
-
Save securetorobert/9312f693f1e6f52c0e437e03b29f7cfc to your computer and use it in GitHub Desktop.
Code designed to run training and evaluation on a DNNClassifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import tensorflow_hub as hub | |
import pandas as pd | |
import numpy as np | |
import shutil | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
MODULE_SPEC_50 = 'https://tfhub.dev/google/nnlm-en-dim50/1' | |
MODULE_SPEC_128='https://tfhub.dev/google/nnlm-en-dim128/1' | |
print(tf.__version__) | |
# Input function | |
CSV_COLUMNS = ['sentiment', 'cleanReviewText'] | |
LABEL_COLUMN = 'sentiment' | |
DEFAULTS = [[0], [' ']] | |
def read_dataset(filename, mode, batch_size = 512): | |
def _input_fn(): | |
def decode_csv(value_column): | |
columns = tf.decode_csv(value_column, record_defaults = DEFAULTS) | |
features = dict(zip(CSV_COLUMNS, columns)) | |
label = features.pop(LABEL_COLUMN) | |
return features, label | |
# Create list of files that match pattern | |
file_list = tf.gfile.Glob(filename) | |
# Create dataset from file list | |
dataset = tf.data.TextLineDataset(file_list).map(decode_csv) | |
if mode == tf.estimator.ModeKeys.TRAIN: | |
num_epochs = None # indefinitely | |
dataset = dataset.shuffle(buffer_size = 10 * batch_size) | |
else: | |
num_epochs = 1 # end-of-input after this | |
dataset = dataset.repeat(num_epochs).batch(batch_size) | |
return dataset.make_one_shot_iterator().get_next() | |
return _input_fn | |
# Embedding Feature | |
embedding_feat_col = hub.text_embedding_column(key='cleanReviewText', module_spec=MODULE_SPEC_50) | |
# Serving Input Function | |
def serving_input_fn(): | |
feature_placeholders = { | |
'cleanReviewText' : tf.placeholder(tf.string, [None]) | |
} | |
features = { | |
'cleanReviewText': hub.text_embedding_column(key='cleanReviewText', module_spec=MODULE_SPEC_50) | |
} | |
return tf.estimator.export.ServingInputReceiver(features, feature_placeholders) | |
# Train and Evaluate | |
def train_and_evaluate(output_dir, num_train_steps): | |
estimator = tf.estimator.DNNClassifier( | |
model_dir = output_dir, hidden_units=[512, 1024, 512, 256], n_classes=2, | |
feature_columns = [embedding_feat_col]) | |
train_spec=tf.estimator.TrainSpec( | |
input_fn = read_dataset('./csv_headless/amazon_instant_video.csv', mode = tf.estimator.ModeKeys.TRAIN), | |
max_steps = num_train_steps) | |
exporter = tf.estimator.LatestExporter('exporter', serving_input_fn) | |
eval_spec=tf.estimator.EvalSpec( | |
input_fn = read_dataset('./csv_headless/automotive.csv', mode = tf.estimator.ModeKeys.EVAL), | |
steps = None, | |
start_delay_secs = 1, # start evaluating after N seconds | |
throttle_secs = 10, # evaluate every N seconds | |
exporters = exporter) | |
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) | |
# Run training | |
OUTDIR = 'sentiment_trained' | |
shutil.rmtree(OUTDIR, ignore_errors = True) | |
train_and_evaluate(OUTDIR, num_train_steps = 100) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment