Skip to content

Instantly share code, notes, and snippets.

@chausies
Last active July 9, 2024 08:51
Show Gist options
  • Save chausies/c453d561310317e7eda598e229aea537 to your computer and use it in GitHub Desktop.
Save chausies/c453d561310317e7eda598e229aea537 to your computer and use it in GitHub Desktop.
Simple Hermite Cubic Spline Interpolation and Integration implemented in Pytorch (with autograd support and fast runtime)
import torch as T
def h_poly_helper(tt):
A = T.tensor([
[1, 0, -3, 2],
[0, 1, -2, 1],
[0, 0, 3, -2],
[0, 0, -1, 1]
], dtype=tt[-1].dtype)
return [
sum( A[i, j]*tt[j] for j in range(4) )
for i in range(4) ]
def h_poly(t):
tt = [ None for _ in range(4) ]
tt[0] = 1
for i in range(1, 4):
tt[i] = tt[i-1]*t
return h_poly_helper(tt)
def H_poly(t):
tt = [ None for _ in range(4) ]
tt[0] = t
for i in range(1, 4):
tt[i] = tt[i-1]*t*i/(i+1)
return h_poly_helper(tt)
def interp_func(x, y):
"Returns integral of interpolating function"
if len(y)>1:
m = (y[1:] - y[:-1])/(x[1:] - x[:-1])
m = T.cat([m[[0]], (m[1:] + m[:-1])/2, m[[-1]]])
def f(xs):
if len(y)==1: # in the case of 1 point, treat as constant function
return y[0] + T.zeros_like(xs)
I = T.searchsorted(x[1:], xs)
dx = (x[I+1]-x[I])
hh = h_poly((xs-x[I])/dx)
return hh[0]*y[I] + hh[1]*m[I]*dx + hh[2]*y[I+1] + hh[3]*m[I+1]*dx
return f
def interp(x, y, xs):
return interp_func(x,y)(xs)
def integ_func(x, y):
"Returns interpolating function"
if len(y)>1:
m = (y[1:] - y[:-1])/(x[1:] - x[:-1])
m = T.cat([m[[0]], (m[1:] + m[:-1])/2, m[[-1]]])
Y = T.zeros_like(y)
Y[1:] = (x[1:]-x[:-1])*(
(y[:-1]+y[1:])/2 + (m[:-1] - m[1:])*(x[1:]-x[:-1])/12
)
Y = Y.cumsum(0)
def f(xs):
if len(y)==1:
return y[0]*(xs - x[0])
I = T.searchsorted(x[1:], xs)
dx = (x[I+1]-x[I])
hh = H_poly((xs-x[I])/dx)
return Y[I] + dx*(
hh[0]*y[I] + hh[1]*m[I]*dx + hh[2]*y[I+1] + hh[3]*m[I+1]*dx
)
return f
def integ(x, y, xs):
return integ_func(x,y)(xs)
# Example
# See https://i.stack.imgur.com/zgA0s.png for resulting image
if __name__ == "__main__":
import matplotlib.pylab as P # for plotting
x = T.linspace(0, 6, 7)
y = x.sin()
xs = T.linspace(0, 6, 101)
ys = interp(x, y, xs)
Ys = integ(x, y, xs)
P.scatter(x, y, label='Samples', color='purple')
P.plot(xs, ys, label='Interpolated curve')
P.plot(xs, xs.sin(), '--', label='True Curve')
P.plot(xs, Ys, label='Spline Integral')
P.plot(xs, 1-xs.cos(), '--', label='True Integral')
P.legend()
P.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment