Last active
February 8, 2023 15:59
-
-
Save shi3z/35e096de556daff2972dec66f77bfee2 to your computer and use it in GitHub Desktop.
Reptile meta learning on Keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reptile example on keras | |
# Original code on pytorch at https://blog.openai.com/reptile/ | |
# Porting to keras by shi3z Apr. 12 2018 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import keras | |
from keras.layers import * | |
from keras.models import * | |
from keras import backend as K | |
from copy import deepcopy | |
seed = 0 | |
plot = True | |
innerstepsize = 0.02 # stepsize in inner SGD | |
innerepochs = 1 # number of epochs of each inner SGD | |
outerstepsize0 = 0.1 # stepsize of outer optimization, i.e., meta-optimization | |
niterations = 30000 # number of outer updates; each iteration we sample one task and update on it | |
rng = np.random.RandomState(seed) | |
# Define task distribution | |
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points | |
ntrain = 10 # Size of training minibatches | |
def gen_task(): | |
"Generate classification problem" | |
phase = rng.uniform(low=0, high=2*np.pi) | |
ampl = rng.uniform(0.1, 5) | |
f_randomsine = lambda x : np.sin(x + phase) * ampl | |
return f_randomsine | |
model = Sequential() | |
model.add(Dense(64,input_shape=(1,),activation='tanh')) | |
model.add(Dense(64,activation='tanh')) | |
model.add(Dense(1)) | |
model.compile(loss='mse', optimizer='sgd') | |
f_plot = gen_task() | |
xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)] | |
def train_on_batch(x,y): | |
weights_before = model.weights | |
model.train_on_batch(x,y) | |
weights_after = model.weights | |
for i in range(len(weights_after)): | |
model.weights[i] = (weights_after[i]- | |
(weights_after[i]-weights_before[i])*innerstepsize) | |
# Reptile training loop | |
for iteration in range(niterations): | |
weights_before = model.weights | |
# Generate task | |
f = gen_task() | |
y_all = f(x_all) | |
# Do SGD on this task | |
inds = rng.permutation(len(x_all)) | |
for _ in range(innerepochs): | |
for start in range(0, len(x_all), ntrain): | |
mbinds = inds[start:start+ntrain] | |
train_on_batch(x_all[mbinds], y_all[mbinds]) | |
weights_after = model.weights | |
outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule | |
for i in range(len(weights_after)): | |
model.weights[i] = (weights_before[i]+ | |
(weights_after[i]-weights_before[i])*outerstepsize) | |
if plot and iteration==0 or (iteration+1) % 10 == 0: | |
plt.cla() | |
f = f_plot | |
weights_before = model.weights | |
plt.plot(x_all, model.predict(x_all), label="pred after 0", color=(0,0,1)) | |
for inneriter in range(32): | |
train_on_batch(xtrain_plot, f(xtrain_plot)) | |
if (inneriter+1) % 8 == 0: | |
frac = (inneriter+1) / 32 | |
plt.plot(x_all, model.predict(x_all), label="pred after %i"%(inneriter+1), color=(frac, 0, 1-frac)) | |
plt.plot(x_all, f(x_all), label="true", color=(0,1,0)) | |
lossval = np.square(model.predict(x_all) - f(x_all)).mean() | |
plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k") | |
plt.ylim(-4,4) | |
plt.legend(loc="lower right") | |
plt.pause(0.01) | |
for i in range(len(model.weights)): | |
model.weights[i] = weights_before[i] | |
print(f"-----------------------------") | |
print(f"iteration {iteration+1}") | |
print(f"loss on plotted curve {lossval:.3f}") # wo |
Did you actually run this code? I don't think it works. In particular, model.weights will just return pointers to tf.Variable so weights_before and weights_after will have the same value when ran.
What you want to do is call model.get_weights() as this will give you the value of the current weights.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
model.train_on_batch(x,y)
has the gradient update, so you don't need to do it manually in thedef train_on_batch(x,y):
again. This should make this script run faster.