Skip to content

Instantly share code, notes, and snippets.

@unixpickle
Created October 12, 2019 19:08
Show Gist options
  • Save unixpickle/0981d4cd8efead8b40ab27de1af0733c to your computer and use it in GitHub Desktop.
Save unixpickle/0981d4cd8efead8b40ab27de1af0733c to your computer and use it in GitHub Desktop.
MAML in PyTorch
import torch
import torch.nn.functional as F
def maml_grad(model, inputs, outputs, lr, batch=1):
"""
Update a model's gradient using MAML.
The gradient will point in the direction that
improves the total loss across all inner-loop
mini-batches.
Args:
model: an nn.Module for training.
inputs: a large batch of model inputs.
outputs: a large batch of model outputs.
lr: the inner-loop SGD learning rate.
batch: the inner-loop batch size.
"""
params = list(model.parameters())
device = params[0].device
initial_values = []
final_values = []
losses = []
scalar_losses = []
for i in range(0, inputs.shape[0], batch):
x = inputs[i:i+batch]
y = outputs[i:i+batch]
target = y.to(device)
out = model(x.to(device))
if target.dtype.is_floating_point:
loss = F.binary_cross_entropy_with_logits(out, target)
else:
loss = F.cross_entropy(out, target)
losses.append(loss)
scalar_losses.append(loss.item())
initial_values.append([p.clone().detach() for p in params])
updated = []
grads = torch.autograd.grad(loss, params, create_graph=True, retain_graph=True)
for grad, param in zip(grads, params):
x = param - lr * grad
updated.append(x)
param.data.copy_(x)
final_values.append(updated)
gradient = [torch.zeros_like(p) for p in params]
for loss, initial, final in list(zip(losses, initial_values, final_values))[::-1]:
for p, x in zip(params, initial):
p.data.copy_(x)
grad1 = torch.autograd.grad(loss, params, retain_graph=True)
grad2 = torch.autograd.grad(final, params, grad_outputs=gradient, retain_graph=True)
gradient = [v1 + v2 for v1, v2 in zip(grad1, grad2)]
for p, g in zip(params, gradient):
if p.grad is None:
p.grad = g
else:
p.grad.add_(g)
return scalar_losses
@hzyjerry
Copy link

I see!! Very neat!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment