Skip to content

Instantly share code, notes, and snippets.

@jef-sure
Last active February 20, 2026 17:07
Show Gist options
  • Select an option

  • Save jef-sure/043ea4495852087495433cd6aec03171 to your computer and use it in GitHub Desktop.

Select an option

Save jef-sure/043ea4495852087495433cd6aec03171 to your computer and use it in GitHub Desktop.
/* -*- C++ -*-
* File: dht_nn_cl.h
* Copyright 2026 Anton Petrusevich
*
* OpenCL infrastructure for DHT-NN demosaicing GPU inference.
* Fused kernel: gather patch + forward pass.
* Architecture (patch_r, hidden sizes) determined at runtime from the network.
* Kernel source is generated dynamically with the correct #defines.
* One work-item per pixel.
*
* This code is licensed under one of two licenses as you choose:
*
* 1. GNU LESSER GENERAL PUBLIC LICENSE version 2.1
* (See file LICENSE.LGPL provided in LibRaw distribution archive for details).
*
* 2. COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL) Version 1.0
* (See file LICENSE.CDDL provided in LibRaw distribution archive for details).
*/
#ifndef DHT_NN_CL_H
#define DHT_NN_CL_H
#ifdef USE_OPENCL
#ifndef CL_TARGET_OPENCL_VERSION
#define CL_TARGET_OPENCL_VERSION 120
#endif
#ifdef __APPLE__
#include <OpenCL/opencl.h>
#else
#include <CL/cl.h>
#endif
#ifdef _WIN32
#include <windows.h>
#else
#include <pthread.h>
#endif
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cstdarg>
static bool append_fmt(char **buf, size_t *cap, size_t *pos, const char *fmt, ...)
{
while (1)
{
va_list args;
va_start(args, fmt);
int n = vsnprintf(*buf + *pos, *cap - *pos, fmt, args);
va_end(args);
if (n < 0) return false;
size_t need = *pos + (size_t)n + 1;
if (need <= *cap)
{
*pos += (size_t)n;
return true;
}
size_t new_cap = (*cap < 1024) ? 1024 : *cap;
while (new_cap < need)
new_cap *= 2;
char *new_buf = (char *)realloc(*buf, new_cap);
if (!new_buf) return false;
*buf = new_buf;
*cap = new_cap;
}
}
/* ───────────────────────────────────────────────────────────────────
* Generate complete OpenCL kernel source for the given architecture.
* Kernel layers are unrolled at source-generation time for efficiency.
* Hidden-layer activations use two ping-pong buffers.
* ─────────────────────────────────────────────────────────────────── */
static char *generate_cl_kernel_source(int patch_r, int input_size,
int n_hidden, const int *hidden)
{
int n_wt = n_hidden + 1;
int max_h = 0;
for (int i = 0; i < n_hidden; i++)
if (hidden[i] > max_h) max_h = hidden[i];
if (max_h == 0) max_h = 1;
size_t buf_size = 16384;
char *src = (char *)malloc(buf_size);
size_t pos = 0;
if (!src) return NULL;
src[0] = '\0';
/* ── Defines ── */
if (!append_fmt(&src, &buf_size, &pos,
"#define PATCH_R %d\n"
"#define PATCH_SIDE %d\n"
"#define INPUT_SIZE %d\n"
"#define OUTPUT_SIZE 3\n"
"#define FEAT_PER_PX 16\n"
"#define MAX_H %d\n"
"#define HVSH 1\n"
"#define HOR 2\n"
"#define VER 4\n"
"#define HOT 64\n"
"#define MARGIN %.10ff\n\n",
patch_r, 2 * patch_r + 1, input_size, max_h,
log2f(1.2f) / 16.0f))
{
free(src);
return NULL;
}
/* ── bayer_color helper ── */
if (!append_fmt(&src, &buf_size, &pos,
"inline int bayer_color(uint filters, int row, int col)\n"
"{\n"
" return (int)((filters >> (((row << 1 & 14) | (col & 1)) << 1)) & 3u);\n"
"}\n\n"))
{
free(src);
return NULL;
}
/* ── Kernel signature with per-layer W/b arguments ── */
if (!append_fmt(&src, &buf_size, &pos,
"__kernel void dht_nn_refine(\n"
" __global const float *nraw,\n"
" __global const char *ndir_in,\n"
" __global char *ndir_out,\n"))
{
free(src);
return NULL;
}
for (int l = 0; l < n_wt; l++)
if (!append_fmt(&src, &buf_size, &pos,
" __global const float *W%d, __global const float *b%d,\n", l, l))
{
free(src);
return NULL;
}
if (!append_fmt(&src, &buf_size, &pos,
" const int iwidth, const int iheight,\n"
" const int nr_width, const int nr_topmargin,\n"
" const int nr_leftmargin, const uint filters,\n"
" __global float *green_out,\n"
" const float channel_min_g, const float channel_max_g)\n"
"{\n"))
{
free(src);
return NULL;
}
/* ── Index + bounds check ── */
if (!append_fmt(&src, &buf_size, &pos,
" int idx = get_global_id(0);\n"
" int total = iwidth * iheight;\n"
" if (idx >= total) return;\n"
" int i = idx / iwidth;\n"
" int j = idx - i * iwidth;\n"
" green_out[idx] = -1.0f;\n"
" if (i < PATCH_R || i >= iheight - PATCH_R ||\n"
" j < PATCH_R || j >= iwidth - PATCH_R)\n"
" return;\n"
" int x = j + nr_leftmargin;\n"
" int y = i + nr_topmargin;\n"
" int off_center = y * nr_width + x;\n"
" /* Skip green pixels — only non-green positions need H/V interpolation */\n"
" int center_ch = bayer_color(filters, i, j);\n"
" if (center_ch == 3) center_ch = 1;\n"
" if (center_ch == 1) return;\n"
" char d = ndir_in[off_center];\n"
" if (d & HVSH) return;\n\n"))
{
free(src);
return NULL;
}
/* ── 1+2. Fused patch gather + hidden layers ── */
if (n_hidden > 0)
{
int h0 = hidden[0];
/* Declare ping-pong buffers */
if (!append_fmt(&src, &buf_size, &pos,
" float buf_a[MAX_H], buf_b[MAX_H];\n\n"))
{ free(src); return NULL; }
/* Initialize first hidden layer accumulators from biases */
if (!append_fmt(&src, &buf_size, &pos,
" for (int o = 0; o < %d; o++) buf_a[o] = b0[o];\n\n", h0))
{ free(src); return NULL; }
/* Fused: gather features per-pixel and accumulate into first hidden layer.
* Eliminates float patch[INPUT_SIZE] (6.7 KB!) from private memory.
* All work-items in a warp access the same W0 offsets → good L1 reuse. */
if (!append_fmt(&src, &buf_size, &pos,
" int pidx = 0;\n"
" for (int dy = -PATCH_R; dy <= PATCH_R; dy++)\n"
" {\n"
" for (int dx = -PATCH_R; dx <= PATCH_R; dx++)\n"
" {\n"
" int py = (i + dy) + nr_topmargin;\n"
" int px = (j + dx) + nr_leftmargin;\n"
" int off = py * nr_width + px;\n"
" int off3 = off * 3;\n"
" float f0 = nraw[off3 + 0];\n"
" float f1 = nraw[off3 + 1];\n"
" float f2 = nraw[off3 + 2];\n"
" int ch = bayer_color(filters, i + dy, j + dx);\n"
" if (ch == 3) ch = 1;\n"
" float f6, f7;\n"
" if (ch == 1) {\n"
" float ng_l = nraw[(py * nr_width + (px-1)) * 3 + 0] + nraw[(py * nr_width + (px-1)) * 3 + 2];\n"
" float ng_r = nraw[(py * nr_width + (px+1)) * 3 + 0] + nraw[(py * nr_width + (px+1)) * 3 + 2];\n"
" float ng_u = nraw[((py-1) * nr_width + px) * 3 + 0] + nraw[((py-1) * nr_width + px) * 3 + 2];\n"
" float ng_d = nraw[((py+1) * nr_width + px) * 3 + 0] + nraw[((py+1) * nr_width + px) * 3 + 2];\n"
" f6 = nraw[off3 + 1] - (ng_l + ng_r) * 0.5f;\n"
" f7 = nraw[off3 + 1] - (ng_u + ng_d) * 0.5f;\n"
" } else {\n"
" f6 = (nraw[(py * nr_width + (px-1)) * 3 + 1] + nraw[(py * nr_width + (px+1)) * 3 + 1]) * 0.5f\n"
" - nraw[off3 + ch];\n"
" f7 = (nraw[((py-1) * nr_width + px) * 3 + 1] + nraw[((py+1) * nr_width + px) * 3 + 1]) * 0.5f\n"
" - nraw[off3 + ch];\n"
" }\n"
" float f8, f9;\n"
" if (ch == 1) {\n"
" float kc_l = nraw[(py * nr_width + (px-1)) * 3 + 0] + nraw[(py * nr_width + (px-1)) * 3 + 2];\n"
" float kc_r = nraw[(py * nr_width + (px+1)) * 3 + 0] + nraw[(py * nr_width + (px+1)) * 3 + 2];\n"
" float hh1g = kc_l - (nraw[(py * nr_width + (px-2)) * 3 + 1] + nraw[off3 + 1]) * 0.5f;\n"
" float hh2g = kc_r - (nraw[off3 + 1] + nraw[(py * nr_width + (px+2)) * 3 + 1]) * 0.5f;\n"
" f8 = fabs(hh1g - hh2g);\n"
" float kc_u = nraw[((py-1) * nr_width + px) * 3 + 0] + nraw[((py-1) * nr_width + px) * 3 + 2];\n"
" float kc_d = nraw[((py+1) * nr_width + px) * 3 + 0] + nraw[((py+1) * nr_width + px) * 3 + 2];\n"
" float hv1g = kc_u - (nraw[((py-2) * nr_width + px) * 3 + 1] + nraw[off3 + 1]) * 0.5f;\n"
" float hv2g = kc_d - (nraw[off3 + 1] + nraw[((py+2) * nr_width + px) * 3 + 1]) * 0.5f;\n"
" f9 = fabs(hv1g - hv2g);\n"
" } else {\n"
" float hh1n = nraw[(py * nr_width + (px-1)) * 3 + 1] - (nraw[(py * nr_width + (px-2)) * 3 + ch] + nraw[off3 + ch]) * 0.5f;\n"
" float hh2n = nraw[(py * nr_width + (px+1)) * 3 + 1] - (nraw[off3 + ch] + nraw[(py * nr_width + (px+2)) * 3 + ch]) * 0.5f;\n"
" f8 = fabs(hh1n - hh2n);\n"
" float hv1n = nraw[((py-1) * nr_width + px) * 3 + 1] - (nraw[((py-2) * nr_width + px) * 3 + ch] + nraw[off3 + ch]) * 0.5f;\n"
" float hv2n = nraw[((py+1) * nr_width + px) * 3 + 1] - (nraw[off3 + ch] + nraw[((py+2) * nr_width + px) * 3 + ch]) * 0.5f;\n"
" f9 = fabs(hv1n - hv2n);\n"
" }\n"
" float f10, f11, f12, f13, f14, f15;\n"
" {\n"
" float c0 = nraw[off3 + ch];\n"
" float g_u = nraw[((py-1) * nr_width + px) * 3 + 1];\n"
" float g_d = nraw[((py+1) * nr_width + px) * 3 + 1];\n"
" float g_l = nraw[(py * nr_width + (px-1)) * 3 + 1];\n"
" float g_r = nraw[(py * nr_width + (px+1)) * 3 + 1];\n"
" float c_u2 = nraw[((py-2) * nr_width + px) * 3 + ch];\n"
" float c_d2 = nraw[((py+2) * nr_width + px) * 3 + ch];\n"
" float c_l2 = nraw[(py * nr_width + (px-2)) * 3 + ch];\n"
" float c_r2 = nraw[(py * nr_width + (px+2)) * 3 + ch];\n"
" f10 = (g_u + g_d) * 0.5f - (c_u2 + c_d2 - 2.0f * c0) * 0.25f;\n"
" f11 = (g_l + g_r) * 0.5f - (c_l2 + c_r2 - 2.0f * c0) * 0.25f;\n"
" f12 = (g_l + g_u) * 0.5f - (c_l2 + c_u2 - 2.0f * c0) * 0.25f;\n"
" f13 = (g_l + g_d) * 0.5f - (c_l2 + c_d2 - 2.0f * c0) * 0.25f;\n"
" f14 = (g_r + g_u) * 0.5f - (c_r2 + c_u2 - 2.0f * c0) * 0.25f;\n"
" f15 = (g_r + g_d) * 0.5f - (c_r2 + c_d2 - 2.0f * c0) * 0.25f;\n"
" }\n"
" for (int o = 0; o < %d; o++)\n"
" {\n"
" __global const float *w = W0 + o * INPUT_SIZE + pidx;\n"
" buf_a[o] += f0*w[0] + f1*w[1] + f2*w[2]\n"
" + w[3 + ch]\n"
" + f6*w[6] + f7*w[7] + f8*w[8] + f9*w[9]\n"
" + f10*w[10] + f11*w[11] + f12*w[12]\n"
" + f13*w[13] + f14*w[14] + f15*w[15];\n"
" }\n"
" pidx += FEAT_PER_PX;\n"
" }\n"
" }\n\n", h0))
{ free(src); return NULL; }
/* LeakyReLU (alpha=0.01) on first hidden layer */
if (!append_fmt(&src, &buf_size, &pos,
" for (int o = 0; o < %d; o++) buf_a[o] = buf_a[o] > 0.0f ? buf_a[o] : 0.01f * buf_a[o];\n\n", h0))
{ free(src); return NULL; }
/* Remaining hidden layers (l=1..n_hidden-1, same ping-pong) */
for (int l = 1; l < n_hidden; l++)
{
int out_sz = hidden[l];
int in_sz = hidden[l - 1];
const char *in_name = (l & 1) ? "buf_a" : "buf_b";
const char *out_name = (l & 1) ? "buf_b" : "buf_a";
if (!append_fmt(&src, &buf_size, &pos,
" for (int o = 0; o < %d; o++)\n"
" {\n"
" float sum = b%d[o];\n"
" __global const float *row = W%d + o * %d;\n"
" for (int k = 0; k < %d; k++)\n"
" sum += %s[k] * row[k];\n"
" %s[o] = sum > 0.0f ? sum : 0.01f * sum;\n"
" }\n\n",
out_sz, l, l, in_sz, in_sz, in_name, out_name))
{ free(src); return NULL; }
}
}
else
{
/* n_hidden == 0: direct input->output, must gather patch */
if (!append_fmt(&src, &buf_size, &pos,
" float patch[INPUT_SIZE];\n"
" int pidx = 0;\n"
" for (int dy = -PATCH_R; dy <= PATCH_R; dy++)\n"
" {\n"
" for (int dx = -PATCH_R; dx <= PATCH_R; dx++)\n"
" {\n"
" int py = (i + dy) + nr_topmargin;\n"
" int px = (j + dx) + nr_leftmargin;\n"
" int off = py * nr_width + px;\n"
" int off3 = off * 3;\n"
" patch[pidx++] = nraw[off3 + 0];\n"
" patch[pidx++] = nraw[off3 + 1];\n"
" patch[pidx++] = nraw[off3 + 2];\n"
" int ch = bayer_color(filters, i + dy, j + dx);\n"
" if (ch == 3) ch = 1;\n"
" patch[pidx++] = (ch == 0) ? 1.0f : 0.0f;\n"
" patch[pidx++] = (ch == 1) ? 1.0f : 0.0f;\n"
" patch[pidx++] = (ch == 2) ? 1.0f : 0.0f;\n"
" float hue_h, hue_v;\n"
" if (ch == 1) {\n"
" float ng_l = nraw[(py * nr_width + (px-1)) * 3 + 0] + nraw[(py * nr_width + (px-1)) * 3 + 2];\n"
" float ng_r = nraw[(py * nr_width + (px+1)) * 3 + 0] + nraw[(py * nr_width + (px+1)) * 3 + 2];\n"
" float ng_u = nraw[((py-1) * nr_width + px) * 3 + 0] + nraw[((py-1) * nr_width + px) * 3 + 2];\n"
" float ng_d = nraw[((py+1) * nr_width + px) * 3 + 0] + nraw[((py+1) * nr_width + px) * 3 + 2];\n"
" hue_h = nraw[off3 + 1] - (ng_l + ng_r) * 0.5f;\n"
" hue_v = nraw[off3 + 1] - (ng_u + ng_d) * 0.5f;\n"
" } else {\n"
" hue_h = (nraw[(py * nr_width + (px-1)) * 3 + 1] + nraw[(py * nr_width + (px+1)) * 3 + 1]) * 0.5f\n"
" - nraw[off3 + ch];\n"
" hue_v = (nraw[((py-1) * nr_width + px) * 3 + 1] + nraw[((py+1) * nr_width + px) * 3 + 1]) * 0.5f\n"
" - nraw[off3 + ch];\n"
" }\n"
" patch[pidx++] = hue_h;\n"
" patch[pidx++] = hue_v;\n"
" float hc_h, hc_v;\n"
" if (ch == 1) {\n"
" float kc_l2 = nraw[(py * nr_width + (px-1)) * 3 + 0] + nraw[(py * nr_width + (px-1)) * 3 + 2];\n"
" float kc_r2 = nraw[(py * nr_width + (px+1)) * 3 + 0] + nraw[(py * nr_width + (px+1)) * 3 + 2];\n"
" float hh1g2 = kc_l2 - (nraw[(py * nr_width + (px-2)) * 3 + 1] + nraw[off3 + 1]) * 0.5f;\n"
" float hh2g2 = kc_r2 - (nraw[off3 + 1] + nraw[(py * nr_width + (px+2)) * 3 + 1]) * 0.5f;\n"
" hc_h = fabs(hh1g2 - hh2g2);\n"
" float kc_u2 = nraw[((py-1) * nr_width + px) * 3 + 0] + nraw[((py-1) * nr_width + px) * 3 + 2];\n"
" float kc_d2 = nraw[((py+1) * nr_width + px) * 3 + 0] + nraw[((py+1) * nr_width + px) * 3 + 2];\n"
" float hv1g2 = kc_u2 - (nraw[((py-2) * nr_width + px) * 3 + 1] + nraw[off3 + 1]) * 0.5f;\n"
" float hv2g2 = kc_d2 - (nraw[off3 + 1] + nraw[((py+2) * nr_width + px) * 3 + 1]) * 0.5f;\n"
" hc_v = fabs(hv1g2 - hv2g2);\n"
" } else {\n"
" float hh1n2 = nraw[(py * nr_width + (px-1)) * 3 + 1] - (nraw[(py * nr_width + (px-2)) * 3 + ch] + nraw[off3 + ch]) * 0.5f;\n"
" float hh2n2 = nraw[(py * nr_width + (px+1)) * 3 + 1] - (nraw[off3 + ch] + nraw[(py * nr_width + (px+2)) * 3 + ch]) * 0.5f;\n"
" hc_h = fabs(hh1n2 - hh2n2);\n"
" float hv1n2 = nraw[((py-1) * nr_width + px) * 3 + 1] - (nraw[((py-2) * nr_width + px) * 3 + ch] + nraw[off3 + ch]) * 0.5f;\n"
" float hv2n2 = nraw[((py+1) * nr_width + px) * 3 + 1] - (nraw[off3 + ch] + nraw[((py+2) * nr_width + px) * 3 + ch]) * 0.5f;\n"
" hc_v = fabs(hv1n2 - hv2n2);\n"
" }\n"
" patch[pidx++] = hc_h;\n"
" patch[pidx++] = hc_v;\n"
" {\n"
" float c0 = nraw[off3 + ch];\n"
" float g_u2 = nraw[((py-1) * nr_width + px) * 3 + 1];\n"
" float g_d2 = nraw[((py+1) * nr_width + px) * 3 + 1];\n"
" float g_l2 = nraw[(py * nr_width + (px-1)) * 3 + 1];\n"
" float g_r2 = nraw[(py * nr_width + (px+1)) * 3 + 1];\n"
" float c_u2 = nraw[((py-2) * nr_width + px) * 3 + ch];\n"
" float c_d2 = nraw[((py+2) * nr_width + px) * 3 + ch];\n"
" float c_l2 = nraw[(py * nr_width + (px-2)) * 3 + ch];\n"
" float c_r2 = nraw[(py * nr_width + (px+2)) * 3 + ch];\n"
" patch[pidx++] = (g_u2 + g_d2) * 0.5f - (c_u2 + c_d2 - 2.0f * c0) * 0.25f;\n"
" patch[pidx++] = (g_l2 + g_r2) * 0.5f - (c_l2 + c_r2 - 2.0f * c0) * 0.25f;\n"
" patch[pidx++] = (g_l2 + g_u2) * 0.5f - (c_l2 + c_u2 - 2.0f * c0) * 0.25f;\n"
" patch[pidx++] = (g_l2 + g_d2) * 0.5f - (c_l2 + c_d2 - 2.0f * c0) * 0.25f;\n"
" patch[pidx++] = (g_r2 + g_u2) * 0.5f - (c_r2 + c_u2 - 2.0f * c0) * 0.25f;\n"
" patch[pidx++] = (g_r2 + g_d2) * 0.5f - (c_r2 + c_d2 - 2.0f * c0) * 0.25f;\n"
" }\n"
" }\n"
" }\n\n"))
{ free(src); return NULL; }
}
/* ── 3. Output layer ── */
{
int l = n_hidden;
int in_sz = (n_hidden == 0) ? input_size : hidden[n_hidden - 1];
const char *in_name;
if (n_hidden == 0)
in_name = "patch";
else
in_name = ((n_hidden - 1) & 1) ? "buf_b" : "buf_a";
if (!append_fmt(&src, &buf_size, &pos,
" float out0 = b%d[0];\n"
" float out1 = b%d[1];\n"
" float out2 = b%d[2];\n"
" for (int k = 0; k < %d; k++)\n"
" {\n"
" out0 += %s[k] * W%d[k];\n"
" out1 += %s[k] * W%d[%d + k];\n"
" out2 += %s[k] * W%d[%d + k];\n"
" }\n\n",
l, l, l, in_sz, in_name, l, in_name, l, in_sz, in_name, l, 2 * in_sz))
{
free(src);
return NULL;
}
}
/* ── 4. Blend green + write direction ── */
if (!append_fmt(&src, &buf_size, &pos,
" int kc = bayer_color(filters, i, j);\n"
" if (kc == 3) kc = 1;\n"
" int off3c = off_center * 3;\n"
" float c0 = nraw[off3c + kc];\n\n"
" float max_hv = out0 > out1 ? out0 : out1;\n"
" float eh = exp(out0 - max_hv);\n"
" float ev = exp(out1 - max_hv);\n"
" float shv = eh + ev;\n"
" float H_w = eh / shv;\n"
" float V_w = ev / shv;\n"
" float S = 1.0f / (1.0f + exp(-out2));\n\n"
" float gl = nraw[(y * nr_width + (x-1)) * 3 + 1];\n"
" float gr = nraw[(y * nr_width + (x+1)) * 3 + 1];\n"
" float cl2 = nraw[(y * nr_width + (x-2)) * 3 + kc];\n"
" float cr2 = nraw[(y * nr_width + (x+2)) * 3 + kc];\n"
" float dh1 = gr - (cr2 + c0) * 0.5f;\n"
" float dh2 = gl - (cl2 + c0) * 0.5f;\n"
" float bh1 = fabs(c0 - cr2); bh1 = bh1 > 1e-6f ? 1.0f / bh1 : 1e6f; bh1 *= bh1;\n"
" float bh2 = fabs(c0 - cl2); bh2 = bh2 > 1e-6f ? 1.0f / bh2 : 1e6f; bh2 *= bh2;\n"
" float eg_h = c0 + (bh1 * dh1 + bh2 * dh2) / (bh1 + bh2);\n"
" eg_h = clamp(eg_h, fmax(fmin(gr, gl) - MARGIN, channel_min_g),\n"
" fmin(fmax(gr, gl) + MARGIN, channel_max_g));\n\n"
" float gu = nraw[((y-1) * nr_width + x) * 3 + 1];\n"
" float gd = nraw[((y+1) * nr_width + x) * 3 + 1];\n"
" float cu2 = nraw[((y-2) * nr_width + x) * 3 + kc];\n"
" float cd2 = nraw[((y+2) * nr_width + x) * 3 + kc];\n"
" float dv1 = gu - (cu2 + c0) * 0.5f;\n"
" float dv2 = gd - (cd2 + c0) * 0.5f;\n"
" float bv1 = fabs(c0 - cu2); bv1 = bv1 > 1e-6f ? 1.0f / bv1 : 1e6f; bv1 *= bv1;\n"
" float bv2 = fabs(c0 - cd2); bv2 = bv2 > 1e-6f ? 1.0f / bv2 : 1e6f; bv2 *= bv2;\n"
" float eg_v = c0 + (bv1 * dv1 + bv2 * dv2) / (bv1 + bv2);\n"
" eg_v = clamp(eg_v, fmax(fmin(gu, gd) - MARGIN, channel_min_g),\n"
" fmin(fmax(gu, gd) + MARGIN, channel_max_g));\n\n"
" float soft = H_w * eg_h + V_w * eg_v;\n"
" float hard = H_w > V_w ? eg_h : eg_v;\n"
" float green = (1.0f - S) * soft + S * hard;\n"
" green_out[idx] = clamp(green, channel_min_g, channel_max_g);\n\n"
" char new_d = d & ~(HOR | VER);\n"
" new_d |= (out0 > out1) ? HOR : VER;\n"
" ndir_out[off_center] = new_d;\n"
"}\n"))
{
free(src);
return NULL;
}
return src;
}
/* ───────────────────────────────────────────────────────────────────
* OpenCL runtime wrapper — single-GPU context, lazy init, cached.
* Recompiles kernel if architecture changes.
* ─────────────────────────────────────────────────────────────────── */
struct DHTNNCLContext
{
cl_context context;
cl_command_queue queue;
cl_program program;
cl_kernel kernel;
cl_device_id device;
bool valid;
bool device_ready; /* context+queue created */
/* Cached architecture the kernel was compiled for */
int compiled_n_layers;
int *compiled_sizes; /* [compiled_n_layers] layer sizes for comparison */
/* Cached weight buffers on GPU (persist across calls, re-uploaded on demand) */
cl_mem *d_W_cached;
cl_mem *d_b_cached;
int cached_nw; /* number of weight layers currently cached */
/* Cached per-image GPU buffers (persist across calls, reallocated on size change) */
cl_mem d_nraw_cached;
cl_mem d_ndir_in_cached;
cl_mem d_ndir_out_cached;
cl_mem d_green_out_cached;
size_t cached_nraw_bytes; /* nr_height * nr_width * 3 * sizeof(float) */
size_t cached_ndir_bytes; /* nr_height * nr_width */
size_t cached_green_bytes; /* iwidth * iheight * sizeof(float) */
/* Preferred local work group size multiple (queried after kernel build) */
size_t preferred_wg_multiple;
#ifdef _WIN32
CRITICAL_SECTION ctx_lock;
#else
pthread_mutex_t ctx_lock;
#endif
struct ScopedLock
{
DHTNNCLContext &ctx;
ScopedLock(DHTNNCLContext &c) : ctx(c)
{
#ifdef _WIN32
EnterCriticalSection(&ctx.ctx_lock);
#else
pthread_mutex_lock(&ctx.ctx_lock);
#endif
}
~ScopedLock()
{
#ifdef _WIN32
LeaveCriticalSection(&ctx.ctx_lock);
#else
pthread_mutex_unlock(&ctx.ctx_lock);
#endif
}
};
DHTNNCLContext()
: context(NULL), queue(NULL), program(NULL), kernel(NULL),
device(NULL), valid(false), device_ready(false),
compiled_n_layers(0), compiled_sizes(NULL),
d_W_cached(NULL), d_b_cached(NULL), cached_nw(0),
d_nraw_cached(NULL), d_ndir_in_cached(NULL),
d_ndir_out_cached(NULL), d_green_out_cached(NULL),
cached_nraw_bytes(0), cached_ndir_bytes(0), cached_green_bytes(0),
preferred_wg_multiple(0)
{
#ifdef _WIN32
InitializeCriticalSection(&ctx_lock);
#else
pthread_mutex_init(&ctx_lock, NULL);
#endif
}
~DHTNNCLContext()
{
release_image_bufs();
release_weights();
release_kernel();
if (queue) clReleaseCommandQueue(queue);
if (context) clReleaseContext(context);
free(compiled_sizes);
#ifdef _WIN32
DeleteCriticalSection(&ctx_lock);
#else
pthread_mutex_destroy(&ctx_lock);
#endif
}
void release_kernel()
{
if (kernel) { clReleaseKernel(kernel); kernel = NULL; }
if (program) { clReleaseProgram(program); program = NULL; }
valid = false;
preferred_wg_multiple = 0;
}
void release_image_bufs()
{
if (d_nraw_cached) { clReleaseMemObject(d_nraw_cached); d_nraw_cached = NULL; }
if (d_ndir_in_cached) { clReleaseMemObject(d_ndir_in_cached); d_ndir_in_cached = NULL; }
if (d_ndir_out_cached) { clReleaseMemObject(d_ndir_out_cached); d_ndir_out_cached = NULL; }
if (d_green_out_cached) { clReleaseMemObject(d_green_out_cached); d_green_out_cached = NULL; }
cached_nraw_bytes = cached_ndir_bytes = cached_green_bytes = 0;
}
void release_weights()
{
for (int l = 0; l < cached_nw; l++)
{
if (d_W_cached && d_W_cached[l]) clReleaseMemObject(d_W_cached[l]);
if (d_b_cached && d_b_cached[l]) clReleaseMemObject(d_b_cached[l]);
}
free(d_W_cached); d_W_cached = NULL;
free(d_b_cached); d_b_cached = NULL;
cached_nw = 0;
}
/* Upload (or re-upload) weight buffers to GPU */
bool upload_weights(const DHTNNNetwork &net)
{
int nw = net.num_weights();
/* If architecture changed, must reallocate */
if (nw != cached_nw)
{
release_weights();
d_W_cached = (cl_mem *)calloc((size_t)nw, sizeof(cl_mem));
d_b_cached = (cl_mem *)calloc((size_t)nw, sizeof(cl_mem));
if (!d_W_cached || !d_b_cached) { release_weights(); return false; }
cached_nw = nw;
}
cl_int err;
for (int l = 0; l < nw; l++)
{
int out_sz = net.weight_out(l);
int in_sz = net.weight_in(l);
size_t w_bytes = (size_t)out_sz * in_sz * sizeof(float);
size_t b_bytes = (size_t)out_sz * sizeof(float);
if (d_W_cached[l])
{
/* Buffer exists: update contents without reallocating */
err = clEnqueueWriteBuffer(queue, d_W_cached[l], CL_FALSE, 0, w_bytes, net.W[l], 0, NULL, NULL);
if (err != CL_SUCCESS) { release_weights(); return false; }
err = clEnqueueWriteBuffer(queue, d_b_cached[l], CL_FALSE, 0, b_bytes, net.b[l], 0, NULL, NULL);
if (err != CL_SUCCESS) { release_weights(); return false; }
}
else
{
/* First time: create buffers */
d_W_cached[l] = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
w_bytes, (void*)net.W[l], &err);
if (err != CL_SUCCESS) { release_weights(); return false; }
d_b_cached[l] = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
b_bytes, (void*)net.b[l], &err);
if (err != CL_SUCCESS) { release_weights(); return false; }
}
}
return true;
}
bool init_device()
{
if (device_ready) return true;
cl_int err;
cl_platform_id platforms[8];
cl_uint nplatforms = 0;
err = clGetPlatformIDs(8, platforms, &nplatforms);
if (err != CL_SUCCESS || nplatforms == 0) return false;
device = NULL;
for (cl_uint p = 0; p < nplatforms && !device; p++)
{
cl_device_id devs[8];
cl_uint ndevs = 0;
if (clGetDeviceIDs(platforms[p], CL_DEVICE_TYPE_GPU, 8, devs, &ndevs) == CL_SUCCESS && ndevs > 0)
device = devs[0];
}
if (!device)
{
for (cl_uint p = 0; p < nplatforms && !device; p++)
{
cl_device_id devs[8];
cl_uint ndevs = 0;
if (clGetDeviceIDs(platforms[p], CL_DEVICE_TYPE_ALL, 8, devs, &ndevs) == CL_SUCCESS && ndevs > 0)
device = devs[0];
}
}
if (!device) return false;
char name[256] = {0};
clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(name), name, NULL);
fprintf(stderr, "DHT-NN OpenCL: using device \"%s\"\n", name);
context = clCreateContext(NULL, 1, &device, NULL, NULL, &err);
if (err != CL_SUCCESS) return false;
#ifdef CL_VERSION_2_0
queue = clCreateCommandQueueWithProperties(context, device, NULL, &err);
#else
queue = clCreateCommandQueue(context, device, 0, &err);
#endif
if (err != CL_SUCCESS) { clReleaseContext(context); context = NULL; return false; }
device_ready = true;
return true;
}
/* Check if kernel is already compiled for this architecture */
bool arch_matches(const DHTNNNetwork &net) const
{
if (!valid) return false;
if (!compiled_sizes) return false;
int n_layers = net.n_hidden + 2;
if (n_layers != compiled_n_layers) return false;
if (compiled_sizes[0] != net.input_size) return false;
for (int i = 0; i < net.n_hidden; i++)
if (compiled_sizes[i + 1] != net.hidden[i]) return false;
return compiled_sizes[n_layers - 1] == DHTNNNetwork::OUTPUT_SIZE;
}
/* Build (or rebuild) kernel for a specific architecture */
bool build_kernel(const DHTNNNetwork &net)
{
if (arch_matches(net))
return true; /* already compiled for this arch */
release_kernel();
/* Generate full kernel source */
char *src = generate_cl_kernel_source(net.patch_r, net.input_size,
net.n_hidden, net.hidden);
if (!src)
{
fprintf(stderr, "DHT-NN OpenCL: kernel source generation failed\n");
return false;
}
size_t src_len = strlen(src);
cl_int err;
const char *src_ptr = src;
program = clCreateProgramWithSource(context, 1, &src_ptr, &src_len, &err);
free(src);
if (err != CL_SUCCESS) return false;
err = clBuildProgram(program, 1, &device, "-cl-fast-relaxed-math", NULL, NULL);
if (err != CL_SUCCESS)
{
char log[4096] = {0};
clGetProgramBuildInfo(program, device, CL_PROGRAM_BUILD_LOG, sizeof(log), log, NULL);
fprintf(stderr, "DHT-NN OpenCL build error:\n%s\n", log);
clReleaseProgram(program); program = NULL;
return false;
}
kernel = clCreateKernel(program, "dht_nn_refine", &err);
if (err != CL_SUCCESS)
{
clReleaseProgram(program); program = NULL;
return false;
}
/* Cache architecture */
int n_layers = net.n_hidden + 2;
free(compiled_sizes);
compiled_n_layers = n_layers;
compiled_sizes = (int *)malloc(n_layers * sizeof(int));
compiled_sizes[0] = net.input_size;
for (int i = 0; i < net.n_hidden; i++)
compiled_sizes[i + 1] = net.hidden[i];
compiled_sizes[n_layers - 1] = DHTNNNetwork::OUTPUT_SIZE;
valid = true;
/* Query preferred work group size multiple for launch tuning */
preferred_wg_multiple = 0;
clGetKernelWorkGroupInfo(kernel, device, CL_KERNEL_PREFERRED_WORK_GROUP_SIZE_MULTIPLE,
sizeof(preferred_wg_multiple), &preferred_wg_multiple, NULL);
if (preferred_wg_multiple == 0) preferred_wg_multiple = 64;
fprintf(stderr, "DHT-NN OpenCL: compiled kernel for arch patch_r=%d n_hidden=%d (input=%d) wg_mult=%zu\n",
net.patch_r, net.n_hidden, net.input_size, preferred_wg_multiple);
return true;
}
/*
* Run the NN inference kernel on the GPU.
* Takes the full DHTNNNetwork by const ref, reads arch + weights.
* green_out_ptr: host buffer [iwidth*iheight] to receive blended green values.
* Pixels not processed by the kernel are set to -1.0f.
* Returns true on success (ndir + green_out updated), false to fall back to CPU.
*/
bool run_inference(
float (*nraw_ptr)[3], char *ndir_ptr,
float *green_out_ptr,
int nr_height, int nr_width,
int iwidth, int iheight,
int nr_topmargin, int nr_leftmargin,
unsigned int filters,
float channel_min_g, float channel_max_g,
const DHTNNNetwork &net)
{
ScopedLock guard(*this);
cl_int err = CL_SUCCESS;
int nw = 0;
cl_uint cl_filters = 0;
cl_float cl_channel_min_g = (cl_float)channel_min_g;
cl_float cl_channel_max_g = (cl_float)channel_max_g;
int arg = 0;
size_t global_size = 0;
size_t local_size = 0;
if (!init_device())
return false;
if (!build_kernel(net))
return false;
if (!upload_weights(net))
return false;
size_t nraw_bytes = (size_t)nr_height * nr_width * 3 * sizeof(float);
size_t ndir_bytes = (size_t)nr_height * nr_width;
size_t green_bytes = (size_t)iwidth * iheight * sizeof(float);
/* Reallocate per-image GPU buffers only when size changes */
if (nraw_bytes != cached_nraw_bytes || ndir_bytes != cached_ndir_bytes
|| green_bytes != cached_green_bytes)
{
release_image_bufs();
d_nraw_cached = clCreateBuffer(context, CL_MEM_READ_ONLY, nraw_bytes, NULL, &err);
if (err != CL_SUCCESS || !d_nraw_cached)
{
fprintf(stderr, "DHT-NN OpenCL: d_nraw buffer creation failed (err=%d)\n", err);
release_image_bufs(); return false;
}
d_ndir_in_cached = clCreateBuffer(context, CL_MEM_READ_ONLY, ndir_bytes, NULL, &err);
if (err != CL_SUCCESS || !d_ndir_in_cached)
{
fprintf(stderr, "DHT-NN OpenCL: d_ndir_in buffer creation failed (err=%d)\n", err);
release_image_bufs(); return false;
}
d_ndir_out_cached = clCreateBuffer(context, CL_MEM_READ_WRITE, ndir_bytes, NULL, &err);
if (err != CL_SUCCESS || !d_ndir_out_cached)
{
fprintf(stderr, "DHT-NN OpenCL: d_ndir_out buffer creation failed (err=%d)\n", err);
release_image_bufs(); return false;
}
d_green_out_cached = clCreateBuffer(context, CL_MEM_WRITE_ONLY, green_bytes, NULL, &err);
if (err != CL_SUCCESS || !d_green_out_cached)
{
fprintf(stderr, "DHT-NN OpenCL: d_green_out buffer creation failed (err=%d)\n", err);
release_image_bufs(); return false;
}
cached_nraw_bytes = nraw_bytes;
cached_ndir_bytes = ndir_bytes;
cached_green_bytes = green_bytes;
}
/* Upload per-image data */
err = clEnqueueWriteBuffer(queue, d_nraw_cached, CL_FALSE, 0, nraw_bytes, nraw_ptr, 0, NULL, NULL);
if (err != CL_SUCCESS) { fprintf(stderr, "DHT-NN OpenCL: nraw upload failed (err=%d)\n", err); return false; }
err = clEnqueueWriteBuffer(queue, d_ndir_in_cached, CL_FALSE, 0, ndir_bytes, ndir_ptr, 0, NULL, NULL);
if (err != CL_SUCCESS) { fprintf(stderr, "DHT-NN OpenCL: ndir_in upload failed (err=%d)\n", err); return false; }
err = clEnqueueWriteBuffer(queue, d_ndir_out_cached, CL_FALSE, 0, ndir_bytes, ndir_ptr, 0, NULL, NULL);
if (err != CL_SUCCESS) { fprintf(stderr, "DHT-NN OpenCL: ndir_out upload failed (err=%d)\n", err); return false; }
/* Set kernel arguments: per-image buffers + cached weights */
nw = net.num_weights();
cl_filters = (cl_uint)filters;
arg = 0;
err = clSetKernelArg(kernel, arg++, sizeof(cl_mem), &d_nraw_cached);
if (err != CL_SUCCESS) goto setarg_fail;
err = clSetKernelArg(kernel, arg++, sizeof(cl_mem), &d_ndir_in_cached);
if (err != CL_SUCCESS) goto setarg_fail;
err = clSetKernelArg(kernel, arg++, sizeof(cl_mem), &d_ndir_out_cached);
if (err != CL_SUCCESS) goto setarg_fail;
for (int l = 0; l < nw; l++)
{
err = clSetKernelArg(kernel, arg++, sizeof(cl_mem), &d_W_cached[l]);
if (err != CL_SUCCESS) goto setarg_fail;
err = clSetKernelArg(kernel, arg++, sizeof(cl_mem), &d_b_cached[l]);
if (err != CL_SUCCESS) goto setarg_fail;
}
err = clSetKernelArg(kernel, arg++, sizeof(int), &iwidth);
if (err != CL_SUCCESS) goto setarg_fail;
err = clSetKernelArg(kernel, arg++, sizeof(int), &iheight);
if (err != CL_SUCCESS) goto setarg_fail;
err = clSetKernelArg(kernel, arg++, sizeof(int), &nr_width);
if (err != CL_SUCCESS) goto setarg_fail;
err = clSetKernelArg(kernel, arg++, sizeof(int), &nr_topmargin);
if (err != CL_SUCCESS) goto setarg_fail;
err = clSetKernelArg(kernel, arg++, sizeof(int), &nr_leftmargin);
if (err != CL_SUCCESS) goto setarg_fail;
err = clSetKernelArg(kernel, arg++, sizeof(cl_uint), &cl_filters);
if (err != CL_SUCCESS) goto setarg_fail;
err = clSetKernelArg(kernel, arg++, sizeof(cl_mem), &d_green_out_cached);
if (err != CL_SUCCESS) goto setarg_fail;
err = clSetKernelArg(kernel, arg++, sizeof(cl_float), &cl_channel_min_g);
if (err != CL_SUCCESS) goto setarg_fail;
err = clSetKernelArg(kernel, arg++, sizeof(cl_float), &cl_channel_max_g);
if (err != CL_SUCCESS) goto setarg_fail;
/* Launch with local work size hint */
global_size = (size_t)iwidth * iheight;
local_size = preferred_wg_multiple;
/* Round global_size up to a multiple of local_size */
if (local_size > 0 && (global_size % local_size) != 0)
global_size += local_size - (global_size % local_size);
err = clEnqueueNDRangeKernel(queue, kernel, 1, NULL, &global_size, &local_size, 0, NULL, NULL);
if (err != CL_SUCCESS)
{
fprintf(stderr, "DHT-NN OpenCL: kernel launch failed (err=%d)\n", err);
goto cleanup;
}
/* Read back (both non-blocking, then explicit sync for DMA overlap) */
err = clEnqueueReadBuffer(queue, d_ndir_out_cached, CL_FALSE, 0, ndir_bytes, ndir_ptr, 0, NULL, NULL);
if (err != CL_SUCCESS)
{
fprintf(stderr, "DHT-NN OpenCL: ndir readback failed (err=%d)\n", err);
goto cleanup;
}
err = clEnqueueReadBuffer(queue, d_green_out_cached, CL_FALSE, 0, green_bytes, green_out_ptr, 0, NULL, NULL);
if (err != CL_SUCCESS)
{
fprintf(stderr, "DHT-NN OpenCL: green readback failed (err=%d)\n", err);
goto cleanup;
}
err = clFinish(queue);
if (err != CL_SUCCESS)
fprintf(stderr, "DHT-NN OpenCL: clFinish failed (err=%d)\n", err);
goto cleanup;
setarg_fail:
fprintf(stderr, "DHT-NN OpenCL: clSetKernelArg failed at arg %d (err=%d)\n", arg - 1, err);
cleanup:
return (err == CL_SUCCESS);
}
};
/* Lazy singleton */
static DHTNNCLContext &get_cl_context()
{
static DHTNNCLContext ctx;
return ctx;
}
#endif /* USE_OPENCL */
#endif /* DHT_NN_CL_H */
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment