This is the validation for Nvidia JAX image, I could build an image and run the script, however since on my Mac M1, there are no Nvidia GPUs, it couldn't run the script.
FROM ghcr.io/nvidia/jax:jax
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
libgoogle-glog-dev \
libgflags-dev && \
rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
COPY test_jax_backends.py .
CMD ["python3", "test_jax_backends.py"]
import os
import jax
import jax.numpy as jnp
def test_backend(backend_name: str):
print("="*40)
print(f"Testing backend: {backend_name}")
os.environ["JAX_DIST_BACKEND"] = backend_name
try:
x = jnp.arange(10)
y = x * 2
print("result:", y)
print("Devices:", jax.devices())
print(f"{backend_name} works.")
except Exception as e:
print(f"{backend_name} failed:", e)
if __name__ == "__main__":
for backend in ["gloo", "nccl", "mpi"]:
test_backend(backend)
➜ v git:(main) ✗ docker run --rm jax-gpu
==========
== CUDA ==
==========
NVIDIA Release (build )
CUDA Version 13.0.0.044
Container image Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Various files include modifications (c) NVIDIA CORPORATION & AFFILIATES. All rights reserved.
GOVERNING TERMS: The software and materials are governed by the NVIDIA Software License Agreement
(found at https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-software-license-agreement/)
and the Product-Specific Terms for NVIDIA AI Products
(found at https://www.nvidia.com/en-us/agreements/enterprise-software/product-specific-terms-for-ai-products/).
WARNING: The NVIDIA Driver was not detected. GPU functionality will not be available.
Use the NVIDIA Container Toolkit to start this container with GPU support; see
https://docs.nvidia.com/datacenter/cloud-native/ .
NOTE: The SHMEM allocation limit is set to the default of 64MB. This may be
insufficient for CUDA. NVIDIA recommends the use of the following flags:
docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 ...
ERROR:2025-08-29 09:28:02,758:jax._src.xla_bridge:487: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda13.initialize()
Traceback (most recent call last):
File "/opt/jax/jax/_src/xla_bridge.py", line 485, in discover_pjrt_plugins
plugin_module.initialize()
File "/opt/jaxlibs/jax_cuda13_pjrt/jax_plugins/xla_cuda13/__init__.py", line 328, in initialize
_check_cuda_versions(raise_on_first_error=True)
File "/opt/jaxlibs/jax_cuda13_pjrt/jax_plugins/xla_cuda13/__init__.py", line 285, in _check_cuda_versions
local_device_count = cuda_versions.cuda_device_count()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:113: operation cuInit(0) failed: Unknown CUDA error 303; cuGetErrorName failed. This probably means that JAX was unable to load the CUDA libraries.