Last active
January 3, 2018 20:29
-
-
Save Echooff3/e0fbb868da9a02abc5b56c9e618a5c35 to your computer and use it in GitHub Desktop.
Contrived tf.estimator.DNNClassifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
import sys | |
from tensorflow.python import debug as tf_debug | |
hooks = [tf_debug.LocalCLIDebugHook()] | |
tf.logging.set_verbosity(tf.logging.INFO) | |
trainX = np.array([[1,1,0,1],[0,0,1,0],[1,0,1,1],[0,0,1,1]]) | |
labelX = np.array([[1],[0],[1],[0]]) | |
num_classes = 2 | |
feature_names = ['f1','f2','f3','f4'] | |
feature_columns = [tf.feature_column.numeric_column(k) for k in feature_names] | |
classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns, | |
n_classes=num_classes, #setting number of classes here | |
hidden_units=[10]) | |
def input_fn(): | |
def gen1(a,b): | |
print(a.shape, b.shape) | |
features = tf.split(a,4) | |
return dict(zip(feature_names, features)),b | |
dataset = (tf.data.Dataset.from_tensor_slices((trainX, labelX)).map(gen1)) | |
dataset = dataset.repeat(8) | |
dataset = dataset.batch(32) | |
iterator = dataset.make_one_shot_iterator() | |
data, labels = iterator.get_next() | |
return data, labels | |
def input_fn_pred(in_arr): | |
def gen1(a): | |
features = tf.split(a,4) | |
return dict(zip(feature_names, features)) | |
dataset = tf.data.Dataset.from_tensor_slices((in_arr)).map(gen1) | |
iterator = dataset.make_one_shot_iterator() | |
data = iterator.get_next() | |
return data, None | |
# check values | |
# next_batch = input_fn() | |
# with tf.Session() as sess: | |
# first_batch = sess.run(next_batch) | |
# print(first_batch) | |
# sys.exit(0) | |
classifier.train(input_fn=lambda: input_fn()) | |
evaluate_result = classifier.evaluate(input_fn=lambda: input_fn()) | |
print("Evaluation results") | |
for key in evaluate_result: | |
print(" {}, was: {}".format(key, evaluate_result[key])) | |
test_set = np.array([[1,1,1,1],[0,0,0,0],[0,0,1,1],[1,1,0,0]]) | |
predict_results = classifier.predict( | |
input_fn=lambda: input_fn_pred(test_set)) | |
for prediction in predict_results: | |
print prediction["class_ids"][0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment