For my GSoC project this year, I'm working on an integration between JAX and Kubeflow Trainer v2.
There are multiple challenges which requires consideration, one of them is how to keep multiple backend setup in a single container, when I don't want to have multiple templates / blueprints, i need to validate it.
my current implementation for docker:
FROM python:3.10-bullseye
RUN apt-get update && \
apt-get install -y --no-install-recommends \
build-essential \
cmake \
git \
curl \
libgoogle-glog-dev \
libgflags-dev \
libprotobuf-dev \
protobuf-compiler && \
rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir \
numpy \
jax \
jaxlib \
"jax[cuda12_pip]" \
"jax[tpu]" && \
pip install --no-cache-dir libtpu-nightly || echo "libtpu-nightly not available, continuing without it"
RUN pip install absl-py kubernetes
RUN git clone https://github.com/facebookincubator/gloo.git \
&& cd gloo \
&& git checkout 43b7acbf372cdce14075f3526e39153b7e433b53 \
&& mkdir build \
&& cd build \
&& cmake ../ \
&& make \
&& make install
WORKDIR /workspace
COPY test_jax_backends.py .
CMD ["python3", "test_jax_backends.py"]
Docker image builds successfully, however, I need proper hardware to test it backends.
import os
import jax
import jax.numpy as jnp
def test_backend(backend_name: str):
print("="*40)
print(f"Testing backend: {backend_name}")
os.environ["JAX_DIST_BACKEND"] = backend_name
try:
x = jnp.arange(10)
y = x * 2
print("result:", y)
print("Devices:", jax.devices())
print(f"{backend_name} works.")
except Exception as e:
print(f"{backend_name} failed:", e)
if __name__ == "__main__":
for backend in ["gloo", "nccl", "mpi"]:
test_backend(backend)
on my Mac M1, this results in
➜ v git:(main) ✗ docker run --rm jax-multi-backendpp
Traceback (most recent call last):
File "/workspace/test_jax_backends.py", line 2, in <module>
import jax
File "/usr/local/lib/python3.10/site-packages/jax/__init__.py", line 37, in <module>
from . import config as _config_module
File "/usr/local/lib/python3.10/site-packages/jax/config.py", line 18, in <module>
from jax._src.config import config
File "/usr/local/lib/python3.10/site-packages/jax/_src/config.py", line 26, in <module>
from jax import lib
File "/usr/local/lib/python3.10/site-packages/jax/lib/__init__.py", line 67, in <module>
from jaxlib import pocketfft
ImportError: cannot import name 'pocketfft' from 'jaxlib' (/usr/local/lib/python3.10/site-packages/jaxlib/__init__.py)
which i guess it because i may not be able to use custom backends, i need to experiment more and propose a solution.