Created
December 14, 2017 11:16
-
-
Save marta-sd/b51e452c887868bcd1f8203f6ac05062 to your computer and use it in GitHub Desktop.
Load pretrained model as legs of siamese netwrok
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
# 1. Create and save a network | |
# c = a*b | |
g = tf.Graph() | |
with g.as_default(): | |
a = tf.placeholder(tf.float32, name='a') | |
b = tf.Variable(initial_value=tf.truncated_normal((1,)), name='b') | |
c = tf.multiply(a, b, name='c') | |
s1 = tf.train.Saver() | |
with tf.Session(graph=g) as sess: | |
sess.run(tf.global_variables_initializer()) | |
b_init = sess.run(b) | |
print('initial value: b=%.2f' % b_init) | |
s1.save(sess, 'g1') | |
# 2. Load saved model as legs of siamese network | |
g_siamese = tf.Graph() | |
with g_siamese.as_default(): | |
# two inputs | |
a1 = tf.placeholder(tf.float32) | |
a2 = tf.placeholder(tf.float32) | |
# load g1 and map a to a1 and a2 | |
# we'll have two references to a single network | |
tf.train.import_meta_graph( | |
'g1.meta', | |
import_scope='g1', | |
input_map={'a:0': [a1, a2]}) | |
b_ = g_siamese.get_tensor_by_name('g1/b:0') | |
c1, c2 = tf.split(g_siamese.get_tensor_by_name('g1/c:0'), 2) | |
# we want a1*b - a2*b = 3 | |
cost = tf.abs(c1 - c2 - 3) | |
train = tf.train.AdamOptimizer(0.1).minimize(cost) | |
init_op = tf.global_variables_initializer() | |
loader = tf.train.Saver(var_list={'b': b_}) | |
# 3. Optimize b for given a1 and a2 | |
feed_dict = {a1: 1.0, a2: -1.0} | |
print('a1=%s and a2=%s; optimal b=%s' % (feed_dict[a1], feed_dict[a2], | |
3.0/(feed_dict[a1]-feed_dict[a2]))) | |
with tf.Session(graph=g_siamese) as sess: | |
sess.run(init_op) | |
loader.restore(sess, './g1') | |
cost_value, b_value = sess.run([cost, b_], feed_dict=feed_dict) | |
print('initial values: b=%.2f (restored); cost=%.2f' % (b_value, cost_value)) | |
for _ in range(100): | |
sess.run(train, feed_dict=feed_dict) | |
cost_value, b_value = sess.run([cost, b_], feed_dict=feed_dict) | |
print('after training: b=%.2f; cost=%.2f' % (b_value, cost_value)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, after training, I would like to eval one leg of the network such as :
But it does not work