Last active
March 2, 2022 05:26
-
-
Save Unbinilium/2d08fc3984c4177e5ddaf0338ff14861 to your computer and use it in GitHub Desktop.
MNIST ONNX inferring using OpenCV DNN
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma once | |
#include <algorithm> | |
#include <string> | |
#include <utility> | |
#include <vector> | |
#include <opencv2/core.hpp> | |
#include <opencv2/dnn.hpp> | |
#include <opencv2/imgproc.hpp> | |
#include <opencv2/highgui.hpp> | |
namespace dnn { | |
template <typename T> | |
class mnist_onnx { | |
public: | |
mnist_onnx( | |
const std::string& model_path, | |
const std::vector<T>& labels, | |
int backend_id = cv::dnn::DNN_BACKEND_OPENCV, | |
int target_id = cv::dnn::DNN_TARGET_CPU | |
) : _model_path(model_path), _labels(labels) { | |
_net = cv::dnn::readNetFromONNX(_model_path); | |
_net.setPreferableBackend(backend_id); | |
_net.setPreferableTarget(target_id); | |
_gray = cv::Mat(cv::Size(28, 28), CV_8UC1); | |
} | |
// input image should be cv_8uc3 | |
std::pair<T, double> inferring(const cv::Mat& image) noexcept { | |
cv::cvtColor(image, _gray, cv::COLOR_BGR2GRAY); | |
auto blob = cv::dnn::blobFromImage(_gray, 1.f, cv::Size(28, 28), cv::Scalar(0, 0, 0), false, false, CV_32F); | |
_net.setInput(blob); | |
auto prob = _net.forward(); | |
cv::Mat softmax_prob; | |
float max_prob = 0.f; | |
float sum = 0.f; | |
double confidence = 0.f; | |
max_prob = *std::max_element(prob.begin<float>(), prob.end<float>()); | |
cv::exp(prob - max_prob, softmax_prob); | |
sum = cv::sum(softmax_prob)[0]; | |
softmax_prob /= sum; | |
cv::minMaxLoc(softmax_prob.reshape(1, 1), 0, &confidence, 0, &_class_id); | |
return std::pair<T, double>(_labels.at(_class_id.x), confidence); | |
} | |
private: | |
const std::string _model_path; | |
const std::vector<T> _labels; | |
cv::dnn::Net _net; | |
cv::Mat _gray; | |
cv::Point _class_id; | |
}; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment