Last active
September 28, 2021 01:05
-
-
Save thomasnield/0da01052001dffc3861daf76b4e80584 to your computer and use it in GitHub Desktop.
sympy_xor_neural_network.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sympy import * | |
a, b, c, d, e, f = symbols('a b c d e f') | |
x1, x2, x3, x4 = symbols('x:4') | |
y1, y2, y3, y4 = symbols('y:4') | |
z1, z2, z3, z4 = symbols('z:4') | |
""" | |
input = Matrix([ | |
[x1, y1], | |
[x2, y2], | |
[x3, y2], | |
[x4, y4] | |
]) | |
output_actual = Matrix([ | |
z1, | |
z2, | |
z3, | |
z4 | |
]) | |
# weights | |
w_hidden = Matrix([ | |
[a, b], # w0->0, w0->1 | |
[c, d] # w1->0, w1->1 | |
]) | |
w_output = Matrix([ | |
e, | |
f | |
]).transpose() | |
""" | |
input = Matrix([ | |
[0, 0], | |
[1, 0], | |
[0, 1], | |
[1, 1] | |
]) | |
output_actual = Matrix([ | |
0, | |
1, | |
1, | |
0 | |
]) | |
w_hidden = Matrix([ | |
[.1, .2], # w0->0, w0->1 | |
[.3, .4] # w1->0, w1->1 | |
]) | |
w_output = Matrix([ | |
.5, | |
.6 | |
]).transpose() | |
# apply weights to input | |
hidden_applied = w_hidden * input.transpose() | |
# print(latex(w_hidden) + " . " + latex(input.transpose()) + " = " + latex(hidden_applied)) | |
output_applied = w_output * hidden_applied | |
# print(latex(w_output) + " . " + latex(w_hidden) + " . " + latex(input.transpose()) + " = " + latex(output_applied)) | |
# Apply sigmoid function | |
sigmoid = lambda x: 1 / (1 + exp(-x)) | |
activated = output_applied.applyfunc(sigmoid) | |
# Calculate errors | |
# backprop output -> hidden | |
output_errors = output_actual.transpose() - activated | |
# print(latex(w_output.transpose()) + " . " + latex(output_errors)) | |
hidden_errors = w_output.transpose() * output_errors | |
print(latex(hidden_errors)) | |
# backprop hidden -> input | |
input_errors = w_hidden.transpose() * hidden_errors | |
print(latex(w_hidden.transpose()) + " . " + latex(hidden_errors) + " = " + latex(input_errors)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment