Created
February 18, 2024 16:44
-
-
Save samuela/4f2647dcfe41466586203fba2f2f5d38 to your computer and use it in GitHub Desktop.
jax-test_cuda-passthru-failure
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ubuntu@bitbop:~/nixpkgs$ NIXPKGS_ALLOW_UNFREE=1 nixglhost -- nix run --impure .#python3Packages.jax.passthru.test_cuda_jaxlibBin | |
warning: Git tree '/home/ubuntu/nixpkgs' is dirty | |
Traceback (most recent call last): | |
File "/nix/store/210z4yz9xpmch1fj7pzn83wdsbylmmlj-test_cuda/bin/test_cuda", line 7, in <module> | |
rng = random.PRNGKey(0) | |
^^^^^^^^^^^^^^^^^ | |
File "/nix/store/w48pp30vl0jk2wvwdqgb9zg06my5m7sy-python3-3.11.7-env/lib/python3.11/site-packages/jax/_src/random.py", line 240, in PRNGKey | |
return _return_prng_keys(True, _key('PRNGKey', seed, impl)) | |
^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
File "/nix/store/w48pp30vl0jk2wvwdqgb9zg06my5m7sy-python3-3.11.7-env/lib/python3.11/site-packages/jax/_src/random.py", line 202, in _key | |
return prng.random_seed(seed, impl=impl) | |
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
File "/nix/store/w48pp30vl0jk2wvwdqgb9zg06my5m7sy-python3-3.11.7-env/lib/python3.11/site-packages/jax/_src/prng.py", line 595, in random_seed | |
seeds_arr = jnp.asarray(np.int64(seeds)) | |
^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
File "/nix/store/w48pp30vl0jk2wvwdqgb9zg06my5m7sy-python3-3.11.7-env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2217, in asarray | |
return array(a, dtype=dtype, copy=bool(copy), order=order) # type: ignore | |
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
File "/nix/store/w48pp30vl0jk2wvwdqgb9zg06my5m7sy-python3-3.11.7-env/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 2172, in array | |
out_array: Array = lax_internal._convert_element_type( | |
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
File "/nix/store/w48pp30vl0jk2wvwdqgb9zg06my5m7sy-python3-3.11.7-env/lib/python3.11/site-packages/jax/_src/lax/lax.py", line 560, in _convert_element_type | |
return convert_element_type_p.bind(operand, new_dtype=new_dtype, | |
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
File "/nix/store/w48pp30vl0jk2wvwdqgb9zg06my5m7sy-python3-3.11.7-env/lib/python3.11/site-packages/jax/_src/core.py", line 444, in bind | |
return self.bind_with_trace(find_top_trace(args), args, params) | |
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
File "/nix/store/w48pp30vl0jk2wvwdqgb9zg06my5m7sy-python3-3.11.7-env/lib/python3.11/site-packages/jax/_src/core.py", line 447, in bind_with_trace | |
out = trace.process_primitive(self, map(trace.full_raise, args), params) | |
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
File "/nix/store/w48pp30vl0jk2wvwdqgb9zg06my5m7sy-python3-3.11.7-env/lib/python3.11/site-packages/jax/_src/core.py", line 935, in process_primitive | |
return primitive.impl(*tracers, **params) | |
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
File "/nix/store/w48pp30vl0jk2wvwdqgb9zg06my5m7sy-python3-3.11.7-env/lib/python3.11/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive | |
outs = fun(*args) | |
^^^^^^^^^^ | |
jaxlib.xla_extension.XlaRuntimeError: NOT_FOUND: Couldn't find ptxas. The following locations were considered: /nix/store/w48pp30vl0jk2wvwdqgb9zg06my5m7sy-python3-3.11.7-env/lib/python3.11/site-packages/jaxlib/cuda/bin/ptxas, /home/ubuntu/.nix-profile/bin/ptxas, /nix/var/nix/profiles/default/bin/ptxas, /home/ubuntu/.vscode-server/bin/903b1e9d8990623e3d7da1df3d33db3e42d80eda/bin/remote-cli/ptxas, /home/ubuntu/.nix-profile/bin/ptxas, /nix/var/nix/profiles/default/bin/ptxas, /home/ubuntu/.nix-profile/bin/ptxas, /nix/var/nix/profiles/default/bin/ptxas, /usr/local/sbin/ptxas, /usr/local/bin/ptxas, /usr/sbin/ptxas, /usr/bin/ptxas, /sbin/ptxas, /bin/ptxas, /usr/games/ptxas, /usr/local/games/ptxas, /snap/bin/ptxas, /usr/local/cuda-11.8/bin/ptxas, /usr/local/cuda/bin/ptxas, /nix/store/w48pp30vl0jk2wvwdqgb9zg06my5m7sy-python3-3.11.7-env/lib/python3.11/site-packages/jaxlib/../nvidia/cuda_nvcc/bin/ptxas, /nix/store/w48pp30vl0jk2wvwdqgb9zg06my5m7sy-python3-3.11.7-env/lib/python3.11/site-packages/jaxlib/../../nvidia/cuda_nvcc/bin/ptxas | |
-------------------- | |
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ lib | |
, blas | |
, buildPythonPackage | |
, setuptools | |
, importlib-metadata | |
, fetchFromGitHub | |
, jaxlib | |
, jaxlib-bin | |
, hypothesis | |
, lapack | |
, matplotlib | |
, ml-dtypes | |
, numpy | |
, opt-einsum | |
, pkgs | |
, pytestCheckHook | |
, pytest-xdist | |
, pythonOlder | |
, scipy | |
, stdenv | |
}: | |
let | |
usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl"; | |
# jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work | |
# fine. jaxlib is only used in the checkPhase, so switching backends does not | |
# impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*. | |
jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib; | |
jax = buildPythonPackage rec { | |
pname = "jax"; | |
version = "0.4.24"; | |
pyproject = true; | |
disabled = pythonOlder "3.9"; | |
src = fetchFromGitHub { | |
owner = "google"; | |
repo = "jax"; | |
# google/jax contains tags for jax and jaxlib. Only use jax tags! | |
rev = "refs/tags/${pname}-v${version}"; | |
hash = "sha256-hmx7eo3pephc6BQfoJ3U0QwWBWmhkAc+7S4QmW32qQs="; | |
}; | |
nativeBuildInputs = [ | |
setuptools | |
]; | |
# The version is automatically set to ".dev" if this variable is not set. | |
# https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3 | |
JAX_RELEASE = "1"; | |
# jaxlib is _not_ included in propagatedBuildInputs because there are | |
# different versions of jaxlib depending on the desired target hardware. The | |
# JAX project ships separate wheels for CPU, GPU, and TPU. | |
propagatedBuildInputs = [ | |
ml-dtypes | |
numpy | |
opt-einsum | |
scipy | |
] ++ lib.optional (pythonOlder "3.10") importlib-metadata; | |
nativeCheckInputs = [ | |
hypothesis | |
jaxlib' | |
matplotlib | |
pytestCheckHook | |
pytest-xdist | |
]; | |
# high parallelism will result in the tests getting stuck | |
dontUsePytestXdist = true; | |
# NOTE: Don't run the tests in the expiremental directory as they require flax | |
# which creates a circular dependency. See https://discourse.nixos.org/t/how-to-nix-ify-python-packages-with-circular-dependencies/14648/2. | |
# Not a big deal, this is how the JAX docs suggest running the test suite | |
# anyhow. | |
pytestFlagsArray = [ | |
"--numprocesses=4" | |
"-W ignore::DeprecationWarning" | |
"tests/" | |
]; | |
disabledTests = [ | |
# Exceeds tolerance when the machine is busy | |
"test_custom_linear_solve_aux" | |
# UserWarning: Explicitly requested dtype <class 'numpy.float64'> | |
# requested in astype is not available, and will be truncated to | |
# dtype float32. (With numpy 1.24) | |
"testKde3" | |
"testKde5" | |
"testKde6" | |
# Invokes python manually in a subprocess, which does not have the correct dependencies | |
# ImportError: This version of jax requires jaxlib version >= 0.4.19. | |
"test_no_log_spam" | |
] ++ lib.optionals usingMKL [ | |
# See | |
# * https://github.com/google/jax/issues/9705 | |
# * https://discourse.nixos.org/t/getting-different-results-for-the-same-build-on-two-equally-configured-machines/17921 | |
# * https://github.com/NixOS/nixpkgs/issues/161960 | |
"test_custom_linear_solve_cholesky" | |
"test_custom_root_with_aux" | |
"testEigvalsGrad_shape" | |
] ++ lib.optionals stdenv.isAarch64 [ | |
# See https://github.com/google/jax/issues/14793. | |
"test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop" | |
"testQdwhWithRandomMatrix3" | |
"testScanGrad_jit_scan" | |
# See https://github.com/google/jax/issues/17867. | |
"test_array" | |
"test_async" | |
"test_copy0" | |
"test_device_put" | |
"test_make_array_from_callback" | |
"test_make_array_from_single_device_arrays" | |
# Fails on some hardware due to some numerical error | |
# See https://github.com/google/jax/issues/18535 | |
"testQdwhWithOnRankDeficientInput5" | |
]; | |
disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [ | |
# RuntimeWarning: invalid value encountered in cast | |
"tests/lax_test.py" | |
]; | |
pythonImportsCheck = [ "jax" ]; | |
passthru = | |
let | |
test_cuda = jaxlib: pkgs.writers.writePython3Bin "test_cuda" | |
{ | |
libraries = [ jax jaxlib ]; | |
} '' | |
import jax | |
from jax import random | |
assert jax.devices()[0].platform == "gpu" | |
rng = random.PRNGKey(0) | |
x = random.normal(rng, (100, 100)) | |
x @ x | |
print("success!") | |
''; | |
in | |
{ | |
test_cuda_jaxlibSource = test_cuda (jaxlib.override { cudaSupport = true; }); | |
test_cuda_jaxlibBin = test_cuda (jaxlib-bin.override { cudaSupport = true; }); | |
}; | |
meta = with lib; { | |
description = "Differentiate, compile, and transform Numpy code"; | |
homepage = "https://github.com/google/jax"; | |
license = licenses.asl20; | |
maintainers = with maintainers; [ samuela ]; | |
}; | |
}; | |
in | |
jax |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment