Skip to content

Instantly share code, notes, and snippets.

@marta-sd
Created December 14, 2017 11:16
Show Gist options
  • Save marta-sd/b51e452c887868bcd1f8203f6ac05062 to your computer and use it in GitHub Desktop.
Save marta-sd/b51e452c887868bcd1f8203f6ac05062 to your computer and use it in GitHub Desktop.
Load pretrained model as legs of siamese netwrok
import tensorflow as tf
# 1. Create and save a network
# c = a*b
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.float32, name='a')
b = tf.Variable(initial_value=tf.truncated_normal((1,)), name='b')
c = tf.multiply(a, b, name='c')
s1 = tf.train.Saver()
with tf.Session(graph=g) as sess:
sess.run(tf.global_variables_initializer())
b_init = sess.run(b)
print('initial value: b=%.2f' % b_init)
s1.save(sess, 'g1')
# 2. Load saved model as legs of siamese network
g_siamese = tf.Graph()
with g_siamese.as_default():
# two inputs
a1 = tf.placeholder(tf.float32)
a2 = tf.placeholder(tf.float32)
# load g1 and map a to a1 and a2
# we'll have two references to a single network
tf.train.import_meta_graph(
'g1.meta',
import_scope='g1',
input_map={'a:0': [a1, a2]})
b_ = g_siamese.get_tensor_by_name('g1/b:0')
c1, c2 = tf.split(g_siamese.get_tensor_by_name('g1/c:0'), 2)
# we want a1*b - a2*b = 3
cost = tf.abs(c1 - c2 - 3)
train = tf.train.AdamOptimizer(0.1).minimize(cost)
init_op = tf.global_variables_initializer()
loader = tf.train.Saver(var_list={'b': b_})
# 3. Optimize b for given a1 and a2
feed_dict = {a1: 1.0, a2: -1.0}
print('a1=%s and a2=%s; optimal b=%s' % (feed_dict[a1], feed_dict[a2],
3.0/(feed_dict[a1]-feed_dict[a2])))
with tf.Session(graph=g_siamese) as sess:
sess.run(init_op)
loader.restore(sess, './g1')
cost_value, b_value = sess.run([cost, b_], feed_dict=feed_dict)
print('initial values: b=%.2f (restored); cost=%.2f' % (b_value, cost_value))
for _ in range(100):
sess.run(train, feed_dict=feed_dict)
cost_value, b_value = sess.run([cost, b_], feed_dict=feed_dict)
print('after training: b=%.2f; cost=%.2f' % (b_value, cost_value))
@slegall56
Copy link

Hi, after training, I would like to eval one leg of the network such as :

feed_dict = {a1: 1.0}
sess.run([c1], feed_dict=feed_dict)

But it does not work

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment