-
-
Save farhaven/de396307ab1aa2449fc7abb6cbedac58 to your computer and use it in GitHub Desktop.
A Neural Network framework in 25 LOC
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
# Activation functions | |
def tanh(x): | |
return np.tanh(x) | |
# Derivative of tanh from its output | |
def dtanh(y): | |
return 1 - y ** 2 | |
# The neural network framework | |
class Layer: | |
def __init__(self, num_inputs, num_outputs): | |
# Init all weights between [-1 .. 1]. | |
# Each input is connected to all outputs. | |
# One line per input and one column per output. | |
self.weights = np.random.uniform(-1, 1, (num_inputs, num_outputs)) | |
def forward(self, input): | |
self.output = tanh(input.dot(self.weights)) | |
return self.output | |
def computeGradient(self, error): | |
self.delta = error * dtanh(self.output) | |
# Returns the gradient | |
return self.delta.dot(self.weights.T) | |
def updateWeights(self, input, learning_rate): | |
self.weights += input.T.dot(self.delta) * learning_rate | |
class Network: | |
def __init__(self, layers): | |
self.layers = layers | |
def forward(self, input): | |
output = input | |
for layer in self.layers: | |
output = layer.forward(output) | |
return output | |
def backprop(self, input, error, learning_rate): | |
# Compute deltas at each layer starting from the last one | |
for layer in reversed(self.layers): | |
error = layer.computeGradient(error) | |
# Update the weights | |
for layer in self.layers: | |
layer.updateWeights(input, learning_rate) | |
input = layer.output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ python xor.py | |
Training: | |
Epoch 0 MSE: 1.765 | |
Epoch 100 MSE: 0.015 | |
Epoch 200 MSE: 0.005 | |
* Target MSE reached * | |
Evaluating: | |
1 XOR 0 = 1 ( 0.904) Error: 0.096 | |
0 XOR 1 = 1 ( 0.908) Error: 0.092 | |
1 XOR 1 = 0 (-0.008) Error: 0.008 | |
0 XOR 0 = 0 ( 0.000) Error: 0.000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import nn | |
# Training examples | |
examples = [ | |
# _ XOR _ = _ | |
[ np.array([[ 0, 0 ]]), np.array([[ 0 ]]) ], | |
[ np.array([[ 0, 1 ]]), np.array([[ 1 ]]) ], | |
[ np.array([[ 1, 0 ]]), np.array([[ 1 ]]) ], | |
[ np.array([[ 1, 1 ]]), np.array([[ 0 ]]) ] | |
] | |
# Seed random generator to get consistent results | |
np.random.seed(0) | |
# Build the model | |
num_inputs = 2 | |
num_hidden = 20 # nodes in hidden layer | |
num_output = 1 | |
network = nn.Network([ | |
nn.Layer(num_inputs, num_hidden), | |
nn.Layer(num_hidden, num_output) | |
]) | |
# Train the model | |
print("Training:") | |
learning_rate = 0.1 | |
target_mse = 0.01 | |
for epoch in range(500): | |
errors = [] | |
for input, target in examples: | |
# Forward | |
output = network.forward(input) | |
# Compute the error | |
error = target - output | |
errors.append(error) | |
# Back-propagate the error | |
network.backprop(input, error, learning_rate) | |
# Compute the Mean Squared Error of all examples each 100 epoch | |
if epoch % 100 == 0: | |
mse = (np.array(errors) ** 2).mean() | |
print(" Epoch %3d MSE: %.3f" % (epoch, mse)) | |
if mse <= target_mse: | |
print(" * Target MSE reached *") | |
break | |
# Evaluate the model | |
def eval(x, y): | |
output = network.forward(x) | |
normalized = int(round(output[0])) | |
error = y - output[0] | |
return "%d (% .3f) Error: %.3f" % (normalized, output[0], error) | |
print("Evaluating:") | |
print(" 1 XOR 0 = " + eval(np.array([1, 0]), 1)) | |
print(" 0 XOR 1 = " + eval(np.array([0, 1]), 1)) | |
print(" 1 XOR 1 = " + eval(np.array([1, 1]), 0)) | |
print(" 0 XOR 0 = " + eval(np.array([0, 0]), 0)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment