Created
October 28, 2025 05:12
-
-
Save harieamjari/d661fa6b05ec2612089c127aeb5fdf03 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <iostream> | |
| #include <numeric> | |
| #include <vector> | |
| #include <cstdlib> | |
| #include <cassert> | |
| #define cimg_display 3 | |
| #define cimg_use_png 1 | |
| #include <CImg.h> | |
| #include <onnxruntime_cxx_api.h> | |
| int main(int argc, char *argv[]){ | |
| std::vector<std::string> args; | |
| if (argc == 1){ | |
| std::cout << "no input file" << std::endl; | |
| return 0; | |
| } | |
| for (int i = 0; i < argc; i++) { | |
| args.push_back(std::string(argv[i])); | |
| std::cout << args[i] << std::endl; | |
| } | |
| cimg_library::CImg<float> input_image(std::string(args[1]).c_str()); | |
| std::cout << " read " << args[1] << ", spectrum " << input_image.spectrum() << std::endl; | |
| if (input_image.spectrum() != 4) | |
| input_image.resize(input_image.width(), input_image.height(), 1, 4); | |
| std::cout << " read new " << args[1] << ", spectrum " << input_image.spectrum() << std::endl; | |
| cimg_library::CImg<float> output_image = input_image; | |
| input_image.normalize(0.0,1.0); | |
| // Setup onnx | |
| Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "rembg"); | |
| Ort::SessionOptions session_options; | |
| session_options.SetIntraOpNumThreads(1); | |
| Ort::Session session(env, ORT_TSTR("u2net_human_seg.onnx"), session_options); | |
| Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU); | |
| Ort::AllocatorWithDefaultOptions allocator = Ort::AllocatorWithDefaultOptions(); | |
| std::vector<Ort::AllocatedStringPtr> input_names, output_names; | |
| std::vector<int64_t> input_shape = session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape(); | |
| std::vector<int64_t> output_shape = session.GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape(); | |
| for (size_t i = 0; i < session.GetInputCount(); i++) { | |
| input_names.push_back(session.GetInputNameAllocated(i, allocator)); | |
| std::cout << "input " << i << " " << input_names[i].get() << std::endl; | |
| for (int64_t dim: session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape()) | |
| std::cout << " " << dim << std::endl; | |
| } | |
| for (size_t i = 0; i < session.GetOutputCount(); i++) { | |
| output_names.push_back(session.GetOutputNameAllocated(i, allocator)); | |
| std::cout << "output " << i << " " << output_names[i].get() << std::endl; | |
| for (int64_t dim: session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape()) | |
| std::cout << " " << dim << std::endl; | |
| } | |
| input_image.resize(input_shape[2], input_shape[3]); | |
| Ort::Value input_tensor; | |
| if (input_image.spectrum() == 4) | |
| input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_image.data(), input_image.size()-320*320, input_shape.data(), input_shape.size()); | |
| else if (input_image.spectrum() == 3) | |
| input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_image.data(), input_image.size(), input_shape.data(), input_shape.size()); | |
| else { | |
| std::cout << "input image: invalid spectrum " << input_image.spectrum() << std::endl; | |
| return 0; | |
| } | |
| std::vector<float> output_data(std::accumulate(std::begin(output_shape), std::end(output_shape), 1.0, std::multiplies<float>()), 1.0); | |
| std::cout << "created output_data.size " << output_data.size() << std::endl; | |
| Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output_data.data(), output_data.size(), output_shape.data(), output_shape.size()); | |
| Ort::RunOptions run_options; | |
| std::array<const char *, 1> input_names_ = {input_names[0].get()}; | |
| std::array<const char *, 1> output_names_ = {output_names[0].get()}; | |
| //input_image.display(); | |
| session.Run(run_options, input_names_.data(), &input_tensor, 1, output_names_.data(), &output_tensor, 1); | |
| // memcpy(input_image.data(0,0,0,3), output_data.data(), output_data.size()*sizeof(float)); | |
| for (size_t y = 0, i = 0; y < input_image.height(); y++) | |
| for (size_t x = 0; x < input_image.width(); x++) | |
| input_image(x, y, 0, 3) = output_data[i++]; | |
| input_image.normalize(0.0,255.0).resize(output_image.width(), output_image.height()); | |
| for (size_t y = 0; y < input_image.height(); y++) | |
| for (size_t x = 0; x < input_image.width(); x++) | |
| output_image(x, y, 0, 3) = input_image(x, y, 0, 3); | |
| output_image.save("o.png"); | |
| return 0; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment