Created
October 3, 2016 13:11
-
-
Save ottokart/1465fb65ec9d866aa3cf00ca5438d50a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# A little demo illustrating the effect of momentum in neural network training. | |
# Try using different values for MOMENTUM constant below (e.g. compare 0.0 with 0.9). | |
# This neural network is actually more like logistic regression, but I have used | |
# squared error to make the error surface more interesting. | |
import numpy as np | |
import pylab | |
from mpl_toolkits.mplot3d import Axes3D | |
# Play with this value (0.0 = just gradient descent) | |
MOMENTUM = 0.0 | |
################# | |
W_SPACE_LOW = -1. | |
W_SPACE_HIGH = 1. | |
np.random.seed(5) | |
# Learn to predict whether the first input is larger than second. | |
# Inputs | |
X = np.array([ | |
[1, 4], | |
[7, 2], | |
[5, 1], | |
[2, 4], | |
[3, 9], | |
[9, 8], | |
[2, 5], | |
]).astype(np.float64) | |
# Targets | |
T = np.array([ | |
0, | |
1, | |
1, | |
0, | |
0, | |
1, | |
0, | |
]).astype(np.float64) | |
class SquaredError(object): | |
"""output 'y' and target 't'""" | |
@staticmethod | |
def E(y, t): | |
return 0.5*(t - y)**2 | |
@staticmethod | |
def dE_dy(t, y): | |
"""dE/dy - error derivative with respect to output""" | |
return -(t - y) | |
class SigmoidActivation(object): | |
@staticmethod | |
def y(z): | |
return 1. / (1. + np.exp(-z)) | |
@staticmethod | |
def dy_dz(y): | |
"""dy/dz - derivative of output with respect to input""" | |
return y * (1. - y) | |
class Net(object): | |
def __init__(self, learning_rate, momentum=None): | |
super(Net, self).__init__() | |
self.learning_rate = learning_rate | |
self.W = np.random.uniform(low=W_SPACE_LOW, high=W_SPACE_HIGH, size=2) | |
self.momentum = momentum | |
if self.momentum is not None: | |
self.W_momentum = np.zeros_like(self.W) | |
def forward(self, x): | |
z = np.dot(self.W, x.T) | |
self.y = SigmoidActivation.y(z) | |
def backward(self, x, t): | |
self.dE_dW = np.sum(SquaredError.dE_dy(t, self.y) * SigmoidActivation.dy_dz(self.y) * x.T, axis=1) | |
def update(self): | |
if self.momentum is not None: | |
self.W_momentum *= self.momentum | |
self.W_momentum += self.learning_rate * self.dE_dW | |
self.W -= self.W_momentum | |
else: | |
self.W -= self.learning_rate * self.dE_dW | |
# Draw error surface | |
net = Net(learning_rate=0.1) | |
errors = [] | |
W1, W2 = np.meshgrid(np.linspace(W_SPACE_LOW, W_SPACE_HIGH, num=100), np.linspace(W_SPACE_LOW, W_SPACE_HIGH, num=100)) | |
for w1, w2 in zip(np.ravel(W1), np.ravel(W2)): | |
net.W = np.array([w1, w2]) | |
net.forward(X) | |
errors.append(np.sum(SquaredError.E(net.y, T))) | |
E = np.array(errors).reshape(W1.shape) | |
fig = pylab.figure(figsize=(15,12)) | |
ax = fig.gca(projection='3d') | |
ax.plot_surface(W1, W2, E, cmap=pylab.cm.coolwarm, alpha=0.5, lw=0, rstride=2, cstride=2) | |
ax.set_xlabel('w1') | |
ax.set_ylabel('w2') | |
ax.set_zlabel('error') | |
pylab.draw() | |
# Train network | |
net = Net(learning_rate=0.1, momentum=MOMENTUM) | |
error = np.inf | |
i = 0 | |
old_w0 = old_w1 = old_e = None | |
while True: | |
net.forward(X) | |
old_error = error | |
error = np.sum(SquaredError.E(net.y, T)) | |
net.backward(X, T) | |
if error > old_error*0.99: | |
break | |
ax.scatter3D(net.W[0], net.W[1], error, s=10, c='red', lw=0) | |
if old_e is not None: | |
ax.plot([old_w0, net.W[0]], [old_w1, net.W[1]], [old_e, error], color = 'b') | |
pylab.draw() | |
old_w0 = net.W[0] | |
old_w1 = net.W[1] | |
old_e = error | |
net.update() | |
i += 1 | |
print "%d iterations" % i | |
pylab.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment