Skip to content

Instantly share code, notes, and snippets.

@mizoru
Created July 26, 2023 04:50
Show Gist options
  • Save mizoru/e76772f256ef4f26b954eb341b5d147c to your computer and use it in GitHub Desktop.
Save mizoru/e76772f256ef4f26b954eb341b5d147c to your computer and use it in GitHub Desktop.
Gradient descent on the input to a model trained to recognize MNIST digits
x = torch.ones(bs, 784) * 0.5 + 0.02 * torch.randn(bs, 784)
x.requires_grad_(True).retain_grad()
for i in range(0, n//bs):
preds = model(x)
loss = loss_func(preds, y)
loss.backward()
with torch.no_grad():
x -= x.grad * lr
if x.grad.grad_fn is not None:
x.grad.detach_()
else:
x.grad.requires_grad_(False)
x.grad.zero_()
model.zero_grad()
report(loss, preds, yb)
i = 20
plt.matshow(x[i].detach().numpy().reshape(-1, 28))
plt.matshow(xb[i].detach().numpy().reshape(-1, 28)); yb[i]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment