Skip to content

Instantly share code, notes, and snippets.

@harieamjari
Created October 28, 2025 05:12
Show Gist options
  • Save harieamjari/d661fa6b05ec2612089c127aeb5fdf03 to your computer and use it in GitHub Desktop.
Save harieamjari/d661fa6b05ec2612089c127aeb5fdf03 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <numeric>
#include <vector>
#include <cstdlib>
#include <cassert>
#define cimg_display 3
#define cimg_use_png 1
#include <CImg.h>
#include <onnxruntime_cxx_api.h>
int main(int argc, char *argv[]){
std::vector<std::string> args;
if (argc == 1){
std::cout << "no input file" << std::endl;
return 0;
}
for (int i = 0; i < argc; i++) {
args.push_back(std::string(argv[i]));
std::cout << args[i] << std::endl;
}
cimg_library::CImg<float> input_image(std::string(args[1]).c_str());
std::cout << " read " << args[1] << ", spectrum " << input_image.spectrum() << std::endl;
if (input_image.spectrum() != 4)
input_image.resize(input_image.width(), input_image.height(), 1, 4);
std::cout << " read new " << args[1] << ", spectrum " << input_image.spectrum() << std::endl;
cimg_library::CImg<float> output_image = input_image;
input_image.normalize(0.0,1.0);
// Setup onnx
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "rembg");
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(1);
Ort::Session session(env, ORT_TSTR("u2net_human_seg.onnx"), session_options);
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeCPU);
Ort::AllocatorWithDefaultOptions allocator = Ort::AllocatorWithDefaultOptions();
std::vector<Ort::AllocatedStringPtr> input_names, output_names;
std::vector<int64_t> input_shape = session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
std::vector<int64_t> output_shape = session.GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape();
for (size_t i = 0; i < session.GetInputCount(); i++) {
input_names.push_back(session.GetInputNameAllocated(i, allocator));
std::cout << "input " << i << " " << input_names[i].get() << std::endl;
for (int64_t dim: session.GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape())
std::cout << " " << dim << std::endl;
}
for (size_t i = 0; i < session.GetOutputCount(); i++) {
output_names.push_back(session.GetOutputNameAllocated(i, allocator));
std::cout << "output " << i << " " << output_names[i].get() << std::endl;
for (int64_t dim: session.GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape())
std::cout << " " << dim << std::endl;
}
input_image.resize(input_shape[2], input_shape[3]);
Ort::Value input_tensor;
if (input_image.spectrum() == 4)
input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_image.data(), input_image.size()-320*320, input_shape.data(), input_shape.size());
else if (input_image.spectrum() == 3)
input_tensor = Ort::Value::CreateTensor<float>(memory_info, input_image.data(), input_image.size(), input_shape.data(), input_shape.size());
else {
std::cout << "input image: invalid spectrum " << input_image.spectrum() << std::endl;
return 0;
}
std::vector<float> output_data(std::accumulate(std::begin(output_shape), std::end(output_shape), 1.0, std::multiplies<float>()), 1.0);
std::cout << "created output_data.size " << output_data.size() << std::endl;
Ort::Value output_tensor = Ort::Value::CreateTensor<float>(memory_info, output_data.data(), output_data.size(), output_shape.data(), output_shape.size());
Ort::RunOptions run_options;
std::array<const char *, 1> input_names_ = {input_names[0].get()};
std::array<const char *, 1> output_names_ = {output_names[0].get()};
//input_image.display();
session.Run(run_options, input_names_.data(), &input_tensor, 1, output_names_.data(), &output_tensor, 1);
// memcpy(input_image.data(0,0,0,3), output_data.data(), output_data.size()*sizeof(float));
for (size_t y = 0, i = 0; y < input_image.height(); y++)
for (size_t x = 0; x < input_image.width(); x++)
input_image(x, y, 0, 3) = output_data[i++];
input_image.normalize(0.0,255.0).resize(output_image.width(), output_image.height());
for (size_t y = 0; y < input_image.height(); y++)
for (size_t x = 0; x < input_image.width(); x++)
output_image(x, y, 0, 3) = input_image(x, y, 0, 3);
output_image.save("o.png");
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment