Last active
November 12, 2015 08:57
-
-
Save fdeheeger/642ba27c666a497d039d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
################################################################################ | |
# Interpolation class | |
# TODO : use a nicer n-dim method (like multilinear interpolation) | |
from scipy.interpolate import RectBivariateSpline, UnivariateSpline | |
from dolointerpolation.multilinear_cython import multilinear_interpolation | |
class UnivariateSpline(UnivariateSpline): | |
'''extended UnivariateSpline class, | |
where spline evaluation works uses input broadcast | |
and returns an output with a coherent shape. | |
''' | |
#@profile | |
def __call__(self, *x): | |
# flatten the inputs after saving their shape: | |
shape = np.array(x).shape | |
# Evaluate the spline and reconstruct the dimension: | |
z = super(UnivariateSpline, self).__call__(np.ravel(x)) | |
return z.reshape(shape) | |
#----- | |
#----- | |
def interp_on_state(self, A): | |
'''returns an interpolating function of matrix A, assuming that A | |
is expressed on the state grid `self.state_grid` | |
the shape of A should be (len(g) for g in self.state_grid) | |
''' | |
# Check the dimension of A: | |
expect_shape = self._state_grid_shape | |
if A.shape != expect_shape: | |
raise ValueError('array `A` should be of shape {:s}, not {:s}'.format( | |
str(expect_shape), str(A.shape)) ) | |
if len(expect_shape) == 1: | |
A_interp = UnivariateSpline(self.state_grid[0], A, ext=3) | |
return A_interp | |
elif len(expect_shape) <= 5: | |
A_interp = MlinInterpolator(*self.state_grid) | |
A_interp.set_values(A) | |
return A_interp | |
# if len(expect_shape) == 2: | |
# x1_grid = self.state_grid[0] | |
# x2_grid = self.state_grid[1] | |
# A_interp = RectBivariateSplineBc(x1_grid, x2_grid, A, kx=1, ky=1) | |
# return A_interp | |
else: | |
raise NotImplementedError('interpolation for state dimension >5' | |
' is not implemented.') | |
# end interp_on_state() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
As to adding the UnivariateSpline interpolator, it is a very useful addition, since it removes the need to compile the multilinear_cython module at least for 1D case. However, it would be better to have a way to choose which interpolator to use. What do you think?