Skip to content

Instantly share code, notes, and snippets.

@zhangqiaorjc
Last active December 8, 2021 21:55
Show Gist options
  • Select an option

  • Save zhangqiaorjc/9da06f4e77b6030ab338b394298267aa to your computer and use it in GitHub Desktop.

Select an option

Save zhangqiaorjc/9da06f4e77b6030ab338b394298267aa to your computer and use it in GitHub Desktop.
install tpu vm
# get a new VM
TPU_NAME=lingvo-jax-128
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --zone us-central2-b --accelerator-type v4-128 --version v2-nightly-tpuv4 --project tpu-prod-env-one-vm
# ssh into vm
gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone us-central2-b --project tpu-prod-env-one-vm
# upgrade pip
python3 -m pip install -U pip
# install latest jax
python3 -m pip install -U jax jaxlib
gsutil cp gs://zhangqiaorjc/libtpu.so.bc_offload .
sudo cp libtpu.so.bc_offload /lib/libtpu.so
sudo cp libtpu.so.bc_offload /usr/lib/libtpu.so
# install lingvo-jax
python3 -m pip install lingvo-jax tensorflow-text
# get profiler working
python3 -m pip install tensorboard-plugin-profile
python3 -m pip uninstall tb-nightly
# download expt params
mkdir .local/lib/python3.8/site-packages/lingvo/jax/tasks/lm/params
git clone https://github.com/tensorflow/lingvo.git
cp lingvo/lingvo/jax/tasks/lm/params/* .local/lib/python3.8/site-packages/lingvo/jax/tasks/lm/params/
# Train
LIBTPU_INIT_ARGS="--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true" python3 .local/lib/python3.8/site-packages/lingvo/jax/main.py --model=lm.lm_cloud.LmCloudSpmd2B --jax_profiler_port=9999 --job_log_dir=~/expts/jax_2b_1/
tensorboard --logdir ~/expts/jax_2b_1/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment