Skip to content

Instantly share code, notes, and snippets.

@cmpute
Created June 12, 2020 20:29
Show Gist options
  • Save cmpute/f80182a18ed6cd836c20cf67d73209db to your computer and use it in GitHub Desktop.
Save cmpute/f80182a18ed6cd836c20cf67d73209db to your computer and use it in GitHub Desktop.
Chebyshev estimation of modified Bessel function
# This scripts calculates the Chebyshev approximation of Bessel functions
#
# Chebyshev approximation intro:
# - https://en.wikipedia.org/wiki/Approximation_theory#Chebyshev_approximation
# - https://mathworld.wolfram.com/ChebyshevApproximationFormula.html
# - http://www.chebfun.org/examples/cheb/ExactChebCoeffs.html
# - https://www.eeeguide.com/chebyshev-approximation/
#
# Online modified Bessel function calculator
# - https://www.wolframalpha.com/input/?i=BesselI%5B0%2Cx%5D
# - https://keisan.casio.com/exec/system/1180573473
#
import numpy as np
import mpmath as mp
import scipy.special as sps
from matplotlib import pyplot as plt
from fractions import Fraction
# configurations
mp.dps=50 # mp precision
split=8 # location to split the function
xval=np.linspace(0, 30, 500) # evaluation locations
def i0(x): # 0-order modified Bessel function
item = lambda k: (x*x/4)**k/(mp.fac(k)**2)
return mp.nsum(item, [0, mp.inf])
def di0(x): # first order derivative of i0
return mp.diff(i0, x)
def chebycoeff(f, a, b, j, N):
''' ported from mpmath.chebyfit, f is mapped from [a,b] to [-1,1] '''
s, h = mp.mpf(0), mp.mpf(0.5)
for k in range(1, N+1):
s += f(mp.cospi((k-h)/N) * (b-a)*h + (b+a)*h) * mp.cospi(j*(k-h)/N)
return 2*s/N
def chebyval(coeffs, x, a, b):
''' evaluate chebyshev series with coeffs from c_N to c_0 '''
x = (2*x - a - b) / (b - a) * 2
b0 = b1 = b2 = 0
for c in coeffs:
b2, b1 = b1, b0
b0 = x * b1 - b2 + c
return (b0 - b2) / 2
def i0_cheby_est(NL=30, NR=25):
'''estimate i0 using chebyshev expansion with i0 split into two parts'''
# left part
fleft = lambda x: i0(x)*mp.exp(-x)
a, b = mp.mpf(0), mp.mpf(split)
dleft = [chebycoeff(fleft, a, b, k, NL) for k in range(NL)]
d = [float(di) for di in dleft[::-1]]
xleft = xval[xval <= 8]
gt = sps.i0(xleft)
est = chebyval(d, xleft, 0, split) * np.exp(xleft)
err = np.abs(gt-est)
print("Mean err (left part):", np.mean(err))
# right part
n2 = mp.mpf(split*split)
fright = lambda x: i0(n2/x)*mp.exp(-n2/x)*mp.sqrt(n2/x)
a, b = mp.mpf(0), mp.mpf(split)
dright = [chebycoeff(fright, a, b, k, NR) for k in range(NR)]
d = [float(di) for di in dright[::-1]]
xright = xval[xval >= 8]
gt = sps.i0(xright)
est = chebyval(d, float(n2)/xright, 0, split) * np.exp(xright) / np.sqrt(xright)
err = np.abs(gt-est)
print("Mean err (right part):", np.mean(err))
def i0_taylor_est(N=16):
ds = mp.taylor(i0, 0, N)
for i, di in enumerate(ds):
df = Fraction(float(di)).limit_denominator(8**i)
if df != 0:
print(f"Coeff at x^{i}: {df}")
gt = sps.i0(xval)
est = sum(xval**i*di for i, di in enumerate(ds))
err = np.abs(gt-est)
print(f"Max error in (0, {split}): {np.max(err)}")
print(f"Mean error in (0, {split}):, {np.mean(err)}")
def i0_plots():
# plot i0 related functions
vi0 = sps.i0(xval)
vi0e = vi0 * np.exp(-xval)
xinv = 64 / xval
vi0esq = sps.i0(xinv) * np.exp(-xinv) * np.sqrt(xinv)
plt.subplot(3,1,1)
plt.semilogy(xval, vi0)
plt.grid(True)
plt.subplot(3,1,2)
plt.plot(xval, vi0e)
plt.grid(True)
plt.subplot(3,1,3)
plt.plot(xinv, vi0esq)
plt.grid(True)
plt.xlim(0, 30)
plt.show()
i0_cheby_est()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment