This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import itertools as it | |
| import jax | |
| import jax.numpy as jnp | |
| from jax.experimental import maps | |
| jax.config.update('jax_enable_x64', True) | |
| jax.config.update('jax_platform_name', 'cpu') | |
| jax.config.update('experimental_xmap_spmd_lowering', True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # get a new VM | |
| TPU_NAME=lingvo-jax-128 | |
| gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --zone us-central2-b --accelerator-type v4-128 --version v2-nightly-tpuv4 --project tpu-prod-env-one-vm | |
| # ssh into vm | |
| gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone us-central2-b --project tpu-prod-env-one-vm | |
| # upgrade pip | |
| python3 -m pip install -U pip |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from jax.experimental.pjit import pjit, PartitionSpec as P | |
| from jax.experimental.maps import mesh | |
| print('global devices=', jax.devices()) | |
| print('local devices=', jax.local_devices()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def make_hlo(f, optimize=False, metadata=False, platform=None): | |
| """Utility function for printing JAX-emitted HLO and XLA-compiled HLO. | |
| Args: | |
| f: jax function to return hlo for. | |
| optimize: bool: whether to return platform-specific, XLA-optimized HLO | |
| metadata: bool: whether to include JAX metadata information | |
| platform: Optional[str]: None, 'cpu','gpu','tpu' - platform to compile for, | |
| None uses default. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import itertools as it | |
| import jax | |
| import jax.numpy as jnp | |
| jax.config.update('jax_enable_x64', True) | |
| jax.config.update('jax_platform_name', 'cpu') | |
| L = num_stages = 5 | |
| N = batch_size = 6 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # [email protected] | |
| # `jax.distributed.initialize` is available in jax-0.2.25. | |
| # $ pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Note: wheels only available on linux. | |
| # Run this script on 2 GPU nodes, assuming 10.128.0.6 is the master node | |
| # python nvidia_gpu_pjit.py --server_addr="10.128.0.6:1456" --num_hosts=2 --host_idx=0 | |
| # python nvidia_gpu_pjit.py --server_addr="10.128.0.6:1456" --num_hosts=2 --host_idx=1 | |
| from absl import app | |
| from absl import flags |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Any, Sequence | |
| from jax import core | |
| from jax.interpreters import ad | |
| from jax.interpreters import batching | |
| from jax.interpreters import partial_eval as pe | |
| from jax.interpreters import xla | |
| from jax._src import source_info_util | |
| source_info_util.register_exclusion(__file__) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import jax | |
| from jax import numpy as jnp | |
| import numpy as np | |
| import jax | |
| from ad_checkpoint import name | |
| from jax._src import source_info_util | |
| def saved_residuals(f, *args, **kwargs): | |
| jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f, *args, **kwargs)[1])(*args).jaxpr |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import jax | |
| from jax import numpy as jnp | |
| import numpy as np | |
| import jax | |
| from jax._src import source_info_util | |
| def saved_residuals(f, *args, **kwargs): | |
| jaxpr = jax.make_jaxpr(lambda *args: jax.linearize(f, *args, **kwargs)[1])(*args).jaxpr | |
| res_vars = set(jaxpr.outvars) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # [email protected] | |
| import atexit | |
| import functools | |
| from absl import app | |
| from absl import flags | |
| from absl import logging | |
| import jax | |
| from jax.lib import xla_extension as xc |