Last active
September 12, 2020 18:21
-
-
Save benoitdescamps/6bc4bbe36b3b2b8b7385d31bdee12b49 to your computer and use it in GitHub Desktop.
Pytorch implementation of Reptile as Ravi, et.al. (https://openreview.net/pdf?id=rJY0-Kcll)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Reptile: | |
""" | |
Repile-optimization as described by Ravi,et.al. (https://openreview.net/pdf?id=rJY0-Kcl) | |
""" | |
def __init__(self, | |
model:torch.nn.Module, | |
metalearners:List[MetaLearner]): | |
self.n_tasks = len(metalearners) | |
self.model = model | |
self.metalearners = metalearners | |
def metatraining_step(self,x,y,idxs=None,steps:int=1): | |
for i,(idx,xp,yp) in enumerate(zip(idxs,x,y)): | |
metalearner = self.metalearners[idx] | |
for _ in range(steps): | |
metalearner.training_step(xp,yp) | |
def training_step(self,x,y,idxs=None,metalearning_steps:int=100,learning_rate:float=0.01): | |
for metalearner in self.metalearners: | |
metalearner.model = transfer_model(self.model,metalearner.model) | |
self.metatraining_step(x,y,idxs,metalearning_steps) | |
with torch.no_grad(): | |
for args in zip(self.model.parameters(),*[meta.model.parameters() for meta in self.metalearners]): | |
param = args[0] | |
W = torch.stack([args[i] for i in range(1,len(args))],dim=0).mean(dim=0) | |
param += learning_rate * (W-param) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment