Last active
November 20, 2018 11:55
-
-
Save kice/8ca12acaed41b06ba3d096ee3c63e143 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <string> | |
#include <vector> | |
#include <assert.h> | |
#include <opencv/cv.hpp> | |
template<class _Elem, | |
class _Traits, | |
class _Alloc> inline | |
std::vector<std::basic_string<_Elem, _Traits, _Alloc>> split( | |
std::basic_string<_Elem, _Traits, _Alloc>& _Str, | |
const _Elem _Delim) | |
{ | |
std::vector<std::basic_string<_Elem, _Traits, _Alloc>> elems; | |
std::basic_stringstream<_Elem, _Traits, _Alloc> ss(_Str); | |
std::basic_string<_Elem, _Traits, _Alloc> item; | |
while (std::getline(ss, item, _Delim)) { | |
elems.push_back(item); | |
} | |
return elems; | |
} | |
#include <mxnet-cpp/MxNetCpp.h> | |
using namespace mxnet::cpp; | |
int main() | |
{ | |
cv::Mat img = cv::imread("2631_x2_HR.png"); | |
img.convertTo(img, CV_32FC3, 1 / 255.); | |
std::vector<cv::Mat> rgb(3); | |
cv::split(img, rgb); | |
size_t size = img.rows * img.cols; | |
mx_float *data = new mx_float[size * img.channels()]; | |
memcpy(data, rgb[2].data, size); | |
memcpy(data + size, rgb[1].data, size); | |
memcpy(data + size*2, rgb[0].data, size); | |
auto ctx = Context::gpu(1); | |
Symbol net = Symbol::Load("int8-symbol.json"); | |
std::map<std::string, NDArray> params; | |
NDArray::Load("int8-0000.params", nullptr, ¶ms); | |
std::map<std::string, NDArray> _arg_map; | |
std::map<std::string, NDArray> _aux_map; | |
NDArray ndata = NDArray(data, | |
{ 1, index_t(img.channels()), index_t(img.rows), index_t(img.cols) }, ctx); | |
_arg_map["data"] = ndata; | |
for (auto &v : params) { | |
auto name = v.first; | |
auto data = v.second; | |
auto l = split(name, ':'); | |
if (l[0] == "arg") { | |
_arg_map[l[1]] = data; | |
} else if (l[0] == "aux") { | |
_aux_map[l[1]] = data; | |
} | |
} | |
try { | |
std::vector<const char *> map_keys; | |
std::vector<int> dev_types, dev_ids; | |
ExecutorHandle *shared_exec_handle = nullptr; | |
std::vector<const char *> arg_shape_names = { "data" }; | |
std::vector<mx_uint> input_shape_indptr = { 0, 4 }; | |
std::vector<mx_uint> input_shape_data = | |
{ | |
1, | |
static_cast<mx_uint>(img.channels()), | |
static_cast<mx_uint>(img.rows), | |
static_cast<mx_uint>(img.cols) | |
}; | |
std::vector<const char *> arg_dtypes_name = { "data" }; | |
std::vector<int> arg_dtypes = { 0 }; | |
std::vector<const char *> arg_stype_names; | |
std::vector<int> arg_stypes{}; | |
mx_uint size = 0; | |
const char **sarr = nullptr; | |
MXSymbolListArguments(net.GetHandle(), &size, &sarr); | |
if (size == 0) { | |
throw dmlc::Error("MXSymbolListArguments(net.GetHandle(), &size, &sarr); Error"); | |
} | |
std::vector<std::string> arg_names; | |
for (int i = 0; i < size; ++i) { | |
arg_names.push_back(sarr[i]); | |
} | |
++sarr; | |
--size; | |
std::vector<int> shared_buffer_len = { 0 }; | |
std::vector<const char *> shared_buffer_name_list = { nullptr }; | |
std::vector<NDArrayHandle> shared_buffer_handle_list = { nullptr }; | |
const char **updated_shared_buffer_name_list = nullptr; | |
NDArrayHandle *updated_shared_buffer_handles = nullptr; | |
mx_uint num_in_args = 0; | |
NDArrayHandle *in_args = nullptr, *arg_grads = nullptr; | |
mx_uint num_aux_states = 0; | |
NDArrayHandle *aux_states = nullptr; | |
ExecutorHandle handle = 0; | |
assert(MXExecutorSimpleBind( | |
net.GetHandle(), ctx.GetDeviceType(), ctx.GetDeviceId(), | |
0, map_keys.data(), dev_types.data(), dev_ids.data(), | |
0, nullptr, nullptr, | |
arg_shape_names.size(), arg_shape_names.data(), input_shape_data.data(), input_shape_indptr.data(), | |
arg_dtypes_name.size(), arg_dtypes_name.data(), arg_dtypes.data(), | |
0, arg_stype_names.data(), arg_stypes.data(), | |
size, sarr, | |
shared_buffer_len.data(), shared_buffer_name_list.data(), shared_buffer_handle_list.data(), | |
&updated_shared_buffer_name_list, &updated_shared_buffer_handles, | |
&num_in_args, &in_args, &arg_grads, | |
&num_aux_states, &aux_states, | |
nullptr, &handle | |
) == 0); | |
assert(handle != nullptr); | |
for (int i = 0; i < num_in_args; ++i) { | |
NDArray arg_dst = NDArray(in_args[i]); | |
NDArray arg_param = _arg_map[arg_names[i]]; | |
auto shape1 = arg_dst.GetShape(); | |
auto shape2 = arg_param.GetShape(); | |
assert(shape1 == shape2); | |
auto dtype1 = arg_dst.GetDType(); | |
auto dtype2 = arg_param.GetDType(); | |
assert(dtype1 == dtype2); | |
arg_param.CopyTo(&arg_dst); | |
} | |
auto aux_names = net.ListAuxiliaryStates(); | |
for (int i = 0; i < num_aux_states; ++i) { | |
NDArray aux_dst = NDArray(in_args[i]); | |
NDArray aux_param = _aux_map[aux_names[i]]; | |
auto shape1 = aux_dst.GetShape(); | |
auto shape2 = aux_param.GetShape(); | |
assert(shape1 == shape2); | |
auto dtype1 = aux_dst.GetDType(); | |
auto dtype2 = aux_param.GetDType(); | |
assert(dtype1 == dtype2); | |
aux_param.CopyTo(&aux_dst); | |
} | |
assert(MXExecutorForward(handle, false) == 0); | |
NDArray::WaitAll(); | |
NDArrayHandle *out; | |
mx_uint out_size; | |
assert(MXExecutorOutputs(handle, &out_size, &out) == 0); | |
NDArray res = NDArray(out[0]); | |
mx_float *res_data = new mx_float[res.Size()]; | |
memset(res_data, 0, res.Size() * sizeof(mx_float)); | |
res.SyncCopyToCPU(res_data); | |
MXExecutorFree(handle); | |
} catch (const dmlc::Error& err) { | |
printf("%s\n", MXGetLastError()); | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment