Created
March 1, 2017 08:14
-
-
Save lmatt-bit/726401a1bf2733ed60df25fb8864b8eb to your computer and use it in GitHub Desktop.
tensorflow: python save model, cpp load model & weight
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "tensorflow/core/public/session.h" | |
#include "tensorflow/core/platform/env.h" | |
using namespace tensorflow; | |
int main(int argc, char* argv[]) { | |
// Initialize a tensorflow session | |
Session* session; | |
Status status = NewSession(SessionOptions(), &session); | |
if (!status.ok()) { | |
std::cout << status.ToString() << "\n"; | |
return 1; | |
} | |
GraphDef graph_def; | |
std::string graph_path = std::string(argv[1]) + ".pb"; | |
status = ReadBinaryProto(Env::Default(), graph_path.c_str(), &graph_def); | |
std::cout << "0" << std::endl; | |
if (!status.ok()) { | |
std::cout << status.ToString() << "\n"; | |
return 1; | |
} | |
// Add the graph to the session | |
std::cout << "1" << std::endl; | |
status = session->Create(graph_def); | |
if (!status.ok()) { | |
std::cout << status.ToString() << "\n"; | |
return 1; | |
} | |
Tensor fpath(DT_STRING, TensorShape({1, 1})); | |
//fpath.matrix<std::string>()(0,0) = "/root/models/model.ckpt"; | |
fpath.matrix<std::string>()(0,0) = argv[1]; | |
std::vector<std::pair<string, tensorflow::Tensor>> inputs = { | |
//{"save/Const:0", fpath}, | |
{"save/Const", fpath}, | |
}; | |
status = session->Run(inputs, {}, {"save/restore_all"}, nullptr); | |
std::cout << "2" << std::endl; | |
if (!status.ok()) { | |
std::cout << status.ToString() << "\n"; | |
return 1; | |
} | |
std::vector<tensorflow::Tensor> outputs; | |
status = session->Run({}, {"c"}, {}, &outputs); | |
std::cout << "3" << std::endl; | |
if (!status.ok()) { | |
std::cout << status.ToString() << "\n"; | |
return 1; | |
} | |
// Grab the first output (we only evaluated one graph node: "c") | |
// and convert the node to a scalar representation. | |
auto output_c = outputs[0].scalar<float>(); | |
// (There are similar methods for vectors and matrices here: | |
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/tensor.h) | |
// Print the results | |
std::cout << "4" << std::endl; | |
std::cout << outputs[0].DebugString() << "\n"; // Tensor<type: float shape: [] values: 30> | |
std::cout << output_c() << "\n"; // 30 | |
// Free any resources used by the session | |
session->Close(); | |
return 0; | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def save_model(self, model_path): | |
saver = tf.train.Saver() | |
saver.save(self.sess, model_path, meta_graph_suffix="meta", write_meta_graph=True) | |
saver_def = saver.as_saver_def() | |
print "save_def: filename_tensor_name, restore_op_name, restore_op_name" | |
print saver_def.filename_tensor_name | |
print saver_def.restore_op_name | |
print saver_def.save_tensor_name | |
model_dir = os.path.dirname(model_path) | |
model_name = os.path.basename(model_path) | |
## below call must after previous code, otherwise you could get below error message | |
## Not found: FeedInputs: unable to find feed output save/Const:0 | |
tf.train.write_graph(self.sess.graph_def, model_dir, model_name + ".pb", as_text=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment