Skip to content

Instantly share code, notes, and snippets.

@castano
Created June 28, 2025 00:37
Show Gist options
  • Save castano/6b64fa41ab2037611630cd5112d82415 to your computer and use it in GitHub Desktop.
Save castano/6b64fa41ab2037611630cd5112d82415 to your computer and use it in GitHub Desktop.
Recursive Implementation of the Gaussian Filter Using Truncated Cosine Functions
// Implements "Recursive Implementation of the Gaussian Filter Using Truncated Cosine Functions" by Charalampidis [2016].
// https://discovery.researcher.life/article/recursive-implementation-of-the-gaussian-filter-using-truncated-cosine-functions/dcf24675f5eb30dba93c5205cdae3c40
// This code is based on:
// https://github.com/cloudinary/ssimulacra2/blob/main/src/lib/jxl/gauss_blur.cc
// Copyright (c) the JPEG XL Project Authors. All rights reserved.
struct RecursiveGaussian {
RecursiveGaussian(float sigma);
float mul_in[3];
float mul_prev[3];
float mul_prev2[3];
size_t radius;
};
RecursiveGaussian::RecursiveGaussian(float sigma) {
constexpr double kPi = 3.141592653589793238;
const double radius = roundf(3.2795 * sigma + 0.2546); // (57), "N"
// Table I, first row
const double pi_div_2r = kPi / (2.0 * radius);
const double omega[3] = {pi_div_2r, 3.0 * pi_div_2r, 5.0 * pi_div_2r};
// (37), k={1,3,5}
const double p_1 = +1.0 / tan(0.5 * omega[0]);
const double p_3 = -1.0 / tan(0.5 * omega[1]);
const double p_5 = +1.0 / tan(0.5 * omega[2]);
// (44), k={1,3,5}
const double r_1 = +p_1 * p_1 / sin(omega[0]);
const double r_3 = -p_3 * p_3 / sin(omega[1]);
const double r_5 = +p_5 * p_5 / sin(omega[2]);
// (50), k={1,3,5}
const double neg_half_sigma2 = -0.5 * sigma * sigma;
const double recip_radius = 1.0 / radius;
double rho[3];
for (size_t i = 0; i < 3; ++i) {
rho[i] = exp(neg_half_sigma2 * omega[i] * omega[i]) * recip_radius;
}
// second part of (52), k1,k2 = 1,3; 3,5; 5,1
const double D_13 = p_1 * r_3 - r_1 * p_3;
const double D_35 = p_3 * r_5 - r_3 * p_5;
const double D_51 = p_5 * r_1 - r_5 * p_1;
// (52), k=5
const double recip_d13 = 1.0 / D_13;
const double zeta_15 = D_35 * recip_d13;
const double zeta_35 = D_51 * recip_d13;
double A[9] = {p_1, p_3, p_5, //
r_1, r_3, r_5, // (56)
zeta_15, zeta_35, 1};
JXL_CHECK(Inv3x3Matrix(A));
const double gamma[3] = {1, radius * radius - sigma * sigma, // (55)
zeta_15 * rho[0] + zeta_35 * rho[1] + rho[2]};
double beta[3];
MatMul(A, gamma, 3, 3, 1, beta); // (53)
// Sanity check: correctly solved for beta (IIR filter weights are normalized)
const double sum = beta[0] * p_1 + beta[1] * p_3 + beta[2] * p_5; // (39)
JXL_CHECK(abs(sum - 1) < 1E-12);
this->radius = static_cast<int>(radius);
double n2[3];
double d1[3];
for (size_t i = 0; i < 3; ++i) {
n2[i] = -beta[i] * cos(omega[i] * (radius + 1.0)); // (33)
d1[i] = -2.0 * cos(omega[i]); // (33)
this->mul_in[i] = n2[i];
this->mul_prev[i] = -d1[i];
this->mul_prev2[i] = -1.0;
}
}
static void Gaussian1D(const RecursiveGaussian& rg, const float* JXL_RESTRICT in, intptr_t width, intptr_t stride, float* JXL_RESTRICT out) {
const float mul_in_1 = rg.mul_in[0];
const float mul_in_3 = rg.mul_in[1];
const float mul_in_5 = rg.mul_in[2];
const float mul_prev_1 = rg.mul_prev[0];
const float mul_prev_3 = rg.mul_prev[1];
const float mul_prev_5 = rg.mul_prev[2];
const float mul_prev2_1 = rg.mul_prev2[0];
const float mul_prev2_3 = rg.mul_prev2[1];
const float mul_prev2_5 = rg.mul_prev2[2];
float prev_1 = 0.0f;
float prev_3 = 0.0f;
float prev_5 = 0.0f;
float prev2_1 = 0.0f;
float prev2_3 = 0.0f;
float prev2_5 = 0.0f;
const intptr_t N = rg.radius;
for (intptr_t n = -N + 1; n < width; n++) {
const intptr_t left = n - N - 1;
const intptr_t right = n + N - 1;
const float left_val = left >= 0 ? in[left * stride] : 0.0f;
const float right_val = right < width ? in[right * stride] : 0.0f;
const float sum = left_val + right_val;
float out_1 = sum * mul_in_1 + mul_prev2_1 * prev2_1 + mul_prev_1 * prev_1;
float out_3 = sum * mul_in_3 + mul_prev2_3 * prev2_3 + mul_prev_3 * prev_3;
float out_5 = sum * mul_in_5 + mul_prev2_5 * prev2_5 + mul_prev_5 * prev_5;
prev2_1 = prev_1; prev_1 = out_1;
prev2_3 = prev_3; prev_3 = out_3;
prev2_5 = prev_5; prev_5 = out_5;
if (n >= 0) {
out[n * stride] = out_1 + out_3 + out_5;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment