Floating point deviation in jax.numpy.percentile
with linear interpolation between v0.2.20
and v0.2.21
Hi. There is some (very minor) deviations in the output of jax.numpy.percentile
between jax
v0.2.20
and v0.2.21
in the case that linear interpolation is used (the default). Interestingly, it is really in jax.numpy.percentile
and not in jax.numpy.quantile
as can be shown in the included example (for convenience this Issue also exists as a GitHub Gist).
# example.py
import jax
import jax.numpy as jnp
import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)
if __name__ == "__main__":
# percentile interpolation options:
# This optional parameter specifies the interpolation method to use when the desired percentile lies between two data points i < j:
# * ’linear’: i + (j - i) * fraction, where fraction is the fractional part of the index surrounded by i and j.
# * ’lower’: i.
# * ’higher’: j.
# * ’nearest’: i or j, whichever is nearest.
# * ’midpoint’: (i + j) / 2.
input = [[10, 7, 4], [3, 2, 1]]
print(f"input list: {input}")
print(f"input list ravel: {np.asarray(input).ravel()}")
# [10 7 4 3 2 1]
print(f"\nNumPy v{np.__version__}")
print(f"JAX v{jax.__version__}\n")
numpy_array = np.asarray(input)
print(f"{numpy_array=}")
jax_array = jnp.asarray(input, dtype="float")
print(f"{jax_array=}")
print("\n# Checking quantile\n")
assert np.quantile(numpy_array, 0) == 1.0
assert np.quantile(numpy_array, 0.50) == 3.5
assert np.quantile(numpy_array, 1) == 10
assert np.quantile(numpy_array, 0.50, axis=1).tolist() == [7.0, 2.0]
assert np.quantile(numpy_array, 0.50, interpolation="linear") == 3.5
assert np.quantile(numpy_array, 0.50, interpolation="nearest") == 3.0
assert np.quantile(numpy_array, 0.50, interpolation="lower") == 3.0
assert np.quantile(numpy_array, 0.50, interpolation="midpoint") == 3.5
assert np.quantile(numpy_array, 0.50, interpolation="higher") == 4.0
assert jnp.quantile(jax_array, 0) == 1.0
assert jnp.quantile(jax_array, 0.50) == 3.5
assert jnp.quantile(jax_array, 1) == 10
assert jnp.quantile(jax_array, 0.50, axis=1).tolist() == [7.0, 2.0]
assert jnp.quantile(jax_array, 0.50, interpolation="linear") == 3.5
assert jnp.quantile(jax_array, 0.50, interpolation="nearest") == 3.0
assert jnp.quantile(jax_array, 0.50, interpolation="lower") == 3.0
assert jnp.quantile(jax_array, 0.50, interpolation="midpoint") == 3.5
assert jnp.quantile(jax_array, 0.50, interpolation="higher") == 4.0
print("# Checking percentile")
assert np.percentile(numpy_array, 0) == 1.0
assert np.percentile(numpy_array, 50) == 3.5
assert np.percentile(numpy_array, 100) == 10
assert np.percentile(numpy_array, 50, axis=1).tolist() == [7.0, 2.0]
assert np.percentile(numpy_array, 50, interpolation="linear") == 3.5
assert np.percentile(numpy_array, 50, interpolation="nearest") == 3.0
assert np.percentile(numpy_array, 50, interpolation="lower") == 3.0
assert np.percentile(numpy_array, 50, interpolation="midpoint") == 3.5
assert np.percentile(numpy_array, 50, interpolation="higher") == 4.0
# default interpolation method is "linear"
assert jnp.percentile(jax_array, 0) == 1.0
assert jnp.percentile(jax_array, 50) == 3.5 # 3.499999761581421
assert jnp.percentile(jax_array, 100) == 10 # 9.999998092651367
assert jnp.percentile(jax_array, 50, axis=1).tolist() == [7.0, 2.0]
assert jnp.percentile(jax_array, 50, interpolation="linear") == 3.5 # 3.499999761581421
assert jnp.percentile(jax_array, 50, interpolation="nearest") == 3.0
assert jnp.percentile(jax_array, 50, interpolation="lower") == 3.0
assert jnp.percentile(jax_array, 50, interpolation="midpoint") == 3.5
assert jnp.percentile(jax_array, 50, interpolation="higher") == 4.0
$ python --version
Python 3.9.6
$ python -m venv /tmp/venv && . /tmp/venv/bin/activate
(venv) $ python -m pip install --upgrade pip setuptools wheel
(venv) $ cat requirements_passing.txt
jax==0.2.20
jaxlib==0.1.69
(venv) $ python -m pip install -r requirements_passing.txt
(venv) $ python example.py
input list: [[10, 7, 4], [3, 2, 1]]
input list ravel: [10 7 4 3 2 1]
NumPy v1.21.4
JAX v0.2.20
numpy_array=array([[10, 7, 4],
[ 3, 2, 1]])
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax_array=DeviceArray([[10., 7., 4.],
[ 3., 2., 1.]], dtype=float64)
# Checking quantile
# Checking percentile
(venv) $ cat requirements_failing.txt
jax==0.2.21
jaxlib==0.1.69
(venv) $ python -m pip install -r requirements_failing.txt
$ python example.py
input list: [[10, 7, 4], [3, 2, 1]]
input list ravel: [10 7 4 3 2 1]
NumPy v1.21.4
JAX v0.2.21
numpy_array=array([[10, 7, 4],
[ 3, 2, 1]])
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax_array=DeviceArray([[10., 7., 4.],
[ 3., 2., 1.]], dtype=float64)
# Checking quantile
# Checking percentile
Traceback (most recent call last):
File "/home/feickert/Code/debug/jax-percentile-drift/example.py", line 67, in <module>
assert jnp.percentile(jax_array, 50) == 3.5 # 3.499999761581421
AssertionError
Comparing the code for v0.2.20
https://github.com/google/jax/blob/jax-v0.2.20/jax/_src/numpy/lax_numpy.py#L5905-L5912
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out=None, overwrite_input=False, interpolation="linear",
keepdims=False):
_check_arraylike("percentile", a)
q = true_divide(asarray(q), float32(100.0))
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, keepdims=keepdims)
and v0.2.21
https://github.com/google/jax/blob/jax-v0.2.21/jax/_src/numpy/lax_numpy.py#L6420-L6429
@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
'keepdims'))
def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
out=None, overwrite_input=False, interpolation="linear",
keepdims=False):
_check_arraylike("percentile", a, q)
q = true_divide(q, float32(100.0))
return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
interpolation=interpolation, keepdims=keepdims)
It seems (at first glance as I haven't dug into this yet) that the only relevant difference is the removal of asarray(q)
in the true_divide
call in jax-ml/jax#7747
-q = true_divide(asarray(q), float32(100.0))
+q = true_divide(q, float32(100.0))
This effect is quite minor, and probably poses no real significance in most cases, but it deviates from the docstring described behavior. Maybe the most obvious example is the extremes where the q-th percentile is 1 — which should return the array object maxima but instead returns the floating point approximation.
Would it be possible to revert to the v0.2.20
behavior? This would be more consistent with both the docstring and NumPy.