-
-
Save arafatkatze/c063bddb9b8d17a037695d748db4f592 to your computer and use it in GitHub Desktop.
# This file is useful for reading the contents of the ops generated by ruby. | |
# You can read any graph defination in pb/pbtxt format generated by ruby | |
# or by python and then convert it back and forth from human readable to binary format. | |
import tensorflow as tf | |
from google.protobuf import text_format | |
from tensorflow.python.platform import gfile | |
def pbtxt_to_graphdef(filename): | |
with open(filename, 'r') as f: | |
graph_def = tf.GraphDef() | |
file_content = f.read() | |
text_format.Merge(file_content, graph_def) | |
tf.import_graph_def(graph_def, name='') | |
tf.train.write_graph(graph_def, 'pbtxt/', 'protobuf.pb', as_text=False) | |
def graphdef_to_pbtxt(filename): | |
with gfile.FastGFile(filename,'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
tf.import_graph_def(graph_def, name='') | |
tf.train.write_graph(graph_def, 'pbtxt/', 'protobuf.pbtxt', as_text=True) | |
return | |
graphdef_to_pbtxt('graph.pb') # here you can write the name of the file to be converted | |
# and then a new file will be made in pbtxt directory. |
The SavedModel protobuf message is not a GraphDef, hence your error. There definitely is a way to work with protobuf messages directly without having to actually interpret them as valid Tensorflow objects - at least when conversion between binary and text formats. I can't remember right now what module it was, I'll try to have a look and post a further comment if I find it.
Ok, I think I actually found it. Look up the google.protobuf
module and/or see if you find a "saved_model_pb2" file you could import (that would be the generated python wrapper for the SavedModel message definition, via which I think it should be possible to load the file and convert it between the text/binary format)
The SavedModel protobuf messag is not a GraphDef, hence your error. There definitely is a way to work with protobuf messages directly wthout having to actually interpret them as valid Tensorflow objects - at least when conversion between binary and text formats. I can't remember right now what module it was, I'll try to have a look and post a further comment if I find it.
Any solution that will lead me to saved_model.pb -> saved_model.pbtxt -> saved_model.pb, or just amending the saved_model.pb anyhow is welcome. Thanks for your help, appreciate it!
Ok, I think I actually found it. Look up the
google.protobuf
module and/or see if you find a "saved_model_pb2" file you could import (that would be the generated python wrapper for the SavedModel message definition, via which I think it should be possible to load the file and convert it between the text/binary format)
Could you please provide a minimal example on how to check if the google.protobuf
module has a "saved_model_pb2" and load this, in order to save it back to .pbtxt. Sorry, but even the terminology google.protobuf
is vague for me...
I tried to implement the same but I am getting the following error
in graphdef_to_pbtxt(filename)
7 with open(filename,'rb') as f:
8 graph_def = tf.compat.v1.GraphDef()
----> 9 graph_def.ParseFromString(f.read())
10 with open('protobuf.txt', 'w') as fp:
11 fp.write(str(graph_def))
DecodeError: Error parsing message
The code I used is as follows
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.platform import gfile
def graphdef_to_pbtxt(filename):
with open(filename,'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with open('protobuf.txt', 'w') as fp:
fp.write(str(graph_def))
graphdef_to_pbtxt('saved_model.pb')
Can anybody help me on this?
This function obtained from here will do the trick:
import tensorflow as tf
import sys
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat
model_filename ='saved_model.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
Thanx for your fast reply. So, let me explain my actual goal. I currently re-train a new custom BERT from scratch (https://github.com/google-research/bert). BERT comes also as a tensorflow Hub module (https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/1), which makes it really easy to load this model in both Tensorflow and Keras.
As, you may speculate the module contains a
SavedModel
which means it provides the trained variables and thesaved_model.pb
. Thesaved_model.pb
contain the actual graph definition which in case of the google's default BERT has a vocabulary of 30522, which actually means the graph includes such specifications. In order to be able to load my own custom BERT with a different vocab and the same specifications for the rest of the layers, I need to amendsaved_model.pb
, so I used this piece of code:Given the
saved_model.pbtxt
, I can now actually see the specifications of the model in raw text format and amend all the definitions of 30522 to 32000 in my case. Now, I need to do the reverse part and wrap it up again as a .pb file, because this is what Tensorflow Hub reads... I still struggle with this step and that's how I found your code :)I would appreciate any help in order to find a working solution!