Skip to content

Instantly share code, notes, and snippets.

1. Start the benchmark server in vscode as [this](https://gist.github.com/vanbasten23/dd4f3cbb314a7b9cf6c003103c23c019). Select the correct python intepreter.
2. Then start the vllm server in debugger.
3. After the server is up and running.
4. Add the breakpoint (remember to turn of dynamo and jax jit)
5. Use the [script](https://gist.github.com/vanbasten23/726b28f072993fb7587482672b9c96a9) to send benchmarking request. Make sure to use the correct conda/python.
6. Then dump the input and output.
=========================
pip install flatbuffers
#!/bin/bash
# Usage:
# bash run_tpu_benchmark_client.sh --model Qwen/Qwen2.5-1.5B-Instruct --tp 1
LONGOPTS=model:,tp:,profile
# Parse arguments
PARSED=$(getopt --options=$OPTIONS --longoptions=$LONGOPTS --name "$0" -- "$@")
if [[ $? -ne 0 ]]; then
exit 2
{
"name": "newjax_benchmark_server",
"type": "debugpy",
"request": "launch",
"program": "/home/xiowei_google_com/miniconda3/envs/vllm_newjax/bin/vllm",
"console": "integratedTerminal",
"justMyCode": false,
"env": {
"MODEL_IMPL_TYPE": "vllm",
"TPU_BACKEND_TYPE": "jax",
import jax
from jax import export
import jax.numpy as jnp
import pickle
import time
import statistics
with open("/home/xiowei_google_com/old_exports.pkl", "rb") as f:
data = pickle.load(f)
local keymap = vim.keymap.set
local opts = { noremap = true, silent = true }
-- remap leader key
keymap("n", "<Space>", "", opts)
vim.g.mapleader = " "
vim.g.maplocalleader = " "
-- yank to system clipboard
keymap({"n", "v"}, "<leader>y", '"+y', opts)
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
base = "Qwen/Qwen2.5-3B-Instruct"
adapter = "./lora-1plus1-666"
tok = AutoTokenizer.from_pretrained(base)
m = AutoModelForCausalLM.from_pretrained(base, torch_dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu")
m = PeftModel.from_pretrained(m, adapter)
# minimal_lora_1plus1.py
# pip install -U transformers peft datasets accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, TaskType, get_peft_model
from datasets import Dataset
import torch, os
BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
OUT_DIR = "./lora-1plus1-666"
# This script demonstrate that under torchax, tensor.copy_(lora_tensor) will not change the sharding of `tensor`.
import jax
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchax
from torchax.interop import jax_view
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from torchax.interop import jax_view, torch_view
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
base = "Qwen/Qwen2.5-3B-Instruct"
adapter = "./lora-1plus1-666"
tok = AutoTokenizer.from_pretrained(base)
m = AutoModelForCausalLM.from_pretrained(base, torch_dtype=torch.bfloat16).to("cuda" if torch.cuda.is_available() else "cpu")
m = PeftModel.from_pretrained(m, adapter)
# minimal_lora_1plus1.py
# pip install -U transformers peft datasets accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import LoraConfig, TaskType, get_peft_model
from datasets import Dataset
import torch, os
BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
OUT_DIR = "./lora-1plus1-666"