Last active
June 23, 2018 18:34
-
-
Save mchirico/cce03212eda66e654c54 to your computer and use it in GitHub Desktop.
TensorFlow - Model has been trained, Now run it against test data.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
This example takes the restored values. You don't have to rebuild | |
the graph. But, it assumes you've trained it with the following | |
program: | |
https://gist.github.com/mchirico/bcc376fb336b73f24b29 | |
You'll get output similar to the following | |
... | |
Run 0,0.979020953178 | |
Correct prediction | |
[[ 6.17585704e-02 8.63590300e-01 7.46511072e-02] | |
[ 9.98804331e-01 1.19561062e-03 3.25832108e-13] | |
[ 1.52018686e-07 4.49650863e-04 9.99550164e-01] | |
[ 1.05427168e-01 7.98905313e-01 9.56674740e-02] | |
[ 5.85267730e-02 9.16726947e-01 2.47461870e-02] | |
""" | |
import tensorflow as tf | |
import numpy as np | |
from numpy import genfromtxt | |
# Build Example Data is CSV format, but use Iris data | |
from sklearn import datasets | |
from sklearn.model_selection import train_test_split | |
import sklearn | |
def buildDataFromIris(): | |
iris = datasets.load_iris() | |
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.95, random_state=42) | |
f=open('cs-training.csv','w') | |
for i,j in enumerate(X_train): | |
k=np.append(np.array(y_train[i]),j ) | |
f.write(",".join([str(s) for s in k]) + '\n') | |
f.close() | |
f=open('cs-testing.csv','w') | |
for i,j in enumerate(X_test): | |
k=np.append(np.array(y_test[i]),j ) | |
f.write(",".join([str(s) for s in k]) + '\n') | |
f.close() | |
# Convert to one hot | |
def convertOneHot(data): | |
y=np.array([int(i[0]) for i in data]) | |
y_onehot=[0]*len(y) | |
for i,j in enumerate(y): | |
y_onehot[i]=[0]*(y.max() + 1) | |
y_onehot[i][j]=1 | |
return (y,y_onehot) | |
buildDataFromIris() | |
data = genfromtxt('cs-training.csv',delimiter=',') # Training data | |
test_data = genfromtxt('cs-testing.csv',delimiter=',') # Test data | |
x_train=np.array([ i[1::] for i in data]) | |
y_train,y_train_onehot = convertOneHot(data) | |
x_test=np.array([ i[1::] for i in test_data]) | |
y_test,y_test_onehot = convertOneHot(test_data) | |
# A number of features, 4 in this example | |
# B = 3 species of Iris (setosa, virginica and versicolor) | |
A=data.shape[1]-1 # Number of features, Note first is y | |
B=len(y_train_onehot[0]) | |
tf_in = tf.placeholder("float", [None, A]) # Features | |
tf_weight = tf.Variable(tf.zeros([A,B])) | |
tf_bias = tf.Variable(tf.zeros([B])) | |
tf_softmax = tf.nn.softmax(tf.matmul(tf_in,tf_weight) + tf_bias) | |
# Training via backpropagation | |
tf_softmax_correct = tf.placeholder("float", [None,B]) | |
tf_cross_entropy = -tf.reduce_sum(tf_softmax_correct*tf.log(tf_softmax)) | |
# Train using tf.train.GradientDescentOptimizer | |
tf_train_step = tf.train.GradientDescentOptimizer(0.01).minimize(tf_cross_entropy) | |
# Add accuracy checking nodes | |
tf_correct_prediction = tf.equal(tf.argmax(tf_softmax,1), tf.argmax(tf_softmax_correct,1)) | |
tf_accuracy = tf.reduce_mean(tf.cast(tf_correct_prediction, "float")) | |
# Build the summary operation based on the TF collection of Summaries. | |
summary_op = tf.merge_all_summaries() | |
# Initialize and run | |
init = tf.initialize_all_variables() | |
sess = tf.Session() | |
sess.run(init) | |
saver = tf.train.Saver([tf_weight,tf_bias]) | |
print("...") | |
# Run the training | |
k=[] | |
saved=0 | |
for i in [0]: | |
# sess.run(tf_train_step, feed_dict={tf_in: x_train, tf_softmax_correct: y_train_onehot}) | |
# Print accuracy | |
saver.restore(sess, "./tenIrisSave/saveOne") | |
result = sess.run(tf_accuracy, feed_dict={tf_in: x_test, tf_softmax_correct: y_test_onehot}) | |
print "Run {},{}".format(i,result) | |
k.append(result) | |
ans = sess.run(tf_softmax, feed_dict={tf_in: x_test}) | |
print "Correct prediction\n",ans | |
k=np.array(k) | |
print(np.where(k==k.max())) | |
print "Max: {}".format(k.max()) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment