Last active
February 27, 2020 13:09
-
-
Save omarsar/13965950712bb210e35b56fca5f6605a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Neural_Network(nn.Module): | |
def __init__(self, ): | |
super(Neural_Network, self).__init__() | |
# parameters | |
# TODO: parameters can be parameterized instead of declaring them here | |
self.inputSize = 2 | |
self.outputSize = 1 | |
self.hiddenSize = 3 | |
# weights | |
self.W1 = torch.randn(self.inputSize, self.hiddenSize) # 2 X 3 tensor | |
self.W2 = torch.randn(self.hiddenSize, self.outputSize) # 3 X 1 tensor | |
def forward(self, X): | |
self.z = torch.matmul(X, self.W1) # 3 X 3 ".dot" does not broadcast in PyTorch | |
self.z2 = self.sigmoid(self.z) # activation function | |
self.z3 = torch.matmul(self.z2, self.W2) | |
o = self.sigmoid(self.z3) # final activation function | |
return o | |
def sigmoid(self, s): | |
return 1 / (1 + torch.exp(-s)) | |
def sigmoidPrime(self, s): | |
# derivative of sigmoid | |
return s * (1 - s) | |
def backward(self, X, y, o): | |
self.o_error = y - o # error in output | |
self.o_delta = self.o_error * self.sigmoidPrime(o) # derivative of sig to error | |
self.z2_error = torch.matmul(self.o_delta, torch.t(self.W2)) | |
self.z2_delta = self.z2_error * self.sigmoidPrime(self.z2) | |
self.W1 += torch.matmul(torch.t(X), self.z2_delta) | |
self.W2 += torch.matmul(torch.t(self.z2), self.o_delta) | |
def train(self, X, y): | |
# forward + backward pass for training | |
o = self.forward(X) | |
self.backward(X, y, o) | |
def saveWeights(self, model): | |
# we will use the PyTorch internal storage functions | |
torch.save(model, "NN") | |
# you can reload model with all the weights and so forth with: | |
# torch.load("NN") | |
def predict(self): | |
print ("Predicted data based on trained weights: ") | |
print ("Input (scaled): \n" + str(xPredicted)) | |
print ("Output: \n" + str(self.forward(xPredicted))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@Rahul-python
I admit the code can be improved and cleaned a bit. Your equation is right. But if you look closely, we calculated sigmoid in the
forward
pass, that's what the variableo
represents. So we pass thato
which becomess
in this particularsigoimdPrime
function which is the sigmoid applied already. That's why it ends up beings * (1-s)
as opposed to what you propose. Hopefully, it's a bit clearer now. Thanks for the feedback. I may work on improving this tutorial and provide more details and intuitions on design choices. :)