Skip to content

Instantly share code, notes, and snippets.

@shi3z
Last active February 8, 2023 15:59
Show Gist options
  • Save shi3z/35e096de556daff2972dec66f77bfee2 to your computer and use it in GitHub Desktop.
Save shi3z/35e096de556daff2972dec66f77bfee2 to your computer and use it in GitHub Desktop.
Reptile meta learning on Keras
# Reptile example on keras
# Original code on pytorch at https://blog.openai.com/reptile/
# Porting to keras by shi3z Apr. 12 2018
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras.layers import *
from keras.models import *
from keras import backend as K
from copy import deepcopy
seed = 0
plot = True
innerstepsize = 0.02 # stepsize in inner SGD
innerepochs = 1 # number of epochs of each inner SGD
outerstepsize0 = 0.1 # stepsize of outer optimization, i.e., meta-optimization
niterations = 30000 # number of outer updates; each iteration we sample one task and update on it
rng = np.random.RandomState(seed)
# Define task distribution
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points
ntrain = 10 # Size of training minibatches
def gen_task():
"Generate classification problem"
phase = rng.uniform(low=0, high=2*np.pi)
ampl = rng.uniform(0.1, 5)
f_randomsine = lambda x : np.sin(x + phase) * ampl
return f_randomsine
model = Sequential()
model.add(Dense(64,input_shape=(1,),activation='tanh'))
model.add(Dense(64,activation='tanh'))
model.add(Dense(1))
model.compile(loss='mse', optimizer='sgd')
f_plot = gen_task()
xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)]
def train_on_batch(x,y):
weights_before = model.weights
model.train_on_batch(x,y)
weights_after = model.weights
for i in range(len(weights_after)):
model.weights[i] = (weights_after[i]-
(weights_after[i]-weights_before[i])*innerstepsize)
# Reptile training loop
for iteration in range(niterations):
weights_before = model.weights
# Generate task
f = gen_task()
y_all = f(x_all)
# Do SGD on this task
inds = rng.permutation(len(x_all))
for _ in range(innerepochs):
for start in range(0, len(x_all), ntrain):
mbinds = inds[start:start+ntrain]
train_on_batch(x_all[mbinds], y_all[mbinds])
weights_after = model.weights
outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule
for i in range(len(weights_after)):
model.weights[i] = (weights_before[i]+
(weights_after[i]-weights_before[i])*outerstepsize)
if plot and iteration==0 or (iteration+1) % 10 == 0:
plt.cla()
f = f_plot
weights_before = model.weights
plt.plot(x_all, model.predict(x_all), label="pred after 0", color=(0,0,1))
for inneriter in range(32):
train_on_batch(xtrain_plot, f(xtrain_plot))
if (inneriter+1) % 8 == 0:
frac = (inneriter+1) / 32
plt.plot(x_all, model.predict(x_all), label="pred after %i"%(inneriter+1), color=(frac, 0, 1-frac))
plt.plot(x_all, f(x_all), label="true", color=(0,1,0))
lossval = np.square(model.predict(x_all) - f(x_all)).mean()
plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k")
plt.ylim(-4,4)
plt.legend(loc="lower right")
plt.pause(0.01)
for i in range(len(model.weights)):
model.weights[i] = weights_before[i]
print(f"-----------------------------")
print(f"iteration {iteration+1}")
print(f"loss on plotted curve {lossval:.3f}") # wo
@victorychain
Copy link

model.train_on_batch(x,y) has the gradient update, so you don't need to do it manually in the def train_on_batch(x,y): again. This should make this script run faster.

@EugenHotaj
Copy link

Did you actually run this code? I don't think it works. In particular, model.weights will just return pointers to tf.Variable so weights_before and weights_after will have the same value when ran.

What you want to do is call model.get_weights() as this will give you the value of the current weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment