-
-
Save colesbury/23083dba70334ec7126a1a2946031a96 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/2.ASimpleNeuralNetwork/numpy_like_fizbuz.py b/2.ASimpleNeuralNetwork/numpy_like_fizbuz.py | |
index 62fa394..b6a5f5f 100644 | |
--- a/2.ASimpleNeuralNetwork/numpy_like_fizbuz.py | |
+++ b/2.ASimpleNeuralNetwork/numpy_like_fizbuz.py | |
@@ -44,14 +44,18 @@ y = torch.from_numpy(trY).type(dtype) | |
print(x.grad, x.grad_fn, x) | |
# None, None, [torch.FloatTensor of size 900x10] | |
-w1 = torch.randn(input_size, hidden_units, requires_grad=True).type(dtype) | |
+w1 = torch.randn(input_size, hidden_units).type(dtype) | |
w2 = torch.randn(hidden_units, output_size).type(dtype) | |
+w1.requires_grad = True | |
+w2.requires_grad = True | |
print(w1.grad, w1.grad_fn, w1) | |
# None, None, [torch.FloatTensor of size 10x100] | |
-b1 = torch.zeros(1, hidden_units, requires_grad=True).type(dtype) | |
-b2 = torch.zeros(1, output_size, requires_grad=True).type(dtype) | |
+b1 = torch.zeros(1, hidden_units).type(dtype) | |
+b2 = torch.zeros(1, output_size).type(dtype) | |
+b1.requires_grad = True | |
+b2.requires_grad = True | |
no_of_batches = int(len(trX) / batches) | |
for epoch in range(epochs): | |
@@ -86,7 +90,7 @@ for epoch in range(epochs): | |
# Direct manipulation of data outside autograd is not allowed anymore | |
# so this code snippet won't work with pytoch version 0.4+ | |
- try: | |
+ with torch.no_grad(): | |
w1 -= lr * w1.grad | |
w2 -= lr * w2.grad | |
b1 -= lr * b1.grad | |
@@ -95,9 +99,7 @@ for epoch in range(epochs): | |
w2.grad.zero_() | |
b1.grad.zero_() | |
b2.grad.zero_() | |
- except RuntimeError as e: | |
- raise Exception('Direct manipulation of autograd Variable is not allowed in pytorch \ | |
-version 0.4+. Error thrown by pytorch: {}'.format(e)) | |
+ | |
if epoch % 10: | |
print(epoch, output.item()) | |
# traversing the graph using .grad_fn |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment