Skip to content

Instantly share code, notes, and snippets.

@shubham-kaushal
Created February 10, 2025 07:21
Show Gist options
  • Save shubham-kaushal/fb2ec8075bca94577f75c8ce3b2b5148 to your computer and use it in GitHub Desktop.
Save shubham-kaushal/fb2ec8075bca94577f75c8ce3b2b5148 to your computer and use it in GitHub Desktop.
envs:
model_name: llama-2
tokenizer_path: /home/gcpuser/seed_workdir/ckpt/llama2-7b/original/tokenizer.model
run: |
cd JetStream
python benchmarks/benchmark_serving.py \
--tokenizer=$tokenizer_path --num-prompts=100 \
--dataset openorca --save-request-outputs \
--warmup-mode=sampled --model=$model_name
envs:
HF_TOKEN: # fill in your huggingface token
HF_REPO_ID: meta-llama/Llama-2-7b
model_name: llama-2
input_ckpt_dir: /home/gcpuser/seed_workdir/ckpt/llama2-7b/original
output_ckpt_dir: /home/gcpuser/seed_workdir/ckpt/llama2-7b/converted
tokenizer_path: /home/gcpuser/seed_workdir/ckpt/llama2-7b/original/tokenizer.model
setup: |
pip3 install huggingface_hub
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
# Setup TPU
pip3 install cloud-tpu-client
sudo apt update
sudo apt install -y libopenblas-base
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \
-f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Setup runtime for serving
git clone https://github.com/google/JetStream.git
cd JetStream
git checkout main
git pull origin main
pip install -e .
cd benchmarks
pip install -r requirements.in
cd ../..
git clone https://github.com/google/jetstream-pytorch.git
cd jetstream-pytorch/
git checkout jetstream-v0.2.3
source install_everything.sh
pip3 install -U --pre jax jaxlib libtpu-nightly requests \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Prepare checkpoint, inside jetstream-pytorch repo
mkdir -p ${input_ckpt_dir}
python3 -c "import huggingface_hub; huggingface_hub.snapshot_download('${HF_REPO_ID}', local_dir='${input_ckpt_dir}')"
mkdir -p ${output_ckpt_dir}
python -m convert_checkpoints --model_name=$model_name \
--input_checkpoint_dir=$input_ckpt_dir \
--output_checkpoint_dir=$output_ckpt_dir
run: |
cd jetstream-pytorch
python run_server.py --model_name=$model_name \
--size=7b --batch_size=24 --max_cache_length=2048 \
--checkpoint_path=$output_ckpt_dir \
--tokenizer_path=$tokenizer_path \
--sharding_config="default_shardings/llama.yaml"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment