Last active
October 8, 2017 21:42
-
-
Save Kulbear/18fd4d059cb0930ddda4e0c5b361cced to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets | |
import numpy as np | |
import timeit | |
import tensorflow as tf | |
from pprint import pformat | |
mnist = read_data_sets("data", one_hot=False) | |
NUM_CLASS = 10 | |
STEP = 200 | |
# hypers | |
C = 1. | |
BATCH_SIZE = 128 | |
LEARNING_RATE = 1e-2 | |
def encode_label(labels, target): | |
batch_ys = labels == target | |
batch_ys = batch_ys.astype(int) | |
for i in range(batch_ys.shape[0]): | |
batch_ys[i] = 1 if batch_ys[i] else -1 | |
return batch_ys.reshape([-1, 1]) | |
def svm(x_test, mnist=None): | |
print('Enter run...') | |
reg_term = tf.constant(0.05) | |
X = tf.placeholder(tf.float32, [None, 784], name='x') | |
# W = tf.Variable(tf.truncated_normal([784, 1], name='weight')) | |
W = tf.Variable(tf.zeros([784, 1], name='weight')) | |
b = tf.Variable(tf.zeros([1]), name='b') | |
Y = tf.placeholder(tf.float32, [None, 1], name='y') | |
y_predict = tf.add(tf.matmul(X, W), b) | |
reg_loss = reg_term * tf.reduce_sum(tf.square(W)) | |
hinge_loss = tf.reduce_sum(tf.maximum(0., 1 - Y * y_predict)) | |
svm_loss = reg_loss + C * hinge_loss | |
optimizer = tf.train.GradientDescentOptimizer(LEARNING_RATE) | |
goal = optimizer.minimize(svm_loss) | |
predicted_class = tf.sign(y_predict) | |
correct_prediction = tf.equal(Y, predicted_class) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) | |
for i in range(NUM_CLASS): | |
# print('Enter {}...'.format(i)) | |
x_val, y_val = mnist.validation.next_batch(1000) | |
y_val = encode_label(y_val, i) | |
with tf.Session() as sess: | |
print('Enter {} session...'.format(i)) | |
for stp in range(STEP): | |
tf.global_variables_initializer().run() | |
batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE) | |
batch_ys = encode_label(batch_ys, i) | |
sess.run(goal, feed_dict={X: batch_xs, Y: batch_ys}) | |
if stp % 10 == 0: | |
print('loss: ', sess.run(svm_loss, feed_dict={X: batch_xs, Y: batch_ys})) | |
print("Class", i, "Accuracy on validation:", | |
accuracy.eval(feed_dict={X: x_val, Y: y_val})) | |
def run(algorithm, x_test, y_test, mnist, algorithm_name='Algorithm'): | |
print('Running {}...'.format(algorithm_name)) | |
start = timeit.default_timer() | |
np.random.seed(0) | |
algorithm(x_test, mnist=mnist) | |
for algorithm in [svm]: | |
x_valid, y_valid = mnist.validation._images, mnist.validation.labels | |
# correct_predict, accuracy, run_time = run(algorithm, x_valid, y_valid, mnist, algorithm_name=algorithm.__name__) | |
run(algorithm, x_valid, y_valid, mnist, algorithm_name=algorithm.__name__) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment