Created
March 14, 2025 05:11
-
-
Save malfet/e7785e4b572e7740887a83a2386ef769 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Fail with Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" on M1/M2 (using MacOS 15.3.1) | |
// Works on M4 (and may be M3) | |
let shader_source = """ | |
template <typename T> | |
float bessel_j0_forward(T x) { | |
constexpr float PP[] = { | |
+7.96936729297347051624e-04, | |
+8.28352392107440799803e-02, | |
+1.23953371646414299388e+00, | |
+5.44725003058768775090e+00, | |
+8.74716500199817011941e+00, | |
+5.30324038235394892183e+00, | |
+9.99999999999999997821e-01, | |
}; | |
constexpr float PQ[] = { | |
+9.24408810558863637013e-04, | |
+8.56288474354474431428e-02, | |
+1.25352743901058953537e+00, | |
+5.47097740330417105182e+00, | |
+8.76190883237069594232e+00, | |
+5.30605288235394617618e+00, | |
+1.00000000000000000218e+00, | |
}; | |
constexpr float QP[] = { | |
-1.13663838898469149931e-02, | |
-1.28252718670509318512e+00, | |
-1.95539544257735972385e+01, | |
-9.32060152123768231369e+01, | |
-1.77681167980488050595e+02, | |
-1.47077505154951170175e+02, | |
-5.14105326766599330220e+01, | |
-6.05014350600728481186e+00, | |
}; | |
constexpr float QQ[] = { | |
+6.43178256118178023184e+01, | |
+8.56430025976980587198e+02, | |
+3.88240183605401609683e+03, | |
+7.24046774195652478189e+03, | |
+5.93072701187316984827e+03, | |
+2.06209331660327847417e+03, | |
+2.42005740240291393179e+02, | |
}; | |
constexpr float RP[] = { | |
-4.79443220978201773821e+09, | |
+1.95617491946556577543e+12, | |
-2.49248344360967716204e+14, | |
+9.70862251047306323952e+15, | |
}; | |
constexpr float RQ[] = { | |
+4.99563147152651017219e+02, | |
+1.73785401676374683123e+05, | |
+4.84409658339962045305e+07, | |
+1.11855537045356834862e+10, | |
+2.11277520115489217587e+12, | |
+3.10518229857422583814e+14, | |
+3.18121955943204943306e+16, | |
+1.71086294081043136091e+18, | |
}; | |
if (x < T(0)) { | |
x = -x; | |
} | |
if (x <= T(5.0)) { | |
if (x < T(0.00001)) { | |
return 1.0 - x * x / 4.0; | |
} | |
float rp = 0.0; | |
for (auto index = 0; index <= 3; index++) { | |
rp = rp * (x * x) + RP[index]; | |
} | |
float rq = 0.0; | |
for (auto index = 0; index <= 7; index++) { | |
rq = rq * (x * x) + RQ[index]; | |
} | |
return (x * x - 5.78318596294678452118e+00) * | |
(x * x - T(3.04712623436620863991e+01)) * rp / rq; | |
} | |
float pp = 0.0; | |
for (auto index = 0; index <= 6; index++) { | |
pp = pp * (25.0 / (x * x)) + PP[index]; | |
} | |
float pq = 0.0; | |
for (auto index = 0; index <= 6; index++) { | |
pq = pq * (25.0 / (x * x)) + PQ[index]; | |
} | |
float qp = 0.0; | |
for (auto index = 0; index <= 7; index++) { | |
qp = qp * (25.0 / (x * x)) + QP[index]; | |
} | |
float qq = 0.0; | |
for (auto index = 0; index <= 6; index++) { | |
qq = qq * (25.0 / (x * x)) + QQ[index]; | |
} | |
return (pp / pq * | |
::metal::precise::cos( | |
x - T(0.785398163397448309615660845819875721)) - | |
5.0 / x * (qp / qq) * | |
::metal::precise::sin( | |
x - 0.785398163397448309615660845819875721)) * | |
0.797884560802865355879892119868763737 / ::metal::precise::sqrt(x); | |
} // bessel_j0_forward(T x) | |
template <typename T> | |
float bessel_y0_forward(T x) { | |
constexpr float PP[] = { | |
+7.96936729297347051624e-04, | |
+8.28352392107440799803e-02, | |
+1.23953371646414299388e+00, | |
+5.44725003058768775090e+00, | |
+8.74716500199817011941e+00, | |
+5.30324038235394892183e+00, | |
+9.99999999999999997821e-01, | |
}; | |
constexpr float PQ[] = { | |
+9.24408810558863637013e-04, | |
+8.56288474354474431428e-02, | |
+1.25352743901058953537e+00, | |
+5.47097740330417105182e+00, | |
+8.76190883237069594232e+00, | |
+5.30605288235394617618e+00, | |
+1.00000000000000000218e+00, | |
}; | |
constexpr float QP[] = { | |
-1.13663838898469149931e-02, | |
-1.28252718670509318512e+00, | |
-1.95539544257735972385e+01, | |
-9.32060152123768231369e+01, | |
-1.77681167980488050595e+02, | |
-1.47077505154951170175e+02, | |
-5.14105326766599330220e+01, | |
-6.05014350600728481186e+00, | |
}; | |
constexpr float QQ[] = { | |
+6.43178256118178023184e+01, | |
+8.56430025976980587198e+02, | |
+3.88240183605401609683e+03, | |
+7.24046774195652478189e+03, | |
+5.93072701187316984827e+03, | |
+2.06209331660327847417e+03, | |
+2.42005740240291393179e+02, | |
}; | |
constexpr float YP[] = { | |
+1.55924367855235737965e+04, | |
-1.46639295903971606143e+07, | |
+5.43526477051876500413e+09, | |
-9.82136065717911466409e+11, | |
+8.75906394395366999549e+13, | |
-3.46628303384729719441e+15, | |
+4.42733268572569800351e+16, | |
-1.84950800436986690637e+16, | |
}; | |
constexpr float YQ[] = { | |
+1.04128353664259848412e+03, | |
+6.26107330137134956842e+05, | |
+2.68919633393814121987e+08, | |
+8.64002487103935000337e+10, | |
+2.02979612750105546709e+13, | |
+3.17157752842975028269e+15, | |
+2.50596256172653059228e+17, | |
}; | |
if (x <= T(5.0)) { | |
if (x == T(0.0)) { | |
return -INFINITY; | |
} | |
if (x < T(0.0)) { | |
return NAN; | |
} | |
float yp = 0.0; | |
for (auto index = 0; index <= 7; index++) { | |
yp = yp * (x * x) + YP[index]; | |
} | |
float yq = 0.0; | |
for (auto index = 0; index <= 6; index++) { | |
yq = yq * (x * x) + YQ[index]; | |
} | |
return yp / yq + | |
(0.636619772367581343075535053490057448 * ::metal::precise::log(x) * | |
bessel_j0_forward(x)); | |
} | |
float pp = 0.0; | |
for (auto index = 0; index <= 6; index++) { | |
pp = pp * (25.0 / (x * x)) + PP[index]; | |
} | |
float pq = 0.0; | |
for (auto index = 0; index <= 6; index++) { | |
pq = pq * (25.0 / (x * x)) + PQ[index]; | |
} | |
float qp = 0.0; | |
for (auto index = 0; index <= 7; index++) { | |
qp = qp * (25.0 / (x * x)) + QP[index]; | |
} | |
float qq = 0.0; | |
for (auto index = 0; index <= 6; index++) { | |
qq = qq * (25.0 / (x * x)) + QQ[index]; | |
} | |
return (pp / pq * | |
::metal::precise::sin( | |
x - 0.785398163397448309615660845819875721) + | |
5.0 / x * (qp / qq) * | |
::metal::precise::cos( | |
x - 0.785398163397448309615660845819875721)) * | |
0.797884560802865355879892119868763737 / ::metal::precise::sqrt(x); | |
} // bessel_y0_forward(T x) | |
kernel void bessel_y0( | |
device float* output [[buffer(0)]], | |
constant bool* input [[buffer(1)]], | |
uint index [[thread_position_in_grid]]) { | |
output[index] = bessel_y0_forward(float(input[index])); | |
} | |
""" | |
import Metal | |
guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") } | |
print(device.name) | |
let library = try! device.makeLibrary(source:shader_source, options:MTLCompileOptions()) | |
guard let mfunc = library.makeFunction(name: "bessel_y0") else { fatalError("Can't find function") } | |
let state = try! device.makeComputePipelineState(function: mfunc) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment