Created
January 21, 2022 14:34
-
-
Save nihui/cd8daede08b86cb8d5298ade674f7620 to your computer and use it in GitHub Desktop.
register custom layer with ncnn c api
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
// ncnn | |
#include <c_api.h> | |
/** test.param content | |
7767517 | |
3 3 | |
Input in0 0 1 input0 | |
Input in1 0 1 input1 | |
MyAvg avg 2 1 input0 input1 output | |
**/ | |
static void pretty_print(const ncnn_mat_t m) | |
{ | |
int w = ncnn_mat_get_w(m); | |
int h = ncnn_mat_get_h(m); | |
int c = ncnn_mat_get_c(m); | |
for (int q = 0; q < c; q++) | |
{ | |
const float* ptr = (const float*)ncnn_mat_get_channel_data(m, q); | |
for (int y = 0; y < h; y++) | |
{ | |
for (int x = 0; x < w; x++) | |
{ | |
printf("%f ", *ptr++); | |
} | |
printf("\n"); | |
} | |
printf("------------------------\n"); | |
} | |
} | |
static int avg_layer_load_param(ncnn_layer_t layer, const ncnn_paramdict_t pd) | |
{ | |
return 0; | |
} | |
static int avg_layer_load_model(ncnn_layer_t layer, const ncnn_modelbin_t mb) | |
{ | |
return 0; | |
} | |
static int avg_layer_create_pipeline(ncnn_layer_t layer, const ncnn_option_t opt) | |
{ | |
return 0; | |
} | |
static int avg_layer_destroy_pipeline(ncnn_layer_t layer, const ncnn_option_t opt) | |
{ | |
return 0; | |
} | |
static int avg_layer_forward_n(const ncnn_layer_t layer, const ncnn_mat_t* bottom_blobs, int n, ncnn_mat_t* top_blobs, int n2, const ncnn_option_t opt) | |
{ | |
// assert n == 2 | |
// assert n2 == 1 | |
const ncnn_mat_t a = bottom_blobs[0]; | |
const ncnn_mat_t b = bottom_blobs[1]; | |
int w = ncnn_mat_get_w(a); | |
int h = ncnn_mat_get_h(a); | |
int c = ncnn_mat_get_c(a); | |
ncnn_mat_t out = ncnn_mat_create_3d(w, h, c, NULL); | |
top_blobs[0] = out; | |
#pragma omp parallel for num_threads(ncnn_option_get_num_threads(opt)) | |
for (int q = 0; q < c; q++) | |
{ | |
const float* aptr = (const float*)ncnn_mat_get_channel_data(a, q); | |
const float* bptr = (const float*)ncnn_mat_get_channel_data(b, q); | |
float* outptr = (float*)ncnn_mat_get_channel_data(out, q); | |
for (int y = 0; y < h; y++) | |
{ | |
for (int x = 0; x < w; x++) | |
{ | |
*outptr = (*aptr + *bptr) * 0.5f; | |
aptr++; | |
bptr++; | |
outptr++; | |
} | |
} | |
} | |
return 0; | |
} | |
static ncnn_layer_t avg_layer_creator(void* userdata) | |
{ | |
ncnn_layer_t layer = ncnn_layer_create(); | |
ncnn_layer_set_one_blob_only(layer, 0); | |
ncnn_layer_set_support_inplace(layer, 0); | |
layer->load_param = avg_layer_load_param; | |
layer->load_model = avg_layer_load_model; | |
layer->create_pipeline = avg_layer_create_pipeline; | |
layer->destroy_pipeline = avg_layer_destroy_pipeline; | |
layer->forward_1 = NULL; | |
layer->forward_n = avg_layer_forward_n; | |
layer->forward_inplace_1 = NULL; | |
layer->forward_inplace_n = NULL; | |
return layer; | |
} | |
static void avg_layer_destroyer(ncnn_layer_t layer, void* userdata) | |
{ | |
ncnn_layer_destroy(layer); | |
} | |
int main() | |
{ | |
ncnn_net_t net = ncnn_net_create(); | |
ncnn_net_register_custom_layer_by_type(net, "MyAvg", avg_layer_creator, avg_layer_destroyer, NULL); | |
ncnn_net_load_param(net, "test.param"); | |
ncnn_net_load_model(net, "test.bin"); | |
{ | |
ncnn_extractor_t ex = ncnn_extractor_create(net); | |
ncnn_mat_t in0 = ncnn_mat_create_3d(2, 3, 4, NULL); | |
ncnn_mat_fill_float(in0, 10.f); | |
ncnn_mat_t in1 = ncnn_mat_create_3d(2, 3, 4, NULL); | |
ncnn_mat_fill_float(in1, 60.f); | |
ncnn_extractor_input(ex, "input0", in0); | |
ncnn_extractor_input(ex, "input1", in1); | |
ncnn_mat_t out; | |
ncnn_extractor_extract(ex, "output", &out); | |
pretty_print(in0); // print 10 | |
pretty_print(in1); // print 60 | |
pretty_print(out); // print 35 <=== (10 + 60) * 0.5 | |
ncnn_mat_destroy(in0); | |
ncnn_mat_destroy(in1); | |
ncnn_mat_destroy(out); | |
ncnn_extractor_destroy(ex); | |
} | |
ncnn_net_destroy(net); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment