Skip to content

Instantly share code, notes, and snippets.

@matthewfeickert
Last active November 11, 2021 07:00
Show Gist options
  • Save matthewfeickert/d8364201ddad6653315a62dc9b921318 to your computer and use it in GitHub Desktop.
Save matthewfeickert/d8364201ddad6653315a62dc9b921318 to your computer and use it in GitHub Desktop.
floating point deviation in jax.numpy.percentile between v0.2.20 and v0.2.21

Hi. There is some (very minor) deviations in the output of jax.numpy.percentile between jax v0.2.20 and v0.2.21 in the case that linear interpolation is used (the default). Interestingly, it is really in jax.numpy.percentile and not in jax.numpy.quantile as can be shown in the included example (for convenience this Issue also exists as a GitHub Gist).

Minimal failing example

# example.py
import jax
import jax.numpy as jnp
import numpy as np
from jax.config import config

config.update("jax_enable_x64", True)

if __name__ == "__main__":

    # percentile interpolation options:
    # This optional parameter specifies the interpolation method to use when the desired percentile lies between two data points i < j:
    # * ’linear’: i + (j - i) * fraction, where fraction is the fractional part of the index surrounded by i and j.
    # * ’lower’: i.
    # * ’higher’: j.
    # * ’nearest’: i or j, whichever is nearest.
    # * ’midpoint’: (i + j) / 2.

    input = [[10, 7, 4], [3, 2, 1]]
    print(f"input list: {input}")
    print(f"input list ravel: {np.asarray(input).ravel()}")
    # [10  7  4  3  2  1]

    print(f"\nNumPy v{np.__version__}")
    print(f"JAX v{jax.__version__}\n")
    numpy_array = np.asarray(input)
    print(f"{numpy_array=}")
    jax_array = jnp.asarray(input, dtype="float")
    print(f"{jax_array=}")

    print("\n# Checking quantile\n")
    assert np.quantile(numpy_array, 0) == 1.0
    assert np.quantile(numpy_array, 0.50) == 3.5
    assert np.quantile(numpy_array, 1) == 10
    assert np.quantile(numpy_array, 0.50, axis=1).tolist() == [7.0, 2.0]

    assert np.quantile(numpy_array, 0.50, interpolation="linear") == 3.5
    assert np.quantile(numpy_array, 0.50, interpolation="nearest") == 3.0
    assert np.quantile(numpy_array, 0.50, interpolation="lower") == 3.0
    assert np.quantile(numpy_array, 0.50, interpolation="midpoint") == 3.5
    assert np.quantile(numpy_array, 0.50, interpolation="higher") == 4.0

    assert jnp.quantile(jax_array, 0) == 1.0
    assert jnp.quantile(jax_array, 0.50) == 3.5
    assert jnp.quantile(jax_array, 1) == 10
    assert jnp.quantile(jax_array, 0.50, axis=1).tolist() == [7.0, 2.0]

    assert jnp.quantile(jax_array, 0.50, interpolation="linear") == 3.5
    assert jnp.quantile(jax_array, 0.50, interpolation="nearest") == 3.0
    assert jnp.quantile(jax_array, 0.50, interpolation="lower") == 3.0
    assert jnp.quantile(jax_array, 0.50, interpolation="midpoint") == 3.5
    assert jnp.quantile(jax_array, 0.50, interpolation="higher") == 4.0

    print("# Checking percentile")
    assert np.percentile(numpy_array, 0) == 1.0
    assert np.percentile(numpy_array, 50) == 3.5
    assert np.percentile(numpy_array, 100) == 10
    assert np.percentile(numpy_array, 50, axis=1).tolist() == [7.0, 2.0]

    assert np.percentile(numpy_array, 50, interpolation="linear") == 3.5
    assert np.percentile(numpy_array, 50, interpolation="nearest") == 3.0
    assert np.percentile(numpy_array, 50, interpolation="lower") == 3.0
    assert np.percentile(numpy_array, 50, interpolation="midpoint") == 3.5
    assert np.percentile(numpy_array, 50, interpolation="higher") == 4.0

    # default interpolation method is "linear"
    assert jnp.percentile(jax_array, 0) == 1.0
    assert jnp.percentile(jax_array, 50) == 3.5  # 3.499999761581421
    assert jnp.percentile(jax_array, 100) == 10  # 9.999998092651367
    assert jnp.percentile(jax_array, 50, axis=1).tolist() == [7.0, 2.0]

    assert jnp.percentile(jax_array, 50, interpolation="linear") == 3.5  # 3.499999761581421
    assert jnp.percentile(jax_array, 50, interpolation="nearest") == 3.0
    assert jnp.percentile(jax_array, 50, interpolation="lower") == 3.0
    assert jnp.percentile(jax_array, 50, interpolation="midpoint") == 3.5
    assert jnp.percentile(jax_array, 50, interpolation="higher") == 4.0
$ python --version
Python 3.9.6
$ python -m venv /tmp/venv && . /tmp/venv/bin/activate
(venv) $ python -m pip install --upgrade pip setuptools wheel
(venv) $ cat requirements_passing.txt
jax==0.2.20
jaxlib==0.1.69
(venv) $ python -m pip install -r requirements_passing.txt
(venv) $ python example.py
input list: [[10, 7, 4], [3, 2, 1]]
input list ravel: [10  7  4  3  2  1]

NumPy v1.21.4
JAX v0.2.20

numpy_array=array([[10,  7,  4],
       [ 3,  2,  1]])
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax_array=DeviceArray([[10.,  7.,  4.],
             [ 3.,  2.,  1.]], dtype=float64)

# Checking quantile

# Checking percentile
(venv) $ cat requirements_failing.txt
jax==0.2.21
jaxlib==0.1.69
(venv) $ python -m pip install -r requirements_failing.txt
$ python example.py
input list: [[10, 7, 4], [3, 2, 1]]
input list ravel: [10  7  4  3  2  1]

NumPy v1.21.4
JAX v0.2.21

numpy_array=array([[10,  7,  4],
       [ 3,  2,  1]])
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax_array=DeviceArray([[10.,  7.,  4.],
             [ 3.,  2.,  1.]], dtype=float64)

# Checking quantile

# Checking percentile
Traceback (most recent call last):
  File "/home/feickert/Code/debug/jax-percentile-drift/example.py", line 67, in <module>
    assert jnp.percentile(jax_array, 50) == 3.5  # 3.499999761581421
AssertionError

Notes

Comparing the code for v0.2.20

https://github.com/google/jax/blob/jax-v0.2.20/jax/_src/numpy/lax_numpy.py#L5905-L5912

@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
               out=None, overwrite_input=False, interpolation="linear",
               keepdims=False):
  _check_arraylike("percentile", a)
  q = true_divide(asarray(q), float32(100.0))
  return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
                  interpolation=interpolation, keepdims=keepdims)

and v0.2.21

https://github.com/google/jax/blob/jax-v0.2.21/jax/_src/numpy/lax_numpy.py#L6420-L6429

@_wraps(np.percentile, skip_params=['out', 'overwrite_input'])
@partial(jit, static_argnames=('axis', 'overwrite_input', 'interpolation',
                               'keepdims'))
def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
               out=None, overwrite_input=False, interpolation="linear",
               keepdims=False):
  _check_arraylike("percentile", a, q)
  q = true_divide(q, float32(100.0))
  return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
                  interpolation=interpolation, keepdims=keepdims)

It seems (at first glance as I haven't dug into this yet) that the only relevant difference is the removal of asarray(q) in the true_divide call in jax-ml/jax#7747

-q = true_divide(asarray(q), float32(100.0))
+q = true_divide(q, float32(100.0))

This effect is quite minor, and probably poses no real significance in most cases, but it deviates from the docstring described behavior. Maybe the most obvious example is the extremes where the q-th percentile is 1 — which should return the array object maxima but instead returns the floating point approximation.

Request

Would it be possible to revert to the v0.2.20 behavior? This would be more consistent with both the docstring and NumPy.

import jax
import jax.numpy as jnp
import numpy as np
from jax.config import config
config.update("jax_enable_x64", True)
if __name__ == "__main__":
# percentile interpolation options:
# This optional parameter specifies the interpolation method to use when the desired percentile lies between two data points i < j:
# * ’linear’: i + (j - i) * fraction, where fraction is the fractional part of the index surrounded by i and j.
# * ’lower’: i.
# * ’higher’: j.
# * ’nearest’: i or j, whichever is nearest.
# * ’midpoint’: (i + j) / 2.
input = [[10, 7, 4], [3, 2, 1]]
print(f"input list: {input}")
print(f"input list ravel: {np.asarray(input).ravel()}")
# [10 7 4 3 2 1]
print(f"\nNumPy v{np.__version__}")
print(f"JAX v{jax.__version__}\n")
numpy_array = np.asarray(input)
print(f"{numpy_array=}")
jax_array = jnp.asarray(input, dtype="float")
print(f"{jax_array=}")
print("\n# Checking quantile\n")
assert np.quantile(numpy_array, 0) == 1.0
assert np.quantile(numpy_array, 0.50) == 3.5
assert np.quantile(numpy_array, 1) == 10
assert np.quantile(numpy_array, 0.50, axis=1).tolist() == [7.0, 2.0]
assert np.quantile(numpy_array, 0.50, interpolation="linear") == 3.5
assert np.quantile(numpy_array, 0.50, interpolation="nearest") == 3.0
assert np.quantile(numpy_array, 0.50, interpolation="lower") == 3.0
assert np.quantile(numpy_array, 0.50, interpolation="midpoint") == 3.5
assert np.quantile(numpy_array, 0.50, interpolation="higher") == 4.0
assert jnp.quantile(jax_array, 0) == 1.0
assert jnp.quantile(jax_array, 0.50) == 3.5
assert jnp.quantile(jax_array, 1) == 10
assert jnp.quantile(jax_array, 0.50, axis=1).tolist() == [7.0, 2.0]
assert jnp.quantile(jax_array, 0.50, interpolation="linear") == 3.5
assert jnp.quantile(jax_array, 0.50, interpolation="nearest") == 3.0
assert jnp.quantile(jax_array, 0.50, interpolation="lower") == 3.0
assert jnp.quantile(jax_array, 0.50, interpolation="midpoint") == 3.5
assert jnp.quantile(jax_array, 0.50, interpolation="higher") == 4.0
print("# Checking percentile")
assert np.percentile(numpy_array, 0) == 1.0
assert np.percentile(numpy_array, 50) == 3.5
assert np.percentile(numpy_array, 100) == 10
assert np.percentile(numpy_array, 50, axis=1).tolist() == [7.0, 2.0]
assert np.percentile(numpy_array, 50, interpolation="linear") == 3.5
assert np.percentile(numpy_array, 50, interpolation="nearest") == 3.0
assert np.percentile(numpy_array, 50, interpolation="lower") == 3.0
assert np.percentile(numpy_array, 50, interpolation="midpoint") == 3.5
assert np.percentile(numpy_array, 50, interpolation="higher") == 4.0
# default interpolation method is "linear"
assert jnp.percentile(jax_array, 0) == 1.0
assert jnp.percentile(jax_array, 50) == 3.5 # 3.499999761581421
assert jnp.percentile(jax_array, 100) == 10 # 9.999998092651367
assert jnp.percentile(jax_array, 50, axis=1).tolist() == [7.0, 2.0]
assert jnp.percentile(jax_array, 50, interpolation="linear") == 3.5 # 3.499999761581421
assert jnp.percentile(jax_array, 50, interpolation="nearest") == 3.0
assert jnp.percentile(jax_array, 50, interpolation="lower") == 3.0
assert jnp.percentile(jax_array, 50, interpolation="midpoint") == 3.5
assert jnp.percentile(jax_array, 50, interpolation="higher") == 4.0
jax==0.2.21
jaxlib==0.1.69
jax==0.2.20
jaxlib==0.1.69
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment