Created
July 26, 2023 04:50
-
-
Save mizoru/e76772f256ef4f26b954eb341b5d147c to your computer and use it in GitHub Desktop.
Gradient descent on the input to a model trained to recognize MNIST digits
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
x = torch.ones(bs, 784) * 0.5 + 0.02 * torch.randn(bs, 784) | |
x.requires_grad_(True).retain_grad() | |
for i in range(0, n//bs): | |
preds = model(x) | |
loss = loss_func(preds, y) | |
loss.backward() | |
with torch.no_grad(): | |
x -= x.grad * lr | |
if x.grad.grad_fn is not None: | |
x.grad.detach_() | |
else: | |
x.grad.requires_grad_(False) | |
x.grad.zero_() | |
model.zero_grad() | |
report(loss, preds, yb) | |
i = 20 | |
plt.matshow(x[i].detach().numpy().reshape(-1, 28)) | |
plt.matshow(xb[i].detach().numpy().reshape(-1, 28)); yb[i] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment