Skip to content

Instantly share code, notes, and snippets.

@asimshankar
Last active November 13, 2024 20:16
Show Gist options
  • Save asimshankar/5c96acd1280507940bad9083370fe8dc to your computer and use it in GitHub Desktop.
Save asimshankar/5c96acd1280507940bad9083370fe8dc to your computer and use it in GitHub Desktop.
Training TensorFlow models in C++

Training TensorFlow models in C++

Python is the primary language in which TensorFlow models are typically developed and trained. TensorFlow does have bindings for other programming languages. These bindings have the low-level primitives that are required to build a more complete API, however, lack much of the higher-level API richness of the Python bindings, particularly for defining the model structure.

This file demonstrates taking a model (a TensorFlow graph) created by a Python program and running the training loop in C++.

The model

The model is a trivial one, trying to learn the function: f(x) = W\*x + b, where W and b are model parameters. The training data is constructed so that the "true" value of W is 3 and that of b is 2, i.e., f(x) = 3 * x + 2.

Files

  • model.py: Python code that constructs a model and saves the computational graph in file called graph.pb. TAll other files assume that model.py has been run once.
  • train.cc: C++ code that loads the model, optionally loads model weights saved in a checkpoint, trains a few steps, writes the updated model weights to a checkpoint.

Noteworthy

  • The Python APIs for TensorFlow include other conveniences for training (such as MonitoredSession and tf.train.Estimator), which make it easier to configure checkpointing, evaluation loops etc. The examples here aren't that sophisticated and are focused on basic model training only.
  • In this example, we use placeholders and feed dictionaries to feed input, but in a real example you probably want to use the tf.data API to cconstruct an input pipeline for providing training data to the model.
  • Not demonstrated here, but summaries for TensorBoard can also be produced by executing the summary operations.

See Also

import tensorflow as tf
# Batch of input and target output (1x1 matrices)
x = tf.placeholder(tf.float32, shape=[None, 1, 1], name='input')
y = tf.placeholder(tf.float32, shape=[None, 1, 1], name='target')
# Trivial linear model
y_ = tf.identity(tf.layers.dense(x, 1), name='output')
# Optimize loss
loss = tf.reduce_mean(tf.square(y_ - y), name='loss')
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss, name='train')
init = tf.global_variables_initializer()
# tf.train.Saver.__init__ adds operations to the graph to save
# and restore variables.
saver_def = tf.train.Saver().as_saver_def()
print('Run this operation to initialize variables : ', init.name)
print('Run this operation for a train step : ', train_op.name)
print('Feed this tensor to set the checkpoint filename: ', saver_def.filename_tensor_name)
print('Run this operation to save a checkpoint : ', saver_def.save_tensor_name)
print('Run this operation to restore a checkpoint : ', saver_def.restore_op_name)
# Write the graph out to a file.
with open('graph.pb', 'w') as f:
f.write(tf.get_default_graph().as_graph_def().SerializeToString())
// Example of training the model created by main.py in a C++ program.
//
// See also
// https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/examples/label_image/main.cc
#include <iostream>
#include <vector>
#include <cstdlib>
#include <string>
#include <sys/stat.h>
#include "third_party/tensorflow/core/framework/graph.proto.h"
#include "third_party/tensorflow/core/framework/tensor.h"
#include "third_party/tensorflow/core/lib/io/path.h"
#include "third_party/tensorflow/core/platform/env.h"
#include "third_party/tensorflow/core/platform/init_main.h"
#include "third_party/tensorflow/core/platform/logging.h"
#include "third_party/tensorflow/core/platform/types.h"
#include "third_party/tensorflow/core/public/session.h"
class Model {
public:
Model(const string& graph_def_filename) {
tensorflow::GraphDef graph_def;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
graph_def_filename, &graph_def));
session_.reset(tensorflow::NewSession(tensorflow::SessionOptions()));
TF_CHECK_OK(session_->Create(graph_def));
}
void Init() { TF_CHECK_OK(session_->Run({}, {}, {"init"}, nullptr)); }
void Restore(const string& checkpoint_prefix) {
SaveOrRestore(checkpoint_prefix, "save/restore_all");
}
void Predict(const std::vector<float>& batch) {
std::vector<tensorflow::Tensor> out_tensors;
TF_CHECK_OK(session_->Run({{"input", MakeTensor(batch)}}, {"output"}, {},
&out_tensors));
for (int i = 0; i < batch.size(); ++i) {
std::cout << "\t x = " << batch[i]
<< ", predicted y = " << out_tensors[0].flat<float>()(i)
<< "\n";
}
}
void RunTrainStep(const std::vector<float>& input_batch,
const std::vector<float>& target_batch) {
TF_CHECK_OK(session_->Run({{"input", MakeTensor(input_batch)},
{"target", MakeTensor(target_batch)}},
{}, {"train"}, nullptr));
}
void Checkpoint(const string& checkpoint_prefix) {
SaveOrRestore(checkpoint_prefix, "save/control_dependency");
}
private:
tensorflow::Tensor MakeTensor(const std::vector<float>& batch) {
tensorflow::Tensor t(tensorflow::DT_FLOAT,
tensorflow::TensorShape({(int)batch.size(), 1, 1}));
for (int i = 0; i < batch.size(); ++i) {
t.flat<float>()(i) = batch[i];
}
return t;
}
void SaveOrRestore(const string& checkpoint_prefix, const string& op_name) {
tensorflow::Tensor t(tensorflow::DT_STRING, tensorflow::TensorShape());
t.scalar<string>()() = checkpoint_prefix;
TF_CHECK_OK(session_->Run({{"save/Const", t}}, {}, {op_name}, nullptr));
}
std::unique_ptr<tensorflow::Session> session_;
};
bool DirectoryExists(const string& dir) {
struct stat buf;
return stat(dir.c_str(), &buf) == 0;
}
int main(int argc, char* argv[]) {
const string graph_def_filename =
"/usr/local/google/home/ashankar/tmp/gist/graph.pb";
const string checkpoint_dir = "/usr/local/google/home/ashankar/tmp/gist/checkpoints";
const string checkpoint_prefix = checkpoint_dir + "/checkpoint";
bool restore = DirectoryExists(checkpoint_dir);
// Setup global state for TensorFlow.
tensorflow::port::InitMain(argv[0], &argc, &argv);
std::cout << "Loading graph\n";
Model model(graph_def_filename);
if (!restore) {
std::cout << "Initializing model weights\n";
model.Init();
} else {
std::cout << "Restoring model weights from checkpoint\n";
model.Restore(checkpoint_prefix);
}
const std::vector<float> testdata({1.0, 2.0, 3.0});
std::cout << "Initial predictions\n";
model.Predict(testdata);
std::cout << "Training for a few steps\n";
for (int i = 0; i < 200; ++i) {
std::vector<float> train_inputs, train_targets;
for (int j = 0; j < 10; j++) {
train_inputs.push_back(static_cast<float>(std::rand()) / static_cast<float>(RAND_MAX));
train_targets.push_back(3 * train_inputs.back() + 2);
}
model.RunTrainStep(train_inputs, train_targets);
}
std::cout << "Updated predictions\n";
model.Predict(testdata);
std::cout << "Saving checkpoint\n";
model.Checkpoint(checkpoint_prefix);
return 0;
}
@prantoran
Copy link

What if I trained a model in python and want to load it into a c++ application using Bazel? Btw, this gist was really helpful.

@cg-tester
Copy link

How can I run this project?

@cg-tester
Copy link

Just included 2 files? one python file and one C++ file?
Can I run this project on my local?
Anybody please tell me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment