Created
November 12, 2015 15:35
-
-
Save jbott/9e039cac0d849f30d90a to your computer and use it in GitHub Desktop.
Scikit-learn iris data set classification using TensorFlow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sklearn | |
import tensorflow as tf | |
from sklearn import datasets | |
iris = datasets.load_iris() | |
# Convert targets to one hot | |
targets_onehot = [0] * len(iris.target) | |
for i,target in enumerate(iris.target): | |
targets_onehot[i] = [0] * 3 | |
targets_onehot[i][target] = 1 | |
from sklearn import cross_validation | |
data_train, data_test, target_train, target_test = cross_validation.train_test_split(iris.data, targets_onehot) | |
print "Training: {}\tTesting: {}".format(len(data_train), len(data_test)) | |
tf_in = tf.placeholder("float", [None, 4]) # Four inputs | |
# Weight and bias variables | |
tf_weight = tf.Variable(tf.zeros([4,3])) | |
tf_bias = tf.Variable(tf.zeros([3])) | |
# Output | |
tf_softmax = tf.nn.softmax(tf.matmul(tf_in,tf_weight) + tf_bias) | |
# Training via backpropagation | |
tf_softmax_correct = tf.placeholder("float", [None,3]) | |
tf_cross_entropy = -tf.reduce_sum(tf_softmax_correct*tf.log(tf_softmax)) | |
# Train using tf.train.GradientDescentOptimizer | |
tf_train_step = tf.train.GradientDescentOptimizer(0.01).minimize(tf_cross_entropy) | |
# Add accuracy checking nodes | |
tf_correct_prediction = tf.equal(tf.argmax(tf_softmax,1), tf.argmax(tf_softmax_correct,1)) | |
tf_accuracy = tf.reduce_mean(tf.cast(tf_correct_prediction, "float")) | |
# Initialize and run | |
init = tf.initialize_all_variables() | |
sess = tf.Session() | |
sess.run(init) | |
# Run the training | |
for i in range(10): | |
sess.run(tf_train_step, feed_dict={tf_in: data_train, tf_softmax_correct: target_train}) | |
# Print accuracy | |
print "Run {}".format(i) | |
print sess.run(tf_accuracy, feed_dict={tf_in: data_test, tf_softmax_correct: target_test}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment