Skip to content

Instantly share code, notes, and snippets.

@yejingxin
Created May 8, 2023 17:30
Show Gist options
  • Save yejingxin/6f50ef72577f58ff676a9f6d2ca8d0f8 to your computer and use it in GitHub Desktop.
Save yejingxin/6f50ef72577f58ff676a9f6d2ca8d0f8 to your computer and use it in GitHub Desktop.
MaxText Ray Example
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Simple e2e example running MaxText train.py on multislice.
# set up ray cluster following https://cloud.google.com/tpu/docs/ray-guide
# add ray lib to PYTHONPATH
export PYTHONPATH=${PYTHONPATH}:/home/$USER/ray_code/
# start or stop the training job:
python3 MaxText/ray_runner.py --tpu_topology=1x2x2 --mode=start
python3 MaxText/ray_runner.py --tpu_topology=1x2x2 --mode=stop --delete_tpu
"""
import getpass
import os
import socket
import time
from absl import app
from absl import flags
from absl import logging
from ray.job_submission import JobStatus
from ray_tpu_controller import _DEFAULT_RAY_PORT
from ray_tpu_controller import BASE_JAX_PIP_INSTALLS
from ray_tpu_controller import RayTpuController
from ray_tpu_controller import TpuRayJob
from tpu_api import get_default_gcp_project
FLAGS = flags.FLAGS
flags.DEFINE_string('network', 'default', 'VPC network name.')
flags.DEFINE_string('subnet', 'default', 'Subnet name under VPC.')
flags.DEFINE_string(
'tpu_name',
getpass.getuser() + '-slice',
'TPU vm name, it will be suffixed with slice index, like ubuntu-slice-0.',
)
flags.DEFINE_string('tpu_topology', '2x2x2', 'TPU topology.')
flags.DEFINE_integer('num_slices', 2, 'Number of slices.')
flags.DEFINE_enum('mode', None, ['start', 'stop'], 'Start or stop running jobs.')
flags.DEFINE_boolean('delete_tpu', False, 'Whether delete tpus when stop jobs.')
def stop(controllers):
for controller in controllers:
controller.clean_stale_jobs(controller.resource_name)
logging.info('All Jobs are requested to stop.')
if FLAGS.delete_tpu:
for controller in controllers:
controller.delete_tpu()
def start(controllers):
for controller in controllers:
controller.maybe_create_and_wait_for_ready()
controller.clean_stale_jobs(controller.resource_name)
mxla_coordinator_address = f'{controllers[0].get_ip_addresses()[0]}:8080'
num_slices = len(controllers)
run_command = (
f'python3 MaxText/train.py MaxText/configs/base.yml run_name=$USER_$(date +%Y-%m-%d-%H-%M-%S) dcn_data_parallelism={num_slices}'
)
working_dir = os.path.expanduser('~/MaxText')
pip_installs = {"packages":[
'orbax==0.1.6',
'absl-py',
'argparse',
'datetime',
'google-cloud-storage',
'ml-collections',
'numpy',
'optax',
'portpicker',
'protobuf==3.20.3',
'pylint',
'pytest',
'sentencepiece',
'tensorflow==2.12.0rc0',
'tensorflow-datasets',
'tensorboard-plugin-profile',
'tensorflow-text==2.12.0rc0',
'tensorboardx',
'jaxlib==0.4.6',
'jax[tpu]==0.4.6',
'-f https://storage.googleapis.com/jax-releases/libtpu_releases.html',
"-e git+https://github.com/google/flax.git@main#egg=flax",
],
"pip_check": False,
"pip_version": "==20.0.2;python_full_version=='3.8.10'",}
for slice_index in range(num_slices):
job = TpuRayJob(
entrypoint=run_command,
working_dir=working_dir,
pip_installs=pip_installs,
env_vars={
'LIBTPU_INIT_ARGS': '"--xla_tpu_enable_megascale_barrier=true"',
'JAX_USE_PJRT_C_API_ON_TPU': '1',
'MEGASCALE_COORDINATOR_ADDRESS': mxla_coordinator_address,
'MEGASCALE_SLICE_ID': f'{slice_index}',
'MEGASCALE_NUM_SLICES': f'{num_slices}',
'TPU_VMODULE': 'communication_backend=10',
},
)
controllers[slice_index].queue_tpu_workload(job)
num_slices_job_running = 0
while num_slices_job_running < num_slices:
for controller in controllers:
if controller.jobs_in_status(JobStatus.RUNNING):
num_slices_job_running += 1
logging.info(
'%d/%d slices jobs are running.', num_slices_job_running, num_slices
)
time.sleep(10)
logging.info('All Jobs are in running status, here is the list:')
for slice_index, controller in enumerate(controllers):
for job in controller.queued_jobs:
logging.info('job:%s @ slice: %d', job, slice_index)
def main(_):
project = get_default_gcp_project()
hostname = socket.gethostname()
unused_1, unused_2, unused_3, unused_4, (controller_ip, unused_5) = (
socket.getaddrinfo(hostname, _DEFAULT_RAY_PORT)[0]
)
num_slices = FLAGS.num_slices
controllers = []
for slice_index in range(num_slices):
controller = RayTpuController(
tpu_name=f'{FLAGS.tpu_name}-{slice_index}',
project=project,
zone='us-central2-b',
accelerator_type='V4',
accelerator_topology=FLAGS.tpu_topology,
version='tpu-vm-v4-base',
network=FLAGS.network,
subnetwork=FLAGS.subnet,
head_addr=f'{controller_ip}:{_DEFAULT_RAY_PORT}',
)
controllers.append(controller)
if FLAGS.mode == 'start':
start(controllers)
elif FLAGS.mode == 'stop':
stop(controllers)
if __name__ == '__main__':
flags.mark_flag_as_required('mode')
logging.set_verbosity(logging.INFO)
app.run(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment