Last active
February 27, 2020 13:09
-
-
Save omarsar/13965950712bb210e35b56fca5f6605a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Neural_Network(nn.Module): | |
def __init__(self, ): | |
super(Neural_Network, self).__init__() | |
# parameters | |
# TODO: parameters can be parameterized instead of declaring them here | |
self.inputSize = 2 | |
self.outputSize = 1 | |
self.hiddenSize = 3 | |
# weights | |
self.W1 = torch.randn(self.inputSize, self.hiddenSize) # 2 X 3 tensor | |
self.W2 = torch.randn(self.hiddenSize, self.outputSize) # 3 X 1 tensor | |
def forward(self, X): | |
self.z = torch.matmul(X, self.W1) # 3 X 3 ".dot" does not broadcast in PyTorch | |
self.z2 = self.sigmoid(self.z) # activation function | |
self.z3 = torch.matmul(self.z2, self.W2) | |
o = self.sigmoid(self.z3) # final activation function | |
return o | |
def sigmoid(self, s): | |
return 1 / (1 + torch.exp(-s)) | |
def sigmoidPrime(self, s): | |
# derivative of sigmoid | |
return s * (1 - s) | |
def backward(self, X, y, o): | |
self.o_error = y - o # error in output | |
self.o_delta = self.o_error * self.sigmoidPrime(o) # derivative of sig to error | |
self.z2_error = torch.matmul(self.o_delta, torch.t(self.W2)) | |
self.z2_delta = self.z2_error * self.sigmoidPrime(self.z2) | |
self.W1 += torch.matmul(torch.t(X), self.z2_delta) | |
self.W2 += torch.matmul(torch.t(self.z2), self.o_delta) | |
def train(self, X, y): | |
# forward + backward pass for training | |
o = self.forward(X) | |
self.backward(X, y, o) | |
def saveWeights(self, model): | |
# we will use the PyTorch internal storage functions | |
torch.save(model, "NN") | |
# you can reload model with all the weights and so forth with: | |
# torch.load("NN") | |
def predict(self): | |
print ("Predicted data based on trained weights: ") | |
print ("Input (scaled): \n" + str(xPredicted)) | |
print ("Output: \n" + str(self.forward(xPredicted))) |
I believe the function sigmoidPrime is defined wrong. it should be sigmoid(s)* (1-sigmoid(s)). Please correct me if I am wrong.
@Rahul-python
I admit the code can be improved and cleaned a bit. Your equation is right. But if you look closely, we calculated sigmoid in the forward
pass, that's what the variable o
represents. So we pass that o
which becomes s
in this particular sigoimdPrime
function which is the sigmoid applied already. That's why it ends up being s * (1-s)
as opposed to what you propose. Hopefully, it's a bit clearer now. Thanks for the feedback. I may work on improving this tutorial and provide more details and intuitions on design choices. :)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Indeed it is 2 X 3. Thank you for spotting that.