-
-
Save renesax14/8499e0314351ea4199a17e494bff5c4d to your computer and use it in GitHub Desktop.
# base on the paper "OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING": https://openreview.net/pdf?id=rJY0-Kcll | |
class EmptySimpleMetaLstm(Optimizer): | |
def __init__(self, params, trainable_opt_model, trainable_opt_state, *args, **kwargs): | |
defaults = { | |
'trainable_opt_model':trainable_opt_model, | |
'trainable_opt_state':trainable_opt_state, | |
'args':args, | |
'kwargs':kwargs | |
} | |
super().__init__(params, defaults) | |
class SimpleMetaLstm(DifferentiableOptimizer): | |
def _update(self, grouped_grads, **kwargs): | |
prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr'] | |
eta = self.param_groups[0]['trainable_opt_model']['eta'] | |
# start differentiable & trainable update | |
zipped = zip(self.param_groups, grouped_grads) | |
for group_idx, (group, grads) in enumerate(zipped): | |
for p_idx, (p, g) in enumerate(zip(group['params'], grads)): | |
if g is None: | |
continue | |
# get gradient as "data" | |
g = g.detach() # gradients of gradients are not used (no hessians) | |
## very simplified version of meta-lstm meta-learner | |
input_metalstm = torch.stack([p, g, prev_lr.view(1,1)]).view(1,3) # [p, g, prev_lr] note it's missing loss, normalization etc. see original paper | |
lr = eta(input_metalstm).view(1) | |
fg = 1 - lr # learnable forget rate | |
## update suggested by meta-lstm meta-learner | |
p_new = fg*p - lr*g | |
group['params'][p_idx] = p_new | |
# fake returns | |
self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr | |
higher.register_optim(EmptySimpleMetaLstm, SimpleMetaLstm) | |
def test_parametrized_inner_optimizer(): | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from collections import OrderedDict | |
## training config | |
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
track_higher_grads = True # if True, during unrolled optimization the graph be retained, and the fast weights will bear grad funcs, so as to permit backpropagation through the optimization process. False during test time for efficiency reasons | |
copy_initial_weights = False # if False then we train the base models initial weights (i.e. the base model's initialization) | |
episodes = 5 | |
nb_inner_train_steps = 5 | |
## get base model | |
base_mdl = nn.Sequential(OrderedDict([ | |
('fc', nn.Linear(1,1, bias=False)), | |
('relu', nn.ReLU()) | |
])) | |
## parametrization/mdl for the inner optimizer | |
opt_mdl = nn.Sequential(OrderedDict([ | |
('fc', nn.Linear(3,1, bias=False)), # 3 inputs 1 for parameter, 1 for gradient, 1 for previous lr | |
('sigmoid', nn.Sigmoid()) | |
])) | |
## get outer optimizer (not differentiable nor trainable) | |
outer_opt = optim.Adam([{'params': base_mdl.parameters()},{'params': opt_mdl.parameters()}], lr=0.01) | |
for episode in range(episodes): | |
## get fake support & query data (from a single task and 1 data point) | |
spt_x, spt_y, qry_x, qry_y = torch.randn(1), torch.randn(1), torch.randn(1), torch.randn(1) | |
## get differentiable & trainable (parametrized) inner optimizer | |
inner_opt = EmptySimpleMetaLstm(base_mdl.parameters(), trainable_opt_model={'eta': opt_mdl}, trainable_opt_state={'prev_lr': 0.9*torch.randn(1)}) | |
with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=copy_initial_weights, track_higher_grads=track_higher_grads) as (fmodel, diffopt): | |
for i_inner in range(nb_inner_train_steps): # this current version implements full gradient descent on k_shot examples (which is usually small 5) | |
fmodel.train() | |
# base/child model forward pass | |
inner_loss = 0.5*((fmodel(spt_x) - spt_y))**2 | |
# inner-opt update | |
diffopt.step(inner_loss) | |
## Evaluate on query set for current task | |
qry_loss = 0.5*((fmodel(qry_x) - qry_y))**2 | |
qry_loss.backward() # for memory efficient computation | |
## outer update | |
print(f'episode = {episode}') | |
print(f'base_mdl.grad = {base_mdl.fc.weight.grad}') | |
print(f'opt_mdl.grad = {opt_mdl.fc.weight.grad}') | |
outer_opt.step() | |
outer_opt.zero_grad() | |
if __name__ == '__main__': | |
test_parametrized_inner_optimizer() | |
print('Done \a') | |
""" | |
output when deep copy is uncommented (parametrized optimizer trains properly): | |
episode = 0 | |
base_mdl.grad = tensor([[-0.0351]]) | |
opt_mdl.grad = tensor([[0.0085, 0.0000, 0.0204]]) | |
episode = 1 | |
base_mdl.grad = tensor([[0.0311]]) | |
opt_mdl.grad = tensor([[-0.0086, -0.0100, 0.0358]]) | |
episode = 2 | |
base_mdl.grad = tensor([[0.]]) | |
opt_mdl.grad = tensor([[0., 0., 0.]]) | |
episode = 3 | |
base_mdl.grad = tensor([[0.0066]]) | |
opt_mdl.grad = tensor([[-0.0016, 0.0000, -0.0032]]) | |
episode = 4 | |
base_mdl.grad = tensor([[-0.0311]]) | |
opt_mdl.grad = tensor([[0.0077, 0.0000, 0.0130]]) | |
Done | |
when deep copy is on (paremeters of inner optimizer are not train, sad!): | |
episode = 0 | |
base_mdl.grad = tensor([[0.]]) | |
opt_mdl.grad = None | |
episode = 1 | |
base_mdl.grad = tensor([[0.]]) | |
opt_mdl.grad = None | |
episode = 2 | |
base_mdl.grad = tensor([[0.0069]]) | |
opt_mdl.grad = None | |
episode = 3 | |
base_mdl.grad = tensor([[0.]]) | |
opt_mdl.grad = None | |
episode = 4 | |
base_mdl.grad = tensor([[0.]]) | |
opt_mdl.grad = None | |
Done | |
The deep copy line in higher I am referencing: | |
self.param_groups = _copy.deepcopy(other.param_groups) | |
#self.param_groups = other.param_groups | |
""" |
override version does not work:
class EmptySimpleMetaLstm(Optimizer):
def __init__(self, params, *args, **kwargs):
defaults = { 'args':args, 'kwargs':kwargs}
super().__init__(params, defaults)
class SimpleMetaLstm(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
prev_lr = self.override['trainable_opt_state']['prev_lr']
simp_meta_lstm = self.override['trainable_opt_model']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
# get gradient as "data"
g = g.detach() # gradients of gradients are not used (no hessians)
## very simplified version of meta-lstm meta-learner
input_metalstm = torch.stack([p, g, prev_lr.view(1,1)]).view(1,3) # [p, g, prev_lr] note it's missing loss, normalization etc. see original paper
lr = simp_meta_lstm(input_metalstm).view(1)
fg = 1 - lr # learnable forget rate
## update suggested by meta-lstm meta-learner
p_new = fg*p - lr*g
group['params'][p_idx] = p_new
# fake returns
self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr
higher.register_optim(EmptySimpleMetaLstm, SimpleMetaLstm)
####
####
def test_parametrized_inner_optimizer():
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict
## training config
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
track_higher_grads = True # if True, during unrolled optimization the graph be retained, and the fast weights will bear grad funcs, so as to permit backpropagation through the optimization process. False during test time for efficiency reasons
copy_initial_weights = False # if False then we train the base models initial weights (i.e. the base model's initialization)
episodes = 5
nb_inner_train_steps = 5
## get base model
base_mdl = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1, bias=False)),
('relu', nn.ReLU())
]))
## parametrization/mdl for the inner optimizer
opt_mdl = nn.Sequential(OrderedDict([
('fc', nn.Linear(3,1, bias=False)), # 3 inputs [p, g, prev_lr] 1 for parameter, 1 for gradient, 1 for previous lr
('sigmoid', nn.Sigmoid())
]))
## get outer optimizer (not differentiable nor trainable)
outer_opt = optim.Adam([{'params': base_mdl.parameters()},{'params': opt_mdl.parameters()}], lr=0.01)
for episode in range(episodes):
## get fake support & query data (from a single task and 1 data point)
spt_x, spt_y, qry_x, qry_y = torch.randn(1), torch.randn(1), torch.randn(1), torch.randn(1)
## get differentiable & trainable (parametrized) inner optimizer
override = {'trainable_opt_model': opt_mdl, 'trainable_opt_state': {'prev_lr': 0.9*torch.randn(1)} }
inner_opt = EmptySimpleMetaLstm(base_mdl.parameters())
with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=copy_initial_weights, track_higher_grads=track_higher_grads, override=override) as (fmodel, diffopt):
for i_inner in range(nb_inner_train_steps): # this current version implements full gradient descent on k_shot examples (which is usually small 5)
fmodel.train()
# base/child model forward pass
inner_loss = 0.5*((fmodel(spt_x) - spt_y))**2
# inner-opt update
diffopt.step(inner_loss)
## Evaluate on query set for current task
qry_loss = 0.5*((fmodel(qry_x) - qry_y))**2
qry_loss.backward() # for memory efficient computation
## outer update
print(f'episode = {episode}')
print(f'base_mdl.grad = {base_mdl.fc.weight.grad}')
print(f'opt_mdl.grad = {opt_mdl.fc.weight.grad}')
outer_opt.step()
outer_opt.zero_grad()
the real solution is if I could pass an arbitrary dictionary to a differentiable optimizer and if I could do whatever I wanted with it.
Perhaps just creating my own field once the diffopt
is created is all I need?
so this line:
diffopt.override = {'trainable_opt_model': opt_mdl, 'trainable_opt_state': {'prev_lr': 0.9*torch.randn(1)} }
whole:
def test_parametrized_inner_optimizer():
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict
## training config
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
track_higher_grads = True # if True, during unrolled optimization the graph be retained, and the fast weights will bear grad funcs, so as to permit backpropagation through the optimization process. False during test time for efficiency reasons
copy_initial_weights = False # if False then we train the base models initial weights (i.e. the base model's initialization)
episodes = 5
nb_inner_train_steps = 5
## get base model
base_mdl = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1, bias=False)),
('act', nn.ReLU())
]))
## parametrization/mdl for the inner optimizer
opt_mdl = nn.Sequential(OrderedDict([
('fc', nn.Linear(3,1, bias=False)), # 3 inputs [p, g, prev_lr] 1 for parameter, 1 for gradient, 1 for previous lr
('act', nn.LeakyReLU())
]))
## get outer optimizer (not differentiable nor trainable)
outer_opt = optim.Adam([{'params': base_mdl.parameters()},{'params': opt_mdl.parameters()}], lr=0.01)
for episode in range(episodes):
## get fake support & query data (from a single task and 1 data point)
spt_x, spt_y, qry_x, qry_y = torch.randn(1), torch.randn(1), torch.randn(1), torch.randn(1)
## get differentiable & trainable (parametrized) inner optimizer
inner_opt = EmptySimpleMetaLstm( base_mdl.parameters() )
with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=copy_initial_weights, track_higher_grads=track_higher_grads) as (fmodel, diffopt):
diffopt.override = {'trainable_opt_model': opt_mdl, 'trainable_opt_state': {'prev_lr': 0.9*torch.randn(1)} }
for i_inner in range(nb_inner_train_steps): # this current version implements full gradient descent on k_shot examples (which is usually small 5)
fmodel.train()
# base/child model forward pass
inner_loss = 0.5*((fmodel(spt_x) - spt_y))**2
# inner-opt update
diffopt.step(inner_loss)
## Evaluate on query set for current task
qry_loss = 0.5*((fmodel(qry_x) - qry_y))**2
qry_loss.backward() # for memory efficient computation
## outer update
print(f'episode = {episode}')
print(f'base_mdl.grad = {base_mdl.fc.weight.grad}')
print(f'opt_mdl.grad = {opt_mdl.fc.weight.grad}')
outer_opt.step()
outer_opt.zero_grad()
I think this is all I need:
inner_opt = EmptySimpleMetaLstm( base_mdl.parameters() )
with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=copy_initial_weights, track_higher_grads=track_higher_grads) as (fmodel, diffopt):
diffopt.override = {'trainable_opt_model': opt_mdl, 'trainable_opt_state': {'prev_lr': 0.9*torch.randn(1)} }
egrefen commented on Mar 5
That line of code (deep copy) is important, as we want to safely branch off the state of the optimizer as used in the outer loop and return to it (or not touch it in the first place, which is what we do with the copy here) at the end of the unrolled inner loop.
Use override, please.
Contributor
egrefen commented on Mar 6
Override is a kwarg for differentiable optims (at creation, or step time,
and you can also use it with the context manager) which allows you to use
arbitrary tensors instead of values held in the optimizer state. For
example, you could override the learning rate with a tensor which requires
grad, which would allow you to unroll your loops, take gradient of the
meta-loss with regard to the learning rate, and update this tensor.
See https://higher.readthedocs.io/en/latest/optim.html for details,
#32 (comment) for
a similar explanation, and https://github.com/denisyarats/densenet_cifar10 for
an example.
from this issue:
mentioned in this issue: facebookresearch/higher#62
SO: https://stackoverflow.com/questions/62459891/how-does-one-implemented-a-parametrized-meta-learner-in-pytorchs-higher-library
forum: https://discuss.pytorch.org/t/how-does-one-implemented-a-parametrized-meta-learner-in-pytorchs-higher-library/85988