Skip to content

Instantly share code, notes, and snippets.

@mahdikhashan
Last active August 29, 2025 09:32
Show Gist options
  • Save mahdikhashan/67fed62d65e2ca4e8b32b9381fca4924 to your computer and use it in GitHub Desktop.
Save mahdikhashan/67fed62d65e2ca4e8b32b9381fca4924 to your computer and use it in GitHub Desktop.
FROM ghcr.io/nvidia/jax:jax
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
libgoogle-glog-dev \
libgflags-dev && \
rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
COPY test_jax_backends.py .
CMD ["python3", "test_jax_backends.py"]

This is the validation for Nvidia JAX image, I could build an image and run the script, however since on my Mac M1, there are no Nvidia GPUs, it couldn't run the script.

FROM ghcr.io/nvidia/jax:jax

ENV DEBIAN_FRONTEND=noninteractive

RUN apt-get update && \
    apt-get install -y --no-install-recommends \
        build-essential \
        cmake \
        git \
        libgoogle-glog-dev \
        libgflags-dev && \
    rm -rf /var/lib/apt/lists/*

WORKDIR /workspace

COPY test_jax_backends.py .

CMD ["python3", "test_jax_backends.py"]
import os
import jax
import jax.numpy as jnp


def test_backend(backend_name: str):
    print("="*40)
    print(f"Testing backend: {backend_name}")
    os.environ["JAX_DIST_BACKEND"] = backend_name

    try:
        x = jnp.arange(10)
        y = x * 2
        print("result:", y)
        print("Devices:", jax.devices())
        print(f"{backend_name} works.")
    except Exception as e:
        print(f"{backend_name} failed:", e)

if __name__ == "__main__":
    for backend in ["gloo", "nccl", "mpi"]:
        test_backend(backend)
➜  v git:(main) ✗ docker run --rm jax-gpu


==========
== CUDA ==
==========

NVIDIA Release  (build )
CUDA Version 13.0.0.044
Container image Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

Various files include modifications (c) NVIDIA CORPORATION & AFFILIATES.  All rights reserved.

GOVERNING TERMS: The software and materials are governed by the NVIDIA Software License Agreement
(found at https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-software-license-agreement/)
and the Product-Specific Terms for NVIDIA AI Products
(found at https://www.nvidia.com/en-us/agreements/enterprise-software/product-specific-terms-for-ai-products/).

WARNING: The NVIDIA Driver was not detected.  GPU functionality will not be available.
   Use the NVIDIA Container Toolkit to start this container with GPU support; see
   https://docs.nvidia.com/datacenter/cloud-native/ .

NOTE: The SHMEM allocation limit is set to the default of 64MB.  This may be
   insufficient for CUDA.  NVIDIA recommends the use of the following flags:
   docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 ...

ERROR:2025-08-29 09:28:02,758:jax._src.xla_bridge:487: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda13.initialize()
Traceback (most recent call last):
  File "/opt/jax/jax/_src/xla_bridge.py", line 485, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/opt/jaxlibs/jax_cuda13_pjrt/jax_plugins/xla_cuda13/__init__.py", line 328, in initialize
    _check_cuda_versions(raise_on_first_error=True)
  File "/opt/jaxlibs/jax_cuda13_pjrt/jax_plugins/xla_cuda13/__init__.py", line 285, in _check_cuda_versions
    local_device_count = cuda_versions.cuda_device_count()
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:113: operation cuInit(0) failed: Unknown CUDA error 303; cuGetErrorName failed. This probably means that JAX was unable to load the CUDA libraries.
import os
import jax
import jax.numpy as jnp
def test_backend(backend_name: str):
print("="*40)
print(f"Testing backend: {backend_name}")
os.environ["JAX_DIST_BACKEND"] = backend_name
try:
x = jnp.arange(10)
y = x * 2
print("Computation result:", y)
print("Devices:", jax.devices())
print(f"{backend_name} works.")
except Exception as e:
print(f"{backend_name} failed:", e)
if __name__ == "__main__":
for backend in ["gloo", "nccl", "mpi"]:
test_backend(backend)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment