Skip to content

Instantly share code, notes, and snippets.

@ntakouris
Created January 13, 2021 09:26
Show Gist options
  • Save ntakouris/97dcba5a4429b63d802e4ce74a8c85e5 to your computer and use it in GitHub Desktop.
Save ntakouris/97dcba5a4429b63d802e4ce74a8c85e5 to your computer and use it in GitHub Desktop.
with tf.gfile.GFile(modelfilename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
self.graph_def = graph_def
with tf.Graph().as_default() as g1:
#gi_name = model_name+'/'+gen_input
gi = tf.placeholder(tf.float32,
[None, self.z_dim],
name='g1_gi')
go, gl = tf.import_graph_def(self.graph_def,
input_map={
gen_input: gi
},
return_elements=[gen_output, gen_loss],
op_dict=None,
producer_op_list=None,
name=model_name+'g1')
self.image_shape = go.shape[1:].as_list()
with tf.Graph().as_default() as g2:
#di_name = model_name+'/'+disc_input
di = tf.placeholder(tf.float32,
[None] + self.image_shape,
name='g2_di')
do, = tf.import_graph_def(self.graph_def,
input_map={
disc_input: di
},
return_elements=[disc_output],
op_dict=None,
producer_op_list=None,
name=model_name+'g2')
gdef_1 = g1.as_graph_def()
gdef_2 = g2.as_graph_def()
with tf.Graph().as_default() as g_combined: #merge together
self.gi = tf.placeholder(tf.float32,
[None, self.z_dim],
name='comb_di')
self.di = tf.placeholder(tf.float32,
[None] + self.image_shape,
name='comb_di')
tf.import_graph_def(gdef_1, input_map={'g1_gi': self.gi}, name='g1_comb')
print([n.name for n in tf.get_default_graph().as_graph_def().node if 'Tanh' in n.name])
self.go = g_combined.get_tensor_by_name('g1_comb/' + model_name + 'g1' +'/'+ gen_output)
tf.import_graph_def(gdef_2, input_map={'g2_di': self.di}, name='g2_comb')
self.do = g_combined.get_tensor_by_name('g2_comb/' + model_name + 'g2' +'/'+ disc_output)
self.graph = g_combined
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment