Created
November 11, 2020 11:49
-
-
Save devil-cyber/4ca9a65b217b9741c984757e3a9d4a30 to your computer and use it in GitHub Desktop.
Backpropagation using Pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
''' | |
Backpropgation | |
eg. x = [1.,1.,1] -> "This is the fixed feature value:" | |
y = x + 2 -> a(x) -> "This is the respective y value:" | |
z = y*y*3 -> b(y) -> "This is the example for loss function" | |
Backpropgation x -> a(x) -> b(y) -> z | |
dz/dx = dz/dy * dy/dx (Chain Rule) | |
''' | |
# Find the gradient provided the respective x, y, w | |
x = torch.tensor(1.0) | |
y = torch.tensor(2.0) | |
w = torch.tensor(-2.0, requires_grad=True) | |
# forward path and compute loss | |
y_hat =w * x | |
loss = (y_hat - y)**2 | |
print(loss) | |
# backward pass | |
loss.backward() | |
print(w.grad) | |
## TODO: update weight | |
## TODO: next next forward and backward |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment