Last active
March 13, 2018 17:50
-
-
Save machinaut/2547b88bdbfc89f2f2e1782df491979b to your computer and use it in GitHub Desktop.
reptile in tensorflow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import tensorflow as tf | |
def get_model(hidden_units=[64, 64], | |
activation=tf.tanh, | |
inner_optimizer=tf.train.GradientDescentOptimizer, | |
inner_learning_rate=0.01, | |
outer_optimizer=tf.train.AdamOptimizer, | |
outer_learning_rate=0.001): | |
# Input Layer | |
x = tf.placeholder(tf.float32, shape=[None, 1], name='x') | |
# Hidden Layers | |
net = tf.identity(x) | |
for i, units in enumerate(hidden_units): | |
net = tf.layers.dense(net, units=units, activation=activation, | |
name='hidden%d' % i) | |
# Output Layer | |
y = tf.layers.dense(net, units=1, name='y') | |
# Labels | |
y_ = tf.placeholder(y.dtype, shape=y.shape, name='y_') | |
# Loss | |
loss = tf.losses.mean_squared_error(labels=y_, predictions=y) | |
# Inner Loop | |
in_opt = inner_optimizer(learning_rate=inner_learning_rate, name='in_opt') | |
in_train = in_opt.minimize(loss, name='in_train') | |
# Variable assignment operation | |
assign_dict = {} # map: variable -> placeholder for assignment | |
assign_ops = [] # list of assignments | |
for v in tf.trainable_variables(): | |
assign_dict[v] = tf.placeholder(v.dtype, shape=v.shape) | |
assign_ops.append(v.assign(assign_dict[v])) | |
def assign(phi, sess): | |
sess.run(assign_ops, feed_dict={ | |
assign_dict[k]: v for k, v in phi.items()}) | |
# Outer Loop Gradient Update | |
out_opt = outer_optimizer(learning_rate=outer_learning_rate) | |
update_dict = {} | |
grads_vars = [] | |
for v in tf.trainable_variables(): | |
update_dict[v] = tf.placeholder(v.dtype, shape=v.shape) | |
grads_vars.append((update_dict[v], v)) | |
update_op = out_opt.apply_gradients(grads_vars) | |
def update(phi, W, sess): | |
assign(phi, sess) | |
sess.run(update_op, feed_dict={ | |
update_dict[v]: phi[v] - W[v] for v in phi.keys()}) | |
return (x, y, y_, loss, in_train, assign, update) | |
def get_task(): | |
a = np.random.uniform(.1, 5.) | |
b = np.random.uniform(0, np.pi * 2) | |
return lambda x: a * np.sin(x + b) | |
def get_params(sess=None): | |
return {v: sess.run(v) for v in tf.trainable_variables()} | |
def SGD(tau, x, y_, train, | |
epochs=10, | |
batch_size=10, | |
sess=None): | |
for _ in range(epochs): | |
data_x = np.random.uniform(-5, 5, size=(batch_size, 1)) | |
sess.run(train, feed_dict={x: data_x, y_: tau(data_x)}) | |
return get_params(sess=sess) | |
def render(step, x, y, y_, train, assign, tests, sess=None): | |
if getattr(render, 'fig', None) is None: | |
plt.ion() | |
f, ((a1, a2), (a3, a4)) = plt.subplots(2, 2, sharex='col', sharey='row') | |
render.fig = f | |
render.axes = (a1, a2, a3, a4) | |
full_x = np.linspace(-10, 10, 100).reshape(-1, 1) | |
phi = get_params(sess=sess) | |
plt.suptitle('iteration %d' % step) | |
for i, (test, ax) in enumerate(zip(tests, render.axes)): | |
assign(phi, sess=sess) | |
test_preds = {} | |
test_preds[0] = sess.run(y, feed_dict={x: full_x}) | |
for j in range(32): | |
SGD(test, x, y_, train, epochs=1, sess=sess) | |
test_preds[j + 1] = sess.run(y, feed_dict={x: full_x}) | |
ax.cla() | |
ax.set_xlim([-10, 10]) | |
ax.set_ylim([-5, 5]) | |
ax.axvline(x=-5, color='k') | |
ax.axvline(x=5, color='k') | |
ax.set_title('test %d' % i) | |
ax.plot(full_x, test(full_x), label='true', color=(0, 1, 0)) | |
for j, test_pred in test_preds.items(): | |
color = (j / 32, 0, 1 - (j / 32), 0.5) | |
label = 'after %d' % j if j % 4 == 0 else None | |
ax.plot(full_x, test_pred, label=label, color=color) | |
plt.pause(0.01) | |
def main(outer_epochs=30000, | |
step_test=500, | |
render_test=True): | |
x, y, y_, loss, in_train, assign, update = get_model() | |
tests = [get_task() for _ in range(4)] | |
test_x = np.linspace(-5, 5, 50).reshape(-1, 1) | |
sess = tf.Session() | |
sess.run(tf.global_variables_initializer()) | |
for i in range(outer_epochs): | |
phi = get_params(sess=sess) | |
if i % step_test == 0 or i == outer_epochs - 1: | |
if render_test: | |
render(i, x, y, y_, in_train, assign, tests, sess=sess) | |
print('outer', i) | |
for j, test in enumerate(tests): | |
assign(phi, sess=sess) | |
score = sess.run(loss, feed_dict={x: test_x, y_: test(test_x)}) | |
print('test%d %0.3f' % (j, score), end=' ') | |
print() | |
assign(phi, sess=sess) | |
W = SGD(get_task(), x, y_, in_train, sess=sess) | |
update(phi, W, sess=sess) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment