Skip to content

Instantly share code, notes, and snippets.

@justanotherminh
Created September 5, 2016 09:45
Show Gist options
  • Save justanotherminh/bc354f112b3555a3058f6bfcf12702bf to your computer and use it in GitHub Desktop.
Save justanotherminh/bc354f112b3555a3058f6bfcf12702bf to your computer and use it in GitHub Desktop.
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
def softmax(logits):
e = np.exp(logits)
return e / np.expand_dims(e.sum(axis=1), axis=1)
X_train, y_train = mnist.train.images, mnist.train.labels
X_test, y_test = mnist.test.images, mnist.test.labels
data_size = 55000
batch_size = 128
input_dim = 784
output_dim = 10
lr = 0.001
vW, vb = 0., 0.
gamma = 0.5
W = np.random.normal(0.0, 0.1, [input_dim, output_dim])
b = np.zeros([1, output_dim])
def predict(x, W, b):
a = x.dot(W) + b
return softmax(a)
def loss(prob, label):
ce = -label * np.log(prob)
return ce.mean(axis=0).sum()
def get_gradient(prob, label, x):
# For softmax classifier, dL/dout = out - label
dout = prob - label
db = dout.sum(axis=0)
dW = x.T.dot(dout)
return dW, db
def get_data_batch():
global X_train, y_train
mask = np.random.choice(data_size, batch_size, replace=False)
return X_train[mask, :], y_train[mask, :]
def train():
global W, b, vW, vb
for i in xrange(10000):
X, y = get_data_batch()
probs = predict(X, W, b)
if i % 100 == 0:
print 'Loss: %s' % loss(probs, y)
dW, db = get_gradient(probs, y, X)
# Update parameters
vW = gamma * vW + lr * dW
vb = gamma * vb + lr * db
W -= vW
b -= vb
np.save('saved_networks/W.npy', W)
np.save('saved_networks/b.npy', b)
def test():
global X_test, y_test
W = np.load('saved_networks/W.npy')
b = np.load('saved_networks/b.npy')
y = predict(X_test, W, b)
acc = y.argmax(axis=1) == y_test.argmax(axis=1)
print 'Loss: %s' % loss(y, y_test)
print 'Accuracy: %s' % acc.mean()
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment