Last active
July 9, 2024 08:51
-
-
Save chausies/c453d561310317e7eda598e229aea537 to your computer and use it in GitHub Desktop.
Simple Hermite Cubic Spline Interpolation and Integration implemented in Pytorch (with autograd support and fast runtime)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch as T | |
def h_poly_helper(tt): | |
A = T.tensor([ | |
[1, 0, -3, 2], | |
[0, 1, -2, 1], | |
[0, 0, 3, -2], | |
[0, 0, -1, 1] | |
], dtype=tt[-1].dtype) | |
return [ | |
sum( A[i, j]*tt[j] for j in range(4) ) | |
for i in range(4) ] | |
def h_poly(t): | |
tt = [ None for _ in range(4) ] | |
tt[0] = 1 | |
for i in range(1, 4): | |
tt[i] = tt[i-1]*t | |
return h_poly_helper(tt) | |
def H_poly(t): | |
tt = [ None for _ in range(4) ] | |
tt[0] = t | |
for i in range(1, 4): | |
tt[i] = tt[i-1]*t*i/(i+1) | |
return h_poly_helper(tt) | |
def interp_func(x, y): | |
"Returns integral of interpolating function" | |
if len(y)>1: | |
m = (y[1:] - y[:-1])/(x[1:] - x[:-1]) | |
m = T.cat([m[[0]], (m[1:] + m[:-1])/2, m[[-1]]]) | |
def f(xs): | |
if len(y)==1: # in the case of 1 point, treat as constant function | |
return y[0] + T.zeros_like(xs) | |
I = T.searchsorted(x[1:], xs) | |
dx = (x[I+1]-x[I]) | |
hh = h_poly((xs-x[I])/dx) | |
return hh[0]*y[I] + hh[1]*m[I]*dx + hh[2]*y[I+1] + hh[3]*m[I+1]*dx | |
return f | |
def interp(x, y, xs): | |
return interp_func(x,y)(xs) | |
def integ_func(x, y): | |
"Returns interpolating function" | |
if len(y)>1: | |
m = (y[1:] - y[:-1])/(x[1:] - x[:-1]) | |
m = T.cat([m[[0]], (m[1:] + m[:-1])/2, m[[-1]]]) | |
Y = T.zeros_like(y) | |
Y[1:] = (x[1:]-x[:-1])*( | |
(y[:-1]+y[1:])/2 + (m[:-1] - m[1:])*(x[1:]-x[:-1])/12 | |
) | |
Y = Y.cumsum(0) | |
def f(xs): | |
if len(y)==1: | |
return y[0]*(xs - x[0]) | |
I = T.searchsorted(x[1:], xs) | |
dx = (x[I+1]-x[I]) | |
hh = H_poly((xs-x[I])/dx) | |
return Y[I] + dx*( | |
hh[0]*y[I] + hh[1]*m[I]*dx + hh[2]*y[I+1] + hh[3]*m[I+1]*dx | |
) | |
return f | |
def integ(x, y, xs): | |
return integ_func(x,y)(xs) | |
# Example | |
# See https://i.stack.imgur.com/zgA0s.png for resulting image | |
if __name__ == "__main__": | |
import matplotlib.pylab as P # for plotting | |
x = T.linspace(0, 6, 7) | |
y = x.sin() | |
xs = T.linspace(0, 6, 101) | |
ys = interp(x, y, xs) | |
Ys = integ(x, y, xs) | |
P.scatter(x, y, label='Samples', color='purple') | |
P.plot(xs, ys, label='Interpolated curve') | |
P.plot(xs, xs.sin(), '--', label='True Curve') | |
P.plot(xs, Ys, label='Spline Integral') | |
P.plot(xs, 1-xs.cos(), '--', label='True Integral') | |
P.legend() | |
P.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment