Skip to content

Instantly share code, notes, and snippets.

View zhangqiaorjc's full-sized avatar

Qiao Zhang zhangqiaorjc

View GitHub Profile
@zhangqiaorjc
zhangqiaorjc / spmd_pipeline_xmap.py
Created January 20, 2022 03:24
spmd_pipeline_xmap.py
import itertools as it
import jax
import jax.numpy as jnp
from jax.experimental import maps
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')
jax.config.update('experimental_xmap_spmd_lowering', True)
@zhangqiaorjc
zhangqiaorjc / install_tpu_vm.sh
Last active December 8, 2021 21:55
install tpu vm
# get a new VM
TPU_NAME=lingvo-jax-128
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --zone us-central2-b --accelerator-type v4-128 --version v2-nightly-tpuv4 --project tpu-prod-env-one-vm
# ssh into vm
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone us-central2-b --project tpu-prod-env-one-vm
# upgrade pip
python3 -m pip install -U pip
@zhangqiaorjc
zhangqiaorjc / jax_pjit_einsum.py
Last active December 2, 2021 01:48
jax_pjit_einsum.py demo allgather
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.pjit import pjit, PartitionSpec as P
from jax.experimental.maps import mesh
print('global devices=', jax.devices())
print('local devices=', jax.local_devices())
@zhangqiaorjc
zhangqiaorjc / print_hlo.py
Created November 25, 2021 22:11
make_hlo
def make_hlo(f, optimize=False, metadata=False, platform=None):
"""Utility function for printing JAX-emitted HLO and XLA-compiled HLO.
Args:
f: jax function to return hlo for.
optimize: bool: whether to return platform-specific, XLA-optimized HLO
metadata: bool: whether to include JAX metadata information
platform: Optional[str]: None, 'cpu','gpu','tpu' - platform to compile for,
None uses default.
import itertools as it
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_platform_name', 'cpu')
L = num_stages = 5
N = batch_size = 6
# [email protected]
# `jax.distributed.initialize` is available in jax-0.2.25.
# $ pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Note: wheels only available on linux.
# Run this script on 2 GPU nodes, assuming 10.128.0.6 is the master node
# python nvidia_gpu_pjit.py --server_addr="10.128.0.6:1456" --num_hosts=2 --host_idx=0
# python nvidia_gpu_pjit.py --server_addr="10.128.0.6:1456" --num_hosts=2 --host_idx=1
from absl import app
from absl import flags
from typing import Any, Sequence
from jax import core
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src import source_info_util
source_info_util.register_exclusion(__file__)
import jax
from jax import numpy as jnp
import numpy as np
import jax
from ad_checkpoint import name
from jax._src import source_info_util
def saved_residuals(f, *args, **kwargs):
jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f, *args, **kwargs)[1])(*args).jaxpr
import jax
from jax import numpy as jnp
import numpy as np
import jax
from jax._src import source_info_util
def saved_residuals(f, *args, **kwargs):
jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f, *args, **kwargs)[1])(*args).jaxpr
res_vars = set(jaxpr.outvars)
# [email protected]
import atexit
import functools
from absl import app
from absl import flags
from absl import logging
import jax
from jax.lib import xla_extension as xc