Created
February 10, 2025 07:21
-
-
Save shubham-kaushal/fb2ec8075bca94577f75c8ce3b2b5148 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
envs: | |
model_name: llama-2 | |
tokenizer_path: /home/gcpuser/seed_workdir/ckpt/llama2-7b/original/tokenizer.model | |
run: | | |
cd JetStream | |
python benchmarks/benchmark_serving.py \ | |
--tokenizer=$tokenizer_path --num-prompts=100 \ | |
--dataset openorca --save-request-outputs \ | |
--warmup-mode=sampled --model=$model_name |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
envs: | |
HF_TOKEN: # fill in your huggingface token | |
HF_REPO_ID: meta-llama/Llama-2-7b | |
model_name: llama-2 | |
input_ckpt_dir: /home/gcpuser/seed_workdir/ckpt/llama2-7b/original | |
output_ckpt_dir: /home/gcpuser/seed_workdir/ckpt/llama2-7b/converted | |
tokenizer_path: /home/gcpuser/seed_workdir/ckpt/llama2-7b/original/tokenizer.model | |
setup: | | |
pip3 install huggingface_hub | |
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')" | |
# Setup TPU | |
pip3 install cloud-tpu-client | |
sudo apt update | |
sudo apt install -y libopenblas-base | |
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \ | |
--index-url https://download.pytorch.org/whl/nightly/cpu | |
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \ | |
-f https://storage.googleapis.com/libtpu-releases/index.html | |
pip install torch_xla[pallas] \ | |
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ | |
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html | |
# Setup runtime for serving | |
git clone https://github.com/google/JetStream.git | |
cd JetStream | |
git checkout main | |
git pull origin main | |
pip install -e . | |
cd benchmarks | |
pip install -r requirements.in | |
cd ../.. | |
git clone https://github.com/google/jetstream-pytorch.git | |
cd jetstream-pytorch/ | |
git checkout jetstream-v0.2.3 | |
source install_everything.sh | |
pip3 install -U --pre jax jaxlib libtpu-nightly requests \ | |
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ | |
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html | |
# Prepare checkpoint, inside jetstream-pytorch repo | |
mkdir -p ${input_ckpt_dir} | |
python3 -c "import huggingface_hub; huggingface_hub.snapshot_download('${HF_REPO_ID}', local_dir='${input_ckpt_dir}')" | |
mkdir -p ${output_ckpt_dir} | |
python -m convert_checkpoints --model_name=$model_name \ | |
--input_checkpoint_dir=$input_ckpt_dir \ | |
--output_checkpoint_dir=$output_ckpt_dir | |
run: | | |
cd jetstream-pytorch | |
python run_server.py --model_name=$model_name \ | |
--size=7b --batch_size=24 --max_cache_length=2048 \ | |
--checkpoint_path=$output_ckpt_dir \ | |
--tokenizer_path=$tokenizer_path \ | |
--sharding_config="default_shardings/llama.yaml" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment