Skip to content

Instantly share code, notes, and snippets.

@rosenrodt
Last active March 23, 2026 14:29
Show Gist options
  • Select an option

  • Save rosenrodt/fda6660676956edbdb71231f17f5e5fd to your computer and use it in GitHub Desktop.

Select an option

Save rosenrodt/fda6660676956edbdb71231f17f5e5fd to your computer and use it in GitHub Desktop.
qwen35_prefix_cache_probe.py
import argparse
import time
from dataclasses import dataclass
from typing import cast
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SamplingParams
from tensorrt_llm.tokenizer.tokenizer import tokenizer_factory
from transformers import PreTrainedTokenizerBase
MODEL_PATH = "/mnt/llm-models/Qwen3.5-35B-A3B"
DEFAULT_NUM_PREFIX_WORDS = 1024
DEFAULT_MAX_OUTPUT_TOKENS = 128
TOKEN_MARGIN = 32
RUNTIME_TOKEN_LIMIT_GRANULARITY = 1024
@dataclass(frozen=True)
class PromptSizing:
prefix_mode: str
num_prefix_words: int
prefix_text_len: int
prompt_str_len: int
prompt_token_count: int
estimated_total_tokens: int
runtime_max_seq_len: int
runtime_max_num_tokens: int
def build_token_prefix(num_prefix_words: int) -> str:
"Build a prefix of token-like words for use in the prompt."
if num_prefix_words <= 0:
return ""
return " ".join(
f"n{index:03d}" for index in range(num_prefix_words))
def build_prompt(prefix_text: str) -> str:
return (
"This prompt is used to validate deterministic prefix cache reuse "
"for Qwen3.5. Reuse the long prefix below when possible.\n"
f"{prefix_text}\n"
"Reply with exactly the string cache-ok."
)
def build_chat_prompt(tokenizer, prompt: str, enable_thinking: bool) -> str:
messages = [{"role": "user", "content": prompt}]
rendered = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking)
if not isinstance(rendered, str):
raise TypeError("Tokenizer returned token IDs instead of a string prompt.")
return rendered
def round_up_to_multiple(value: int, multiple: int) -> int:
if multiple <= 0:
raise ValueError("multiple must be positive")
return ((value + multiple - 1) // multiple) * multiple
def estimate_prompt_sizing(prefix_mode: str, num_prefix_words: int,
prefix_text_len: int, prompt: str,
max_output_tokens: int) -> PromptSizing:
tokenizer = cast(PreTrainedTokenizerBase, tokenizer_factory(MODEL_PATH))
prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False))
# Keep a margin above the visible prompt and requested output.
estimated_total_tokens = prompt_token_count + max_output_tokens + TOKEN_MARGIN
# WARNING. Round up to 1024 token granularity to avoid random lock ups!
runtime_token_limit = round_up_to_multiple(estimated_total_tokens,
RUNTIME_TOKEN_LIMIT_GRANULARITY)
return PromptSizing(prefix_mode=prefix_mode,
num_prefix_words=num_prefix_words,
prefix_text_len=prefix_text_len,
prompt_str_len=len(prompt),
prompt_token_count=prompt_token_count,
estimated_total_tokens=estimated_total_tokens,
runtime_max_seq_len=runtime_token_limit,
runtime_max_num_tokens=runtime_token_limit)
def run_case(prompt: str,
enable_block_reuse: bool,
prompt_sizing: PromptSizing,
enable_cuda_graph: bool,
enable_autotuner: bool,
load_format: str,
max_output_tokens: int) -> tuple:
kv_cache_config_kwargs = {
"free_gpu_memory_fraction": 0.8,
"enable_block_reuse": enable_block_reuse,
}
kv_cache_config = KvCacheConfig.model_validate(kv_cache_config_kwargs)
cuda_graph_config = CudaGraphConfig() if enable_cuda_graph else None
sampling_params = SamplingParams(max_tokens=max_output_tokens,
temperature=0,
return_perf_metrics=True)
llm_init_start = time.perf_counter()
llm = LLM(model=MODEL_PATH,
load_format=load_format,
kv_cache_config=kv_cache_config,
cuda_graph_config=cuda_graph_config,
tensor_parallel_size=1,
moe_expert_parallel_size=1,
max_seq_len=prompt_sizing.runtime_max_seq_len,
max_num_tokens=prompt_sizing.runtime_max_num_tokens,
max_batch_size=1,
disable_overlap_scheduler=True,
enable_autotuner=enable_autotuner)
with llm:
llm_init_elapsed = time.perf_counter() - llm_init_start
first_start = time.perf_counter()
first = llm.generate([prompt], sampling_params)[0].outputs[0]
first_elapsed = time.perf_counter() - first_start
second_start = time.perf_counter()
second = llm.generate([prompt], sampling_params)[0].outputs[0]
second_elapsed = time.perf_counter() - second_start
return first, second, llm_init_elapsed, first_elapsed, second_elapsed
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--num-prefix-words",
type=int,
default=DEFAULT_NUM_PREFIX_WORDS,
help="Number of synthetic words to generate.")
parser.add_argument("--disable-cuda-graph",
action="store_true",
help="Disable CUDA graph during LLM initialization.")
parser.add_argument("--load-format",
choices=["auto", "dummy"],
default="auto",
help="LLM load_format to use for model construction.")
parser.add_argument("--max-output-tokens",
type=int,
default=DEFAULT_MAX_OUTPUT_TOKENS,
help="Maximum number of tokens to generate per request.")
parser.add_argument("--enable-thinking",
action=argparse.BooleanOptionalAction,
default=False,
help="Apply the Qwen chat template with enable_thinking on or off.")
args = parser.parse_args()
enable_cuda_graph = not args.disable_cuda_graph
enable_autotuner = False
load_format = args.load_format
max_output_tokens = args.max_output_tokens
enable_thinking = args.enable_thinking
num_prefix_words = args.num_prefix_words
prefix_text = build_token_prefix(args.num_prefix_words)
tokenizer = tokenizer_factory(MODEL_PATH)
raw_prompt = build_prompt(prefix_text)
prompt = build_chat_prompt(tokenizer, raw_prompt, enable_thinking)
prompt_sizing = estimate_prompt_sizing("token_words", num_prefix_words,
len(prefix_text), prompt,
max_output_tokens)
(reuse_first, reuse_second, reuse_llm_init_elapsed, reuse_first_elapsed,
reuse_second_elapsed) = run_case(prompt,
enable_block_reuse=True,
prompt_sizing=prompt_sizing,
enable_cuda_graph=enable_cuda_graph,
enable_autotuner=enable_autotuner,
load_format=load_format,
max_output_tokens=max_output_tokens)
(_, no_reuse_second, no_reuse_llm_init_elapsed, no_reuse_first_elapsed,
no_reuse_second_elapsed) = run_case(prompt,
enable_block_reuse=False,
prompt_sizing=prompt_sizing,
enable_cuda_graph=enable_cuda_graph,
enable_autotuner=enable_autotuner,
load_format=load_format,
max_output_tokens=max_output_tokens)
reuse_first_metrics = reuse_first.request_perf_metrics.kv_cache_metrics
reuse_second_metrics = reuse_second.request_perf_metrics.kv_cache_metrics
no_reuse_second_metrics = no_reuse_second.request_perf_metrics.kv_cache_metrics
W = 82
SEP = "=" * W
THIN = "-" * W
def section(title: str) -> None:
print(SEP, flush=True)
print(f"= {title}", flush=True)
print(SEP, flush=True)
def row(label: str, value: object, width: int = 40) -> None:
print(f"{label + ':':<{width}} {value}", flush=True)
def preview_text(text: str, edge_chars: int = 96) -> str:
if len(text) <= edge_chars * 2 + 5:
return repr(text)
return repr(f"{text[:edge_chars]} ... {text[-edge_chars:]}")
print(flush=True)
section("KV CACHE METRICS")
print(flush=True)
hdr_w = 20
print(f"{'':40s} {'Reuse ON':>{hdr_w}s} {'Reuse OFF':>{hdr_w}s}",
flush=True)
print(THIN, flush=True)
print(
f"{'Reused blocks (1st request):':<40s} "
f"{reuse_first_metrics.num_reused_blocks:>{hdr_w}d} "
f"{'n/a':>{hdr_w}s}",
flush=True)
print(
f"{'Reused blocks (2nd request):':<40s} "
f"{reuse_second_metrics.num_reused_blocks:>{hdr_w}d} "
f"{no_reuse_second_metrics.num_reused_blocks:>{hdr_w}d}",
flush=True)
print(
f"{'New blocks alloc (2nd request):':<40s} "
f"{reuse_second_metrics.num_new_allocated_blocks:>{hdr_w}d} "
f"{no_reuse_second_metrics.num_new_allocated_blocks:>{hdr_w}d}",
flush=True)
print(flush=True)
section("TIMING (seconds)")
print(flush=True)
print(f"{'':40s} {'Reuse ON':>{hdr_w}s} {'Reuse OFF':>{hdr_w}s}",
flush=True)
print(THIN, flush=True)
print(
f"{'LLM init:':<40s} "
f"{reuse_llm_init_elapsed:>{hdr_w}.3f} "
f"{no_reuse_llm_init_elapsed:>{hdr_w}.3f}",
flush=True)
print(
f"{'1st generate:':<40s} "
f"{reuse_first_elapsed:>{hdr_w}.3f} "
f"{no_reuse_first_elapsed:>{hdr_w}.3f}",
flush=True)
print(
f"{'2nd generate:':<40s} "
f"{reuse_second_elapsed:>{hdr_w}.3f} "
f"{no_reuse_second_elapsed:>{hdr_w}.3f}",
flush=True)
print(flush=True)
print(f"-- Speed ratios {THIN[16:]}", flush=True)
row("1st / 2nd (reuse ON)",
f"{reuse_first_elapsed / reuse_second_elapsed:.2f}x")
row("2nd reuse OFF / 2nd reuse ON",
f"{no_reuse_second_elapsed / reuse_second_elapsed:.2f}x")
print(flush=True)
section("PROMPT SIZING")
print(flush=True)
row("Prefix mode", prompt_sizing.prefix_mode)
row("Prefix size argument", prompt_sizing.num_prefix_words)
row("Prefix text length (chars)", prompt_sizing.prefix_text_len)
row("Full prompt length (chars)", prompt_sizing.prompt_str_len)
row("Prompt token count", prompt_sizing.prompt_token_count)
row("Estimated total tokens", prompt_sizing.estimated_total_tokens)
print(flush=True)
section("CONFIGURATION")
print(flush=True)
row("max_seq_len", prompt_sizing.runtime_max_seq_len)
row("max_num_tokens", prompt_sizing.runtime_max_num_tokens)
row("max_output_tokens", max_output_tokens)
row("CUDA graph", enable_cuda_graph)
row("Autotuner", enable_autotuner)
row("Enable thinking", enable_thinking)
row("Load format", load_format)
print(flush=True)
section("OUTPUT PARITY")
print(flush=True)
row("Prompt preview", preview_text(prompt))
row("Reuse ON 1st output tokens", len(reuse_first.token_ids))
row("Reuse ON 2nd output tokens", len(reuse_second.token_ids))
row("Reuse OFF 2nd output tokens", len(no_reuse_second.token_ids))
row("Text match", reuse_second.text == no_reuse_second.text)
row("Token ID match",
reuse_second.token_ids == no_reuse_second.token_ids)
row("Reuse ON output", repr(reuse_second.text))
row("Reuse OFF output", repr(no_reuse_second.text))
print(SEP, flush=True)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment