Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save mahdikhashan/d649607370645b2b6017dcab341c8c27 to your computer and use it in GitHub Desktop.
Save mahdikhashan/d649607370645b2b6017dcab341c8c27 to your computer and use it in GitHub Desktop.

For my GSoC project this year, I'm working on an integration between JAX and Kubeflow Trainer v2.

There are multiple challenges which requires consideration, one of them is how to keep multiple backend setup in a single container, when I don't want to have multiple templates / blueprints, i need to validate it.

my current implementation for docker:

FROM python:3.10-bullseye

RUN apt-get update && \
    apt-get install -y --no-install-recommends \
        build-essential \
        cmake \
        git \
        curl \
        libgoogle-glog-dev \
        libgflags-dev \
        libprotobuf-dev \
        protobuf-compiler && \
    rm -rf /var/lib/apt/lists/*

RUN pip install --no-cache-dir --upgrade pip && \
    pip install --no-cache-dir \
        numpy \
        jax \
        jaxlib \
        "jax[cuda12_pip]" \
        "jax[tpu]" && \
    pip install --no-cache-dir libtpu-nightly || echo "libtpu-nightly not available, continuing without it"

RUN pip install absl-py kubernetes

RUN git clone https://github.com/facebookincubator/gloo.git \
    && cd gloo \
    && git checkout 43b7acbf372cdce14075f3526e39153b7e433b53 \
    && mkdir build \
    && cd build \
    && cmake ../ \
    && make \
    && make install

WORKDIR /workspace
COPY test_jax_backends.py .

CMD ["python3", "test_jax_backends.py"]

Docker image builds successfully, however, I need proper hardware to test it backends.

import os
import jax
import jax.numpy as jnp


def test_backend(backend_name: str):
    print("="*40)
    print(f"Testing backend: {backend_name}")
    os.environ["JAX_DIST_BACKEND"] = backend_name

    try:
        x = jnp.arange(10)
        y = x * 2
        print("result:", y)
        print("Devices:", jax.devices())
        print(f"{backend_name} works.")
    except Exception as e:
        print(f"{backend_name} failed:", e)

if __name__ == "__main__":
    for backend in ["gloo", "nccl", "mpi"]:
        test_backend(backend)

on my Mac M1, this results in

➜  v git:(main) ✗ docker run --rm jax-multi-backendpp

Traceback (most recent call last):
  File "/workspace/test_jax_backends.py", line 2, in <module>
    import jax
  File "/usr/local/lib/python3.10/site-packages/jax/__init__.py", line 37, in <module>
    from . import config as _config_module
  File "/usr/local/lib/python3.10/site-packages/jax/config.py", line 18, in <module>
    from jax._src.config import config
  File "/usr/local/lib/python3.10/site-packages/jax/_src/config.py", line 26, in <module>
    from jax import lib
  File "/usr/local/lib/python3.10/site-packages/jax/lib/__init__.py", line 67, in <module>
    from jaxlib import pocketfft
ImportError: cannot import name 'pocketfft' from 'jaxlib' (/usr/local/lib/python3.10/site-packages/jaxlib/__init__.py)

which i guess it because i may not be able to use custom backends, i need to experiment more and propose a solution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment