Created
March 10, 2019 07:10
-
-
Save rjenc29/3d3e70693f2fe68be1cf315d3d3a5092 to your computer and use it in GitHub Desktop.
Small change to np.interp to adapt behaviour in line with numpy version
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@register_jitable | |
def np_interp_impl_inner(x, xp, fp, dtype, non_finite_per_116): | |
# non_finite_per_116 -> if False, replicate the bug | |
# which existed in all versions to 1.16; if True, | |
# replicate numpy 1.16+ behaviour | |
x_arr = np.asarray(x) | |
xp_arr = np.asarray(xp) | |
fp_arr = np.asarray(fp) | |
if len(xp_arr) == 0: | |
raise ValueError('array of sample points is empty') | |
if len(xp_arr) != len(fp_arr): | |
raise ValueError('fp and xp are not of the same size.') | |
if xp_arr.size == 1: | |
return np.full(x_arr.shape, fill_value=fp_arr[0], dtype=dtype) | |
if not np.all(xp_arr[1:] > xp_arr[:-1]): | |
msg = 'xp must be monotonically increasing' | |
raise ValueError(msg) | |
# note: NumPy docs suggest this is required but it is not | |
# checked for or enforced; see: | |
# https://github.com/numpy/numpy/issues/10448 | |
# This check is quite expensive. | |
out = np.empty(x_arr.shape, dtype=dtype) | |
idx = 0 | |
# pre-cache slopes | |
slopes = (fp_arr[1:] - fp_arr[:-1]) / (xp_arr[1:] - xp_arr[:-1]) | |
for i in range(x_arr.size): | |
# shortcut if possible | |
if np.isnan(x_arr.flat[i]): | |
out.flat[i] = np.nan | |
continue | |
if x_arr.flat[i] >= xp_arr[-1]: | |
out.flat[i] = fp_arr[-1] | |
continue | |
if x_arr.flat[i] <= xp_arr[0]: | |
out.flat[i] = fp_arr[0] | |
continue | |
if xp_arr[idx - 1] < x_arr.flat[i] <= xp_arr[idx]: | |
pass | |
elif xp_arr[idx] < x_arr.flat[i] <= xp_arr[idx + 1]: | |
idx += 1 | |
elif xp_arr[idx - 2] < x_arr.flat[i] <= xp_arr[idx - 1]: | |
idx -= 1 | |
else: | |
idx = np.searchsorted(xp_arr, x_arr.flat[i]) | |
if x_arr.flat[i] == xp_arr[idx]: | |
if non_finite_per_116: | |
out.flat[i] = fp_arr[idx] | |
else: | |
if not np.isfinite(slopes[idx]): | |
out.flat[i] = np.nan | |
else: | |
out.flat[i] = fp_arr[idx] | |
else: | |
delta_x = x_arr.flat[i] - xp_arr[idx - 1] | |
out.flat[i] = fp_arr[idx - 1] + slopes[idx - 1] * delta_x | |
return out | |
if numpy_version >= (1, 10): | |
# replicate behaviour change of 1.10+ | |
@overload(np.interp) | |
def np_interp(x, xp, fp): | |
if hasattr(xp, 'ndim') and xp.ndim > 1: | |
raise TypingError('xp must be 1D') | |
if hasattr(fp, 'ndim') and fp.ndim > 1: | |
raise TypingError('fp must be 1D') | |
complex_dtype_msg = ( | |
"Cannot cast array data from complex dtype to float64 dtype" | |
) | |
xp_dt = determine_dtype(xp) | |
if np.issubdtype(xp_dt, np.complexfloating): | |
raise TypingError(complex_dtype_msg) | |
fp_dt = determine_dtype(fp) | |
dtype = np.result_type(fp_dt, np.float64) | |
NON_FINITE_PER_116 = numpy_version >= (1, 16) | |
def np_interp_impl(x, xp, fp): | |
return np_interp_impl_inner(x, xp, fp, dtype, NON_FINITE_PER_116) | |
def np_interp_scalar_impl(x, xp, fp): | |
return np_interp_impl_inner( | |
x, xp, fp, dtype, NON_FINITE_PER_116 | |
).flat[0] | |
if isinstance(x, types.Number): | |
if isinstance(x, types.Complex): | |
raise TypingError(complex_dtype_msg) | |
return np_interp_scalar_impl | |
return np_interp_impl |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment