Last active
May 15, 2024 13:06
-
-
Save madmann91/2ae76df7b49cdd5a11c0f94c1aeb9ee6 to your computer and use it in GitHub Desktop.
Solver for quartic and cubic polynomials
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <tgmath.h> | |
#include <stdint.h> | |
#include <stdbool.h> | |
#include <assert.h> | |
#include <string.h> | |
#define PI 3.141592653589793f | |
#define SEARCH_RADIUS 1.e-5f // Initial search radius for the local search procedure | |
static inline float sq(float x) { return x * x; } | |
static inline float cb(float x) { return x * x * x; } | |
static inline int solve_linear(float a0, float* z) { | |
// Solves x + a0 = 0 | |
z[0] = -a0; | |
return 1; | |
} | |
static inline int solve_quadratic(float a1, float a0, float* z) { | |
// Solves x^2 + a1 * x + a0 = 0 | |
const float d = sq(a1) - 4.f * a0; | |
if (d < 0.f) | |
return 0; | |
const float s = sqrt(d); | |
z[0] = 0.5f * (-a1 + s); | |
z[1] = 0.5f * (-a1 - s); | |
return d == 0.f ? 1 : 2; | |
} | |
static inline int solve_cubic(float a2, float a1, float a0, float* z) { | |
// Solves x^3 + a2 * x^2 + a1 * x + a0 = 0 | |
// Inspired from "Practical Algorithm for Solving the Cubic Equation", D. J. Wolters, 2021 | |
const float q = (3.f * a1 - sq(a2)) / 9.f; | |
const float r = (9.f * a1 * a2 - 27.f * a0 - 2.f * cb(a2)) / 54.f; | |
if (sq(r) + cb(q) > 0.f) { | |
const float a = cbrt(fabs(r) + sqrt(sq(r) + cb(q))); | |
const float t = a - q / a; | |
const float t1 = r < 0 ? -t : t; | |
z[0] = t1 - a2 / 3.f; | |
return 1; | |
} | |
const float theta = q == 0.f ? 0.f : acos(r / sqrt(cb(-q))); | |
const float phi1 = theta / 3.f; | |
const float phi2 = phi1 - 2.f * PI / 3.f; | |
const float phi3 = phi1 + 2.f * PI / 3.f; | |
const float k1 = 2.f * sqrt(-q); | |
const float k2 = a2 / 3.f; | |
z[0] = k1 * cos(phi1) - k2; | |
z[1] = k1 * cos(phi2) - k2; | |
z[2] = k1 * cos(phi3) - k2; | |
return 3; | |
} | |
static inline float find_largest_cubic_root(float a2, float a1, float a0) { | |
float z[3]; | |
int n = solve_cubic(a2, a1, a0, z); | |
float r = 0.f; | |
if (n > 0) r = fmax(r, z[0]); | |
if (n > 1) r = fmax(r, z[1]); | |
if (n > 2) r = fmax(r, z[2]); | |
return r; | |
} | |
static inline int solve_quartic(float a3, float a2, float a1, float a0, float* z) { | |
// Solves x^4 + a3 * x^3 + a2 * x^2 + a1 * x + a0 = 0 | |
// Inspired from "Practical Algorithms for Solving the Quartic Equation", D. J. Wolters, 2020 | |
const float c = a3 * 0.25f; | |
const float b2 = a2 - 6.f * sq(c); | |
const float b1 = a1 + c * (-2.f * a2 + 8.f * sq(c)); | |
const float b0 = a0 + c * (-a1 + c * (a2 - 3.f * sq(c))); | |
const float m = find_largest_cubic_root(b2, sq(b2) * 0.25f - b0, sq(b1) * -0.125f); | |
const float r1 = sqrt(sq(m) + b2 * m + sq(b2) * 0.25f - b0); | |
const float r = b1 > 0.f ? r1 : -r1; | |
const float l = sqrt(m * 0.5f); | |
const float k = m * -0.5f - b2 * 0.5f; | |
const bool has_k12 = k - r >= 0.f; | |
const bool has_k34 = k + r >= 0.f; | |
const float k12 = has_k12 ? sqrt(k - r) : 0.f; | |
const float k34 = has_k34 ? sqrt(k + r) : 0.f; | |
const float z1 = l - c + k12; | |
const float z2 = l - c - k12; | |
const float z3 = -l - c + k34; | |
const float z4 = -l - c - k34; | |
if (has_k12 || has_k34) { | |
z[0] = has_k12 ? z1 : z3; | |
z[1] = has_k12 ? z2 : z4; | |
if (has_k34) { | |
z[2] = z3; | |
z[3] = z4; | |
return has_k12 ? 4 : 2; | |
} | |
return 2; | |
} | |
return 0; | |
} | |
float eval(float a4, float a3, float a2, float a1, float a0, float x) { | |
return a0 + x * (a1 + x * (a2 + x * (a3 + x * a4))); | |
} | |
float eval_diff(float a4, float a3, float a2, float a1, float x) { | |
return a1 + x * (2.f * a2 + x * (3.f * a3 + x * (4.f * a4))); | |
} | |
float local_search(float a4, float a3, float a2, float a1, float a0, float x, size_t max_iters) { | |
// Finds a zero of q(x) = a4 * x^4 + a3 * x^3 + a2 * x^2 + a1 * x^1 + a0 using a local search | |
// starting at the given point. This can be used to improve the quality of an initial estimate. | |
// Find initial bracket around x such that the signs of q(a) and q(b) are different. This | |
// interval is guaranteed to contain a zero for q(x). | |
float qx = eval(a4, a3, a2, a1, a0, x); | |
float a, b, qa, qb; | |
for (float radius = SEARCH_RADIUS;; radius *= 2.f) { | |
a = x - radius; | |
b = x + radius; | |
qa = eval(a4, a3, a2, a1, a0, a); | |
qb = eval(a4, a3, a2, a1, a0, b); | |
// Move x to the point closest to 0 | |
if (fabs(qa) < fabs(qx)) | |
x = a, qx = qa; | |
if (fabs(qb) < fabs(qx)) | |
x = b, qx = qb; | |
if (signbit(qa) != signbit(qb)) | |
break; | |
} | |
// Tighten the bracket around 0: Use [x, b] or [a, x] instead of [a, b]. | |
if (signbit(qa) == signbit(qx)) | |
a = x, qa = qx; | |
else | |
b = x, qb = qx; | |
// Apply several rounds of bisection or Newton-Raphson, whichever is best | |
for (size_t iters = 0; iters < max_iters; ++iters) { | |
float m = fabs(qa) < fabs(qb) | |
? a - qa / eval_diff(a4, a3, a2, a1, a) | |
: b - qb / eval_diff(a4, a3, a2, a1, b); | |
// Use bisection if Newton-Raphson takes us outside the interval | |
if (m < a || m > b) | |
m = (a + b) / 2; | |
const float qm = eval(a4, a3, a2, a1, a0, m); | |
if (signbit(qm) == signbit(qa)) | |
a = m, qa = qm; | |
else | |
b = m, qb = qm; | |
// Pick the value closest to 0 | |
if (fabs(qm) < fabs(qx)) | |
x = m, qx = qm; | |
} | |
return x; | |
} | |
int solve(float a4, float a3, float a2, float a1, float a0, float* z) { | |
// Finds the real roots of the polynomial a4 * x^4 + a3 * x^3 + a2 * x^2 + a1 * x + a0 | |
// Note: a4, a3, a2, a1 and a0 can each be zero | |
if (a4 == 0.f) { | |
if (a3 == 0.f) { | |
if (a2 == 0.f) { | |
if (a1 == 0.f) | |
return 0; | |
return solve_linear(a0 / a1, z); | |
} | |
const float inv_a2 = 1.f / a2; | |
return solve_quadratic(a1 * inv_a2, a0 * inv_a2, z); | |
} | |
const float inv_a3 = 1.f / a3; | |
return solve_cubic(a2 * inv_a3, a1 * inv_a3, a0 * inv_a3, z); | |
} | |
const float inv_a4 = 1.f / a4; | |
return solve_quartic(a3 * inv_a4, a2 * inv_a4, a1 * inv_a4, a0 * inv_a4, z); | |
} | |
static inline float next_ulp(float x) { | |
uint32_t y; | |
memcpy(&y, &x, sizeof(y)); | |
y++; | |
memcpy(&x, &y, sizeof(x)); | |
return x; | |
} | |
static inline float prev_ulp(float x) { | |
uint32_t y; | |
memcpy(&y, &x, sizeof(y)); | |
y--; | |
memcpy(&x, &y, sizeof(x)); | |
return x; | |
} | |
static inline complex float eval_complex(complex float a3, complex float a2, complex float a1, complex float a0, complex float x) { | |
return a0 + x * (a1 + x * (a2 + x * (a3 + x))); | |
} | |
void solve_complex(complex float p[], complex float z[], size_t n, size_t iters) { | |
// Solve sum(p[i] * x^i, i = 0..n) = 0 with Aberth-Erhlich | |
static const complex float b = 0.4f + 0.9 * I; | |
assert(n > 0); | |
z[0] = 1.f; | |
for (size_t i = 1; i < n; ++i) | |
z[i] = z[i - 1] * b; | |
for (size_t i = 0; i < iters; ++i) { | |
for (size_t j = 0; j < n; ++j) { | |
const complex float x = z[j]; | |
complex float y = p[n]; | |
complex float d = p[n] * (float)n; | |
for (size_t k = n - 1; k > 0; --k) { | |
y = y * x + p[k]; | |
d = d * x + p[k] * (float)k; | |
} | |
y = y * x + p[0]; | |
complex float denom = d / y; | |
for (size_t k = 0; k < n; ++k) { | |
if (k == j) | |
continue; | |
denom -= 1.f / (z[j] - z[k]); | |
} | |
z[j] -= 1.f / denom; | |
} | |
} | |
} | |
int main() { | |
#define COEFFS 1.f, -1000.f, 4.f, -40.f | |
complex float zc[4] = {}; | |
complex float pc[5] = { 1.f, COEFFS }; | |
for (size_t i = 0, n = sizeof(pc) / sizeof(pc[0]); i < n / 2; ++i) { | |
complex float p = pc[i]; | |
pc[i] = pc[n - i - 1]; | |
pc[n - i - 1] = p; | |
} | |
solve_complex(pc, zc, 4, 8); | |
for (size_t i = 0; i < 4; ++i) { | |
complex float y = eval_complex(COEFFS, zc[i]); | |
printf("p(%f + %fi) = %f + %fi\n", creal(zc[i]), cimag(zc[i]), creal(y), cimag(y)); | |
} | |
printf("----\n"); | |
float z[4] = {}; | |
int n = solve(1.f, COEFFS, z); | |
for (size_t i = 0; i < n; ++i) { | |
z[i] = local_search(1.f, COEFFS, z[i], 4); | |
printf("p(%f) = %f\n", z[i], eval(1.f, COEFFS, z[i])); | |
} | |
printf("----\n"); | |
for (size_t i = 0; i < n; ++i) { | |
float ref = fabs(eval(1.f, COEFFS, z[i])); | |
float next = fabs(eval(1.f, COEFFS, next_ulp(z[i]))); | |
float prev = fabs(eval(1.f, COEFFS, prev_ulp(z[i]))); | |
printf("|p(%f)| = %f", z[i], ref); | |
if (next >= ref && prev >= ref) | |
printf(" IS OPTIMAL\n"); | |
else { | |
printf(" IS SUB-OPTIMAL:\n"); | |
if (prev < ref) | |
printf("|p(%f)| = %f\n", prev_ulp(z[i]), prev); | |
if (next < ref) | |
printf("|p(%f)| = %f\n", next_ulp(z[i]), next); | |
} | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment