Created
May 8, 2023 17:30
-
-
Save yejingxin/6f50ef72577f58ff676a9f6d2ca8d0f8 to your computer and use it in GitHub Desktop.
MaxText Ray Example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Copyright 2023 Google LLC | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
https://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
""" | |
""" | |
Simple e2e example running MaxText train.py on multislice. | |
# set up ray cluster following https://cloud.google.com/tpu/docs/ray-guide | |
# add ray lib to PYTHONPATH | |
export PYTHONPATH=${PYTHONPATH}:/home/$USER/ray_code/ | |
# start or stop the training job: | |
python3 MaxText/ray_runner.py --tpu_topology=1x2x2 --mode=start | |
python3 MaxText/ray_runner.py --tpu_topology=1x2x2 --mode=stop --delete_tpu | |
""" | |
import getpass | |
import os | |
import socket | |
import time | |
from absl import app | |
from absl import flags | |
from absl import logging | |
from ray.job_submission import JobStatus | |
from ray_tpu_controller import _DEFAULT_RAY_PORT | |
from ray_tpu_controller import BASE_JAX_PIP_INSTALLS | |
from ray_tpu_controller import RayTpuController | |
from ray_tpu_controller import TpuRayJob | |
from tpu_api import get_default_gcp_project | |
FLAGS = flags.FLAGS | |
flags.DEFINE_string('network', 'default', 'VPC network name.') | |
flags.DEFINE_string('subnet', 'default', 'Subnet name under VPC.') | |
flags.DEFINE_string( | |
'tpu_name', | |
getpass.getuser() + '-slice', | |
'TPU vm name, it will be suffixed with slice index, like ubuntu-slice-0.', | |
) | |
flags.DEFINE_string('tpu_topology', '2x2x2', 'TPU topology.') | |
flags.DEFINE_integer('num_slices', 2, 'Number of slices.') | |
flags.DEFINE_enum('mode', None, ['start', 'stop'], 'Start or stop running jobs.') | |
flags.DEFINE_boolean('delete_tpu', False, 'Whether delete tpus when stop jobs.') | |
def stop(controllers): | |
for controller in controllers: | |
controller.clean_stale_jobs(controller.resource_name) | |
logging.info('All Jobs are requested to stop.') | |
if FLAGS.delete_tpu: | |
for controller in controllers: | |
controller.delete_tpu() | |
def start(controllers): | |
for controller in controllers: | |
controller.maybe_create_and_wait_for_ready() | |
controller.clean_stale_jobs(controller.resource_name) | |
mxla_coordinator_address = f'{controllers[0].get_ip_addresses()[0]}:8080' | |
num_slices = len(controllers) | |
run_command = ( | |
f'python3 MaxText/train.py MaxText/configs/base.yml run_name=$USER_$(date +%Y-%m-%d-%H-%M-%S) dcn_data_parallelism={num_slices}' | |
) | |
working_dir = os.path.expanduser('~/MaxText') | |
pip_installs = {"packages":[ | |
'orbax==0.1.6', | |
'absl-py', | |
'argparse', | |
'datetime', | |
'google-cloud-storage', | |
'ml-collections', | |
'numpy', | |
'optax', | |
'portpicker', | |
'protobuf==3.20.3', | |
'pylint', | |
'pytest', | |
'sentencepiece', | |
'tensorflow==2.12.0rc0', | |
'tensorflow-datasets', | |
'tensorboard-plugin-profile', | |
'tensorflow-text==2.12.0rc0', | |
'tensorboardx', | |
'jaxlib==0.4.6', | |
'jax[tpu]==0.4.6', | |
'-f https://storage.googleapis.com/jax-releases/libtpu_releases.html', | |
"-e git+https://github.com/google/flax.git@main#egg=flax", | |
], | |
"pip_check": False, | |
"pip_version": "==20.0.2;python_full_version=='3.8.10'",} | |
for slice_index in range(num_slices): | |
job = TpuRayJob( | |
entrypoint=run_command, | |
working_dir=working_dir, | |
pip_installs=pip_installs, | |
env_vars={ | |
'LIBTPU_INIT_ARGS': '"--xla_tpu_enable_megascale_barrier=true"', | |
'JAX_USE_PJRT_C_API_ON_TPU': '1', | |
'MEGASCALE_COORDINATOR_ADDRESS': mxla_coordinator_address, | |
'MEGASCALE_SLICE_ID': f'{slice_index}', | |
'MEGASCALE_NUM_SLICES': f'{num_slices}', | |
'TPU_VMODULE': 'communication_backend=10', | |
}, | |
) | |
controllers[slice_index].queue_tpu_workload(job) | |
num_slices_job_running = 0 | |
while num_slices_job_running < num_slices: | |
for controller in controllers: | |
if controller.jobs_in_status(JobStatus.RUNNING): | |
num_slices_job_running += 1 | |
logging.info( | |
'%d/%d slices jobs are running.', num_slices_job_running, num_slices | |
) | |
time.sleep(10) | |
logging.info('All Jobs are in running status, here is the list:') | |
for slice_index, controller in enumerate(controllers): | |
for job in controller.queued_jobs: | |
logging.info('job:%s @ slice: %d', job, slice_index) | |
def main(_): | |
project = get_default_gcp_project() | |
hostname = socket.gethostname() | |
unused_1, unused_2, unused_3, unused_4, (controller_ip, unused_5) = ( | |
socket.getaddrinfo(hostname, _DEFAULT_RAY_PORT)[0] | |
) | |
num_slices = FLAGS.num_slices | |
controllers = [] | |
for slice_index in range(num_slices): | |
controller = RayTpuController( | |
tpu_name=f'{FLAGS.tpu_name}-{slice_index}', | |
project=project, | |
zone='us-central2-b', | |
accelerator_type='V4', | |
accelerator_topology=FLAGS.tpu_topology, | |
version='tpu-vm-v4-base', | |
network=FLAGS.network, | |
subnetwork=FLAGS.subnet, | |
head_addr=f'{controller_ip}:{_DEFAULT_RAY_PORT}', | |
) | |
controllers.append(controller) | |
if FLAGS.mode == 'start': | |
start(controllers) | |
elif FLAGS.mode == 'stop': | |
stop(controllers) | |
if __name__ == '__main__': | |
flags.mark_flag_as_required('mode') | |
logging.set_verbosity(logging.INFO) | |
app.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment