Last active
November 11, 2021 14:10
-
-
Save thomasnield/a6a2e24139b6d846f1434644b29382a1 to your computer and use it in GitHub Desktop.
sympy_neural_network_derivatives.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sympy import * | |
# Declare weight, bias, layer outputs, and input variables | |
# W1 and W2 are hidden and output layer weights | |
# B1 and B2 are hidden and output layer biases | |
# A1 and A2 are activated layer outputs | |
# Z1 and Z2 are inactivated layer outputs | |
# X and Y are the training data | |
W1, W2, B1, B2, A1, A2, Z1, Z2, X, Y = \ | |
symbols('W1 W2 B1 B2 A1 A2 Z1 Z2 X Y') | |
# Calculate derivative of cost function with respect to A2 | |
C = (A2 - Y)**2 | |
dC_dA2 = diff(C, A2) | |
print("dC_dA2 = ", dC_dA2) # 2*A2 - 2*Y | |
# Hidden layer uses ReLU function | |
# Output layer uses logistic | |
# Calculate derivative of A2 with respect to Z2 | |
logistic = lambda x: 1 / (1 + exp(-x)) | |
_A2 = logistic(Z2) | |
dA2_dZ2 = diff(_A2, Z2) | |
print("dA2_dZ2 = ", dA2_dZ2) # exp(-Z2)/(1 + exp(-Z2))**2 | |
# Calculate derivative of Z2 with respect to A1 | |
_Z2 = A1*W2 + B2 | |
dZ2_dA1 = diff(_Z2, A1) | |
print("dZ2_dA1 = ", dZ2_dA1) # W2 | |
# Calculate derivative of Z2 with respect to W2 | |
dZ2_dW2 = diff(_Z2, W2) | |
print("dZ2_dW2 = ", dZ2_dW2) # A1 | |
# Calculate derivative of Z2 with respect to B2 | |
dZ2_dB2 = diff(_Z2, B2) | |
print("dZ2_dB2 = ", dZ2_dB2) # 1 | |
# Calculate derivative of A1 with respect to Z1 | |
relu = lambda x: Max(x, 0) | |
_A1 = relu(Z1) | |
d_relu = lambda x: x > 0 # Slope is 1 if positive, 0 if negative | |
dA1_dZ1 = d_relu(Z1) | |
print("dA1_dZ1 = ", dA1_dZ1) # Z1 > 0 | |
# Calculate derivative of Z1 with respect to W1 | |
_Z1 = X*W1 + B1 | |
dZ1_dW1 = diff(_Z1, W1) | |
print("dZ1_dW1 = ", dZ1_dW1) # X | |
# Calculate derivative of Z1 with respect to B1 | |
dZ1_dB1 = diff(_Z1, B1) | |
print("dZ1_dB1 = ", dZ1_dB1) # 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment