In pyhf we've noticed one of our unit tests that was passing for jax v0.2.7 and jaxlib v0.1.57, however, with the release of jaxlib v0.1.58 it has started failing.
We've narrowed it down to being for jaxlib v0.1.58 with jax_enable_x64=True in CPU mode (aka, where our unit test run).
In a fresh Python 3.8 virtual environment
$ python --version --version
Python 3.8.6 (default, Jan 5 2021, 00:14:15)
[GCC 9.3.0]
$ python -m pip install --quiet --upgrade pip setuptools wheel
$ python -m pip install jax jaxlib
$ python -m pip list
Package Version
----------- -------
absl-py 0.11.0
flatbuffers 1.12
jax 0.2.8
jaxlib 0.1.58
numpy 1.19.5
opt-einsum 3.3.0
pip 20.3.3
scipy 1.6.0
setuptools 51.1.2
six 1.15.0
wheel 0.36.2
then for
# jaxlab_issue.py
import jax
import jaxlib
from jax.config import config
import jax.numpy as jnp
from jax.scipy.special import gammaln
class Poisson:
def __init__(self, rate):
self.rate = jnp.asarray(rate, dtype="float64")
def log_prob(self, n):
n = jnp.asarray(n, dtype="float64")
return n * jnp.log(self.rate) - self.rate - gammaln(n + 1.0)
def main():
config.update("jax_enable_x64", True)
print(f"jax version: {jax.__version__}")
print(f"jaxlib version: {jaxlib.__version__}")
joint = gammaln(jnp.asarray([2.0, 3.0], dtype="float64")).tolist()
individual = [
*gammaln(jnp.asarray([2.0], dtype="float64")).tolist(),
*gammaln(jnp.asarray([3.0], dtype="float64")).tolist(),
]
print(f"joint: {joint}")
print(f"individual: {individual}")
assert joint == individual
# This is more akin to what we're seeing
joint = Poisson([10.0, 10.0]).log_prob([2.0, 3.0])
poisson_1 = Poisson([10.0]).log_prob(2.0)
poisson_2 = Poisson([10.0]).log_prob(3.0)
print(f"\njoint: {joint.tolist()}")
print(f"individual: {[*poisson_1.tolist(), *poisson_2.tolist()]}")
assert joint.tolist() == [*poisson_1.tolist(), *poisson_2.tolist()]
if __name__ == "__main__":
main()$ python jaxlib_issue.py
jax version: 0.2.8
jaxlib version: 0.1.58
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
joint: [8.881784197001252e-16, 0.693147180559945]
individual: [8.881784197001252e-16, 0.6931471805599432]
Traceback (most recent call last):
File "jaxlib_issue.py", line 43, in <module>
main()
File "jaxlib_issue.py", line 29, in main
assert joint == individual
AssertionError
however for
$ python -m pip install --quiet --upgrade "jaxlib<0.1.58"
$ python -m pip list | grep jax
jax 0.2.8
jaxlib 0.1.57
things are passing as before
$ python jaxlib_issue.py
jax version: 0.2.8
jaxlib version: 0.1.57
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
joint: [8.881784197001252e-16, 0.693147180559945]
individual: [8.881784197001252e-16, 0.693147180559945]
joint: [-6.087976994571854, -4.884004190245918]
individual: [-6.087976994571854, -4.884004190245918]
If the CUDA enabled version of jaxlib is installed the issue is not seen
$ python -m pip install --quiet --upgrade jax jaxlib==0.1.57+cuda101 --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
$ python -m pip list | grep jax
jax 0.2.8
jaxlib 0.1.57+cuda101
$ python jaxlib_issue.py
jax version: 0.2.8
jaxlib version: 0.1.57
joint: [8.881784197001252e-16, 0.6931471805599441]
individual: [8.881784197001252e-16, 0.6931471805599441]
joint: [-6.087976994571852, -4.884004190245918]
individual: [-6.087976994571852, -4.884004190245918]
$ python -m pip install --quiet --upgrade jax jaxlib==0.1.58+cuda101 --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
$ python -m pip list | grep jax
jax 0.2.8
jaxlib 0.1.58+cuda101
$ python jaxlib_issue.py
jax version: 0.2.8
jaxlib version: 0.1.58
joint: [8.881784197001252e-16, 0.6931471805599441]
individual: [8.881784197001252e-16, 0.6931471805599441]
joint: [-6.087976994571852, -4.884004190245918]
individual: [-6.087976994571852, -4.884004190245918]
so this seems to be a CPU only issue.
We do realize that this difference is incredibly small, but as this is a change in behavior that we didn't expect we though we'd still report it even if this gets a "won't fix" label.
cc @lukasheinrich @kratsg