Skip to content

Instantly share code, notes, and snippets.

@lmatt-bit
Created March 1, 2017 08:14
Show Gist options
  • Save lmatt-bit/726401a1bf2733ed60df25fb8864b8eb to your computer and use it in GitHub Desktop.
Save lmatt-bit/726401a1bf2733ed60df25fb8864b8eb to your computer and use it in GitHub Desktop.
tensorflow: python save model, cpp load model & weight
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/platform/env.h"
using namespace tensorflow;
int main(int argc, char* argv[]) {
// Initialize a tensorflow session
Session* session;
Status status = NewSession(SessionOptions(), &session);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
GraphDef graph_def;
std::string graph_path = std::string(argv[1]) + ".pb";
status = ReadBinaryProto(Env::Default(), graph_path.c_str(), &graph_def);
std::cout << "0" << std::endl;
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
// Add the graph to the session
std::cout << "1" << std::endl;
status = session->Create(graph_def);
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
Tensor fpath(DT_STRING, TensorShape({1, 1}));
//fpath.matrix<std::string>()(0,0) = "/root/models/model.ckpt";
fpath.matrix<std::string>()(0,0) = argv[1];
std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
//{"save/Const:0", fpath},
{"save/Const", fpath},
};
status = session->Run(inputs, {}, {"save/restore_all"}, nullptr);
std::cout << "2" << std::endl;
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
std::vector<tensorflow::Tensor> outputs;
status = session->Run({}, {"c"}, {}, &outputs);
std::cout << "3" << std::endl;
if (!status.ok()) {
std::cout << status.ToString() << "\n";
return 1;
}
// Grab the first output (we only evaluated one graph node: "c")
// and convert the node to a scalar representation.
auto output_c = outputs[0].scalar<float>();
// (There are similar methods for vectors and matrices here:
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/tensor.h)
// Print the results
std::cout << "4" << std::endl;
std::cout << outputs[0].DebugString() << "\n"; // Tensor<type: float shape: [] values: 30>
std::cout << output_c() << "\n"; // 30
// Free any resources used by the session
session->Close();
return 0;
}
def save_model(self, model_path):
saver = tf.train.Saver()
saver.save(self.sess, model_path, meta_graph_suffix="meta", write_meta_graph=True)
saver_def = saver.as_saver_def()
print "save_def: filename_tensor_name, restore_op_name, restore_op_name"
print saver_def.filename_tensor_name
print saver_def.restore_op_name
print saver_def.save_tensor_name
model_dir = os.path.dirname(model_path)
model_name = os.path.basename(model_path)
## below call must after previous code, otherwise you could get below error message
## Not found: FeedInputs: unable to find feed output save/Const:0
tf.train.write_graph(self.sess.graph_def, model_dir, model_name + ".pb", as_text=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment