Last active
November 15, 2020 11:43
-
-
Save takuseno/42b63b6418c26b42d0ddab5d1ac96af3 to your computer and use it in GitHub Desktop.
TensorFlow version of reptile sample https://blog.openai.com/reptile/
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import tensorflow as tf | |
mode = 'maml' | |
seed = 0 | |
plot = True | |
innerstepsize = 0.02 # stepsize in inner SGD | |
innerepochs = 1 # number of epochs of each inner SGD | |
outerstepsize0 = 0.1 if mode == 'reptile' else 0.001 # stepsize of outer optimization, i.e., meta-optimization | |
niterations = 30000 # number of outer updates; each iteration we sample one task and update on it | |
rng = np.random.RandomState(seed) | |
# Define task distribution | |
x_all = np.linspace(-5, 5, 50)[:,None] # All of the x points | |
ntrain = 10 # Size of training minibatches | |
def gen_task(): | |
"Generate classification problem" | |
phase = rng.uniform(low=0, high=2*np.pi) | |
ampl = rng.uniform(0.1, 5) | |
f_randomsine = lambda x : np.sin(x + phase) * ampl | |
return f_randomsine | |
class Network: | |
def __init__(self, innerstepsize, mode='reptile'): | |
self.innerstepsize = innerstepsize | |
self.mode = mode | |
self.build() | |
def network(self, x, scope): | |
with tf.variable_scope(scope): | |
output = tf.layers.dense(x, 64) | |
output = tf.nn.tanh(output) | |
output = tf.layers.dense(output, 64) | |
output = tf.nn.tanh(output) | |
y = tf.layers.dense(output, 1, name='y') | |
return y | |
def build(self): | |
# build network | |
self.x = tf.placeholder(tf.float32, [None, 1], name='x') | |
self.y = self.network(self.x, self.mode) | |
# backup network for outer update | |
backup_model = self.network(self.x, 'backup') | |
variables = tf.get_collection( | |
tf.GraphKeys.TRAINABLE_VARIABLES, self.mode) | |
backup_variables = tf.get_collection( | |
tf.GraphKeys.TRAINABLE_VARIABLES, 'backup') | |
self.label = tf.placeholder(tf.float32, [None, 1], name='label') | |
self.loss = tf.reduce_mean(tf.square(self.y - self.label)) | |
self.gradients = tf.gradients(self.loss, variables) | |
inner_optimize_expr = [] | |
for var, grad in zip(variables, self.gradients): | |
inner_optimize_expr.append( | |
var.assign(var - self.innerstepsize * grad)) | |
self.inner_optimize = tf.group(*inner_optimize_expr) | |
backup_expr = [] | |
for var, backup_var in zip(variables, backup_variables): | |
backup_expr.append(backup_var.assign(var)) | |
self.backup_ops = tf.group(*backup_expr) | |
restore_expr = [] | |
for var, backup_var in zip(variables, backup_variables): | |
restore_expr.append(var.assign(backup_var)) | |
self.restore_ops = tf.group(*restore_expr) | |
self.outerstepsize = tf.placeholder(tf.float32, [], name='outerstepsize') | |
outer_optimize_expr = [] | |
if self.mode == 'reptile': | |
for var, backup_var in zip(variables, backup_variables): | |
outer_optimize_expr.append( | |
var.assign(backup_var + self.outerstepsize * (var - backup_var))) | |
elif self.mode == 'maml': | |
for var, backup_var, grad in zip(variables, backup_variables, self.gradients): | |
outer_optimize_expr.append( | |
var.assign(backup_var - self.outerstepsize * grad)) | |
self.outer_optimize = tf.group(*outer_optimize_expr) | |
if self.mode == 'reptile': | |
self.outer_update = self.outer_update_reptile | |
elif self.mode == 'maml': | |
self.outer_update = self.outer_update_maml | |
def get_session(self): | |
return tf.get_default_session() | |
def predict(self, x): | |
sess = self.get_session() | |
with sess.as_default(): | |
return sess.run(self.y, feed_dict={self.x: x}) | |
def inner_update(self, x, label): | |
sess = self.get_session() | |
with sess.as_default(): | |
feed_dict = { | |
self.x: x, | |
self.label: label | |
} | |
loss, _ = sess.run( | |
[self.loss, self.inner_optimize], feed_dict=feed_dict) | |
return loss | |
def outer_update_reptile(self, outerstepsize): | |
sess = self.get_session() | |
with sess.as_default(): | |
feed_dict = { | |
self.outerstepsize: outerstepsize | |
} | |
sess.run(self.outer_optimize, feed_dict=feed_dict) | |
def outer_update_maml(self, x, label, outerstepsize): | |
sess = self.get_session() | |
with sess.as_default(): | |
feed_dict = { | |
self.x: x, | |
self.label: label, | |
self.outerstepsize: outerstepsize | |
} | |
grad, _ = sess.run([self.gradients, self.outer_optimize], feed_dict=feed_dict) | |
def backup(self): | |
sess = self.get_session() | |
with sess.as_default(): | |
sess.run(self.backup_ops) | |
def restore(self): | |
sess = self.get_session() | |
with sess.as_default(): | |
sess.run(self.restore_ops) | |
def train_on_batch(x, y): | |
model.inner_update(x, y) | |
def predict(x): | |
return model.predict(x) | |
model = Network(innerstepsize=innerstepsize, mode=mode) | |
sess = tf.Session() | |
sess.__enter__() | |
sess.run(tf.global_variables_initializer()) | |
# Choose a fixed task and minibatch for visualization | |
f_plot = gen_task() | |
xtrain_plot = x_all[rng.choice(len(x_all), size=ntrain)] | |
# Reptile training loop | |
for iteration in range(niterations): | |
model.backup() | |
# Generate task | |
f = gen_task() | |
y_all = f(x_all) | |
# Do SGD on this task | |
inds = rng.permutation(len(x_all)) | |
for _ in range(innerepochs): | |
for start in range(0, len(x_all), ntrain): | |
mbinds = inds[start:start+ntrain] | |
train_on_batch(x_all[mbinds], y_all[mbinds]) | |
# Interpolate between current weights and trained weights from this task | |
# I.e. (weights_before - weights_after) is the meta-gradient | |
outerstepsize = outerstepsize0 * (1 - iteration / niterations) # linear schedule | |
if mode == 'reptile': | |
model.outer_update(outerstepsize) | |
elif mode == 'maml': | |
model.outer_update(x_all, y_all, outerstepsize) | |
model.backup() | |
# Periodically plot the results on a particular task and minibatch | |
if plot and iteration==0 or (iteration+1) % 1000 == 0: | |
plt.cla() | |
f = f_plot | |
plt.plot(x_all, predict(x_all), label="pred after 0", color=(0,0,1)) | |
for inneriter in range(32): | |
train_on_batch(xtrain_plot, f(xtrain_plot)) | |
if (inneriter+1) % 8 == 0: | |
frac = (inneriter+1) / 32 | |
plt.plot(x_all, predict(x_all), | |
label="pred after %i"%(inneriter+1), | |
color=(frac, 0, 1-frac)) | |
plt.plot(x_all, f(x_all), label="true", color=(0, 1, 0)) | |
lossval = np.square(predict(x_all) - f(x_all)).mean() | |
plt.plot(xtrain_plot, f(xtrain_plot), "x", label="train", color="k") | |
plt.ylim(-4, 4) | |
plt.legend(loc="lower right") | |
plt.pause(0.01) | |
model.restore() # restore from snapshot | |
print(f"-----------------------------") | |
print(f"iteration {iteration+1}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
mode
variable switches the algorithm between MAML and Reptile.