Skip to content

Instantly share code, notes, and snippets.

@jbott
Created November 12, 2015 15:35
Show Gist options
  • Save jbott/9e039cac0d849f30d90a to your computer and use it in GitHub Desktop.
Save jbott/9e039cac0d849f30d90a to your computer and use it in GitHub Desktop.
Scikit-learn iris data set classification using TensorFlow
import sklearn
import tensorflow as tf
from sklearn import datasets
iris = datasets.load_iris()
# Convert targets to one hot
targets_onehot = [0] * len(iris.target)
for i,target in enumerate(iris.target):
targets_onehot[i] = [0] * 3
targets_onehot[i][target] = 1
from sklearn import cross_validation
data_train, data_test, target_train, target_test = cross_validation.train_test_split(iris.data, targets_onehot)
print "Training: {}\tTesting: {}".format(len(data_train), len(data_test))
tf_in = tf.placeholder("float", [None, 4]) # Four inputs
# Weight and bias variables
tf_weight = tf.Variable(tf.zeros([4,3]))
tf_bias = tf.Variable(tf.zeros([3]))
# Output
tf_softmax = tf.nn.softmax(tf.matmul(tf_in,tf_weight) + tf_bias)
# Training via backpropagation
tf_softmax_correct = tf.placeholder("float", [None,3])
tf_cross_entropy = -tf.reduce_sum(tf_softmax_correct*tf.log(tf_softmax))
# Train using tf.train.GradientDescentOptimizer
tf_train_step = tf.train.GradientDescentOptimizer(0.01).minimize(tf_cross_entropy)
# Add accuracy checking nodes
tf_correct_prediction = tf.equal(tf.argmax(tf_softmax,1), tf.argmax(tf_softmax_correct,1))
tf_accuracy = tf.reduce_mean(tf.cast(tf_correct_prediction, "float"))
# Initialize and run
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
# Run the training
for i in range(10):
sess.run(tf_train_step, feed_dict={tf_in: data_train, tf_softmax_correct: target_train})
# Print accuracy
print "Run {}".format(i)
print sess.run(tf_accuracy, feed_dict={tf_in: data_test, tf_softmax_correct: target_test})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment