Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Created May 21, 2024 23:03
Show Gist options
  • Save youkaichao/085588e5848c99dadca542e97fcf3cab to your computer and use it in GitHub Desktop.
Save youkaichao/085588e5848c99dadca542e97fcf3cab to your computer and use it in GitHub Desktop.
vLLM + torch.compile
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
from contextlib import nullcontext
with nullcontext():
# import depyf
# with depyf.prepare_debug("./debug_output"):
llm = LLM(model="facebook/opt-125m")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@youkaichao
Copy link
Author

Steps:

for several models vLLM supports:

  1. run the script as is, measure the time
  2. add one line self.model = torch.compile(self.model) in function https://github.com/vllm-project/vllm/blob/99eff67ba9155b5fec9a9abd939e3a29a1b42dce/vllm/worker/model_runner.py#L131 , before this function returns , measure the time
  3. remove with nullcontext(): , uncomment the following two lines, run again. check if there are any decompilation error. if not, inspect the debug output.

Goal:

Have no graph break for the whole model, for important models used in benchmarking

Further goal:

test how it works for tensor parallel inference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment