Last active
July 13, 2020 12:31
-
-
Save arnaldog12/5c6494d20a4a5d7b23a01975b66811b1 to your computer and use it in GitHub Desktop.
TensorFlow in C++
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#ifndef TENSORFLOW_GRAPH_H | |
#define TENSORFLOW_GRAPH_H | |
#include "TensorflowUtils.h" | |
#include "TensorflowPlaceholder.h" | |
using namespace tensorflow; | |
using deeplearning::TensorflowUtils; | |
using deeplearning::TensorflowPlaceholder; | |
namespace deeplearning | |
{ | |
class TensorflowGraph | |
{ | |
private: | |
Session *session; | |
public: | |
TensorflowGraph(std::string metaFile, std::string checkpointFolder, SessionOptions options = SessionOptions()) | |
{ | |
MetaGraphDef graphDef = this->loadGraphFromMetaFile(metaFile); | |
this->session = this->createSession(graphDef.graph_def(), options); | |
loadCheckpoint(graphDef, checkpointFolder); | |
} | |
TensorflowGraph(std::string protobufFile, SessionOptions options = SessionOptions()) | |
{ | |
GraphDef graphDef = this->loadGraphFromProtobufFile(protobufFile); | |
this->session = this->createSession(graphDef, options); | |
} | |
TensorflowGraph(std::ostringstream& protobufFile, SessionOptions options = SessionOptions()) | |
{ | |
std::string decoded = protobufFile.str(); | |
GraphDef graphDef = this->loadGraphFromString(decoded); | |
this->session = this->createSession(graphDef, options); | |
} | |
~TensorflowGraph() | |
{ | |
tensorflow::Status status = this->session->Close(); | |
delete this->session; | |
} | |
std::vector<std::vector<cv::Mat>> run(TensorflowPlaceholder::tensorDict feedDict, std::vector<std::string> outputTensorNames, std::vector<std::string> targetNodeNames = {}) | |
{ | |
std::vector<Tensor> outputsTensor; | |
TF_CHECK_OK(session->Run(feedDict, outputTensorNames, targetNodeNames, &outputsTensor)); | |
return TensorflowUtils::tensor2mat(outputsTensor); | |
} | |
private: | |
MetaGraphDef loadGraphFromMetaFile(std::string metaFile) | |
{ | |
MetaGraphDef graphDef; | |
TF_CHECK_OK(ReadBinaryProto(Env::Default(), metaFile, &graphDef)); | |
return graphDef; | |
} | |
GraphDef loadGraphFromProtobufFile(std::string protobufFile) | |
{ | |
GraphDef graphDef; | |
TF_CHECK_OK(ReadBinaryProto(Env::Default(), protobufFile, &graphDef)); | |
return graphDef; | |
} | |
GraphDef loadGraphFromString(std::string protobufFile) | |
{ | |
GraphDef graphDef; | |
if (!graphDef.ParseFromString(protobufFile)) throw "Nao foi possible carregar o modelo do Tensorflow!"; | |
return graphDef; | |
} | |
Session* createSession(GraphDef graphDef, SessionOptions options) | |
{ | |
Session *session; | |
TF_CHECK_OK(NewSession(options, &session)); | |
TF_CHECK_OK(session->Create(graphDef)); | |
return session; | |
} | |
void loadCheckpoint(MetaGraphDef& graphDef, std::string checkpointFolder) | |
{ | |
Tensor checkpointPathTensor(DT_STRING, TensorShape()); | |
checkpointPathTensor.scalar<std::string>()() = checkpointFolder; | |
TF_CHECK_OK( | |
session->Run( | |
{ { graphDef.saver_def().filename_tensor_name(), checkpointPathTensor } }, | |
{}, | |
{ graphDef.saver_def().restore_op_name() }, | |
nullptr) | |
); | |
} | |
}; | |
} | |
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#ifndef TENSORFLOW_PLACEHOLDER_H | |
#define TENSORFLOW_PLACEHOLDER_H | |
using namespace tensorflow; | |
namespace deeplearning | |
{ | |
class TensorflowPlaceholder | |
{ | |
public: | |
typedef std::pair<std::string, Tensor> placeholderType; | |
typedef std::vector<placeholderType> tensorDict; | |
static placeholderType tensor(string key, Tensor t) | |
{ | |
return { key, t }; | |
} | |
static placeholderType boolean(string key, bool value) | |
{ | |
Tensor placeholder(DT_BOOL, TensorShape()); | |
placeholder.scalar<bool>()() = value; | |
return { key, placeholder }; | |
} | |
}; | |
} | |
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#ifndef TENSORFLOW_UTILS_H | |
#define TENSORFLOW_UTILS_H | |
#include "opencv2/core/core.hpp" | |
#include "tensorflow/core/public/session.h" | |
#include "tensorflow/core/protobuf/meta_graph.pb.h" | |
using namespace tensorflow; | |
typedef enum | |
{ | |
TENSOR_2D = 1, | |
TENSOR_4D = 3, | |
}TENSOR_SHAPE; | |
namespace deeplearning | |
{ | |
class TensorflowUtils | |
{ | |
public: | |
template <class T> | |
static Tensor mat2tensor(cv::Mat image, tensorflow::DataType type = tensorflow::DT_FLOAT, TENSOR_SHAPE shape = TENSOR_4D, int nImages = 1) | |
{ | |
T *imageData = (T *)image.data; | |
TensorShape imageShape; | |
switch (shape) | |
{ | |
case TENSOR_2D: imageShape = TensorShape{ nImages, image.rows * image.cols * image.channels() / nImages }; break; | |
default: imageShape = TensorShape{ nImages, image.rows / nImages, image.cols, image.channels() }; break; | |
} | |
Tensor imageTensor = Tensor(type, imageShape); | |
std::copy_n((char *)imageData, imageShape.num_elements() * sizeof(T), const_cast<char *>(imageTensor.tensor_data().data())); | |
return imageTensor; | |
} | |
template <class T> | |
static Tensor mat2tensor(std::vector<cv::Mat> images, tensorflow::DataType type = tensorflow::DT_FLOAT, TENSOR_SHAPE shape = TENSOR_4D) | |
{ | |
cv::Mat imagesConcat; | |
cv::vconcat(images, imagesConcat); | |
return mat2tensor<T>(imagesConcat, type, shape, images.size()); | |
} | |
static std::vector<cv::Mat> tensor2mat(Tensor tensor) | |
{ | |
TensorShape shape = tensor.shape(); | |
int nDims = shape.dims(); | |
int nImages = shape.dim_size(0); | |
int width = nDims > 2 ? shape.dim_size(2) : (nDims > 1 ? shape.dim_size(1) : shape.dim_size(0)); | |
int height = nDims > 2 ? shape.dim_size(1) : 1; | |
int channels = (nDims == 4) ? shape.dim_size(3) : 1; | |
std::vector<cv::Mat> result; | |
for (int i = 0; i < nImages; i++) | |
{ | |
Tensor slice = tensor.Slice(i, i + 1); | |
assert(slice.IsAligned() == true); | |
float *outputData = slice.flat<float>().data(); | |
cv::Mat imgOut(cv::Size(width, height), CV_32FC(channels)); | |
std::copy_n((char*)outputData, slice.shape().num_elements() * sizeof(float), (char*)imgOut.data); | |
result.push_back(imgOut); | |
} | |
return result; | |
} | |
static std::vector<std::vector<cv::Mat>> tensor2mat(std::vector<Tensor> tensors) | |
{ | |
std::vector<std::vector<cv::Mat>> results; | |
for (Tensor t : tensors) | |
results.push_back(tensor2mat(t)); | |
return results; | |
} | |
}; | |
} | |
#endif |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment