Last active
March 23, 2026 14:29
-
-
Save rosenrodt/fda6660676956edbdb71231f17f5e5fd to your computer and use it in GitHub Desktop.
qwen35_prefix_cache_probe.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import time | |
| from dataclasses import dataclass | |
| from typing import cast | |
| from tensorrt_llm import LLM | |
| from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SamplingParams | |
| from tensorrt_llm.tokenizer.tokenizer import tokenizer_factory | |
| from transformers import PreTrainedTokenizerBase | |
| MODEL_PATH = "/mnt/llm-models/Qwen3.5-35B-A3B" | |
| DEFAULT_NUM_PREFIX_WORDS = 1024 | |
| DEFAULT_MAX_OUTPUT_TOKENS = 128 | |
| TOKEN_MARGIN = 32 | |
| RUNTIME_TOKEN_LIMIT_GRANULARITY = 1024 | |
| @dataclass(frozen=True) | |
| class PromptSizing: | |
| prefix_mode: str | |
| num_prefix_words: int | |
| prefix_text_len: int | |
| prompt_str_len: int | |
| prompt_token_count: int | |
| estimated_total_tokens: int | |
| runtime_max_seq_len: int | |
| runtime_max_num_tokens: int | |
| def build_token_prefix(num_prefix_words: int) -> str: | |
| "Build a prefix of token-like words for use in the prompt." | |
| if num_prefix_words <= 0: | |
| return "" | |
| return " ".join( | |
| f"n{index:03d}" for index in range(num_prefix_words)) | |
| def build_prompt(prefix_text: str) -> str: | |
| return ( | |
| "This prompt is used to validate deterministic prefix cache reuse " | |
| "for Qwen3.5. Reuse the long prefix below when possible.\n" | |
| f"{prefix_text}\n" | |
| "Reply with exactly the string cache-ok." | |
| ) | |
| def build_chat_prompt(tokenizer, prompt: str, enable_thinking: bool) -> str: | |
| messages = [{"role": "user", "content": prompt}] | |
| rendered = tokenizer.apply_chat_template(messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=enable_thinking) | |
| if not isinstance(rendered, str): | |
| raise TypeError("Tokenizer returned token IDs instead of a string prompt.") | |
| return rendered | |
| def round_up_to_multiple(value: int, multiple: int) -> int: | |
| if multiple <= 0: | |
| raise ValueError("multiple must be positive") | |
| return ((value + multiple - 1) // multiple) * multiple | |
| def estimate_prompt_sizing(prefix_mode: str, num_prefix_words: int, | |
| prefix_text_len: int, prompt: str, | |
| max_output_tokens: int) -> PromptSizing: | |
| tokenizer = cast(PreTrainedTokenizerBase, tokenizer_factory(MODEL_PATH)) | |
| prompt_token_count = len(tokenizer.encode(prompt, add_special_tokens=False)) | |
| # Keep a margin above the visible prompt and requested output. | |
| estimated_total_tokens = prompt_token_count + max_output_tokens + TOKEN_MARGIN | |
| # WARNING. Round up to 1024 token granularity to avoid random lock ups! | |
| runtime_token_limit = round_up_to_multiple(estimated_total_tokens, | |
| RUNTIME_TOKEN_LIMIT_GRANULARITY) | |
| return PromptSizing(prefix_mode=prefix_mode, | |
| num_prefix_words=num_prefix_words, | |
| prefix_text_len=prefix_text_len, | |
| prompt_str_len=len(prompt), | |
| prompt_token_count=prompt_token_count, | |
| estimated_total_tokens=estimated_total_tokens, | |
| runtime_max_seq_len=runtime_token_limit, | |
| runtime_max_num_tokens=runtime_token_limit) | |
| def run_case(prompt: str, | |
| enable_block_reuse: bool, | |
| prompt_sizing: PromptSizing, | |
| enable_cuda_graph: bool, | |
| enable_autotuner: bool, | |
| load_format: str, | |
| max_output_tokens: int) -> tuple: | |
| kv_cache_config_kwargs = { | |
| "free_gpu_memory_fraction": 0.8, | |
| "enable_block_reuse": enable_block_reuse, | |
| } | |
| kv_cache_config = KvCacheConfig.model_validate(kv_cache_config_kwargs) | |
| cuda_graph_config = CudaGraphConfig() if enable_cuda_graph else None | |
| sampling_params = SamplingParams(max_tokens=max_output_tokens, | |
| temperature=0, | |
| return_perf_metrics=True) | |
| llm_init_start = time.perf_counter() | |
| llm = LLM(model=MODEL_PATH, | |
| load_format=load_format, | |
| kv_cache_config=kv_cache_config, | |
| cuda_graph_config=cuda_graph_config, | |
| tensor_parallel_size=1, | |
| moe_expert_parallel_size=1, | |
| max_seq_len=prompt_sizing.runtime_max_seq_len, | |
| max_num_tokens=prompt_sizing.runtime_max_num_tokens, | |
| max_batch_size=1, | |
| disable_overlap_scheduler=True, | |
| enable_autotuner=enable_autotuner) | |
| with llm: | |
| llm_init_elapsed = time.perf_counter() - llm_init_start | |
| first_start = time.perf_counter() | |
| first = llm.generate([prompt], sampling_params)[0].outputs[0] | |
| first_elapsed = time.perf_counter() - first_start | |
| second_start = time.perf_counter() | |
| second = llm.generate([prompt], sampling_params)[0].outputs[0] | |
| second_elapsed = time.perf_counter() - second_start | |
| return first, second, llm_init_elapsed, first_elapsed, second_elapsed | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--num-prefix-words", | |
| type=int, | |
| default=DEFAULT_NUM_PREFIX_WORDS, | |
| help="Number of synthetic words to generate.") | |
| parser.add_argument("--disable-cuda-graph", | |
| action="store_true", | |
| help="Disable CUDA graph during LLM initialization.") | |
| parser.add_argument("--load-format", | |
| choices=["auto", "dummy"], | |
| default="auto", | |
| help="LLM load_format to use for model construction.") | |
| parser.add_argument("--max-output-tokens", | |
| type=int, | |
| default=DEFAULT_MAX_OUTPUT_TOKENS, | |
| help="Maximum number of tokens to generate per request.") | |
| parser.add_argument("--enable-thinking", | |
| action=argparse.BooleanOptionalAction, | |
| default=False, | |
| help="Apply the Qwen chat template with enable_thinking on or off.") | |
| args = parser.parse_args() | |
| enable_cuda_graph = not args.disable_cuda_graph | |
| enable_autotuner = False | |
| load_format = args.load_format | |
| max_output_tokens = args.max_output_tokens | |
| enable_thinking = args.enable_thinking | |
| num_prefix_words = args.num_prefix_words | |
| prefix_text = build_token_prefix(args.num_prefix_words) | |
| tokenizer = tokenizer_factory(MODEL_PATH) | |
| raw_prompt = build_prompt(prefix_text) | |
| prompt = build_chat_prompt(tokenizer, raw_prompt, enable_thinking) | |
| prompt_sizing = estimate_prompt_sizing("token_words", num_prefix_words, | |
| len(prefix_text), prompt, | |
| max_output_tokens) | |
| (reuse_first, reuse_second, reuse_llm_init_elapsed, reuse_first_elapsed, | |
| reuse_second_elapsed) = run_case(prompt, | |
| enable_block_reuse=True, | |
| prompt_sizing=prompt_sizing, | |
| enable_cuda_graph=enable_cuda_graph, | |
| enable_autotuner=enable_autotuner, | |
| load_format=load_format, | |
| max_output_tokens=max_output_tokens) | |
| (_, no_reuse_second, no_reuse_llm_init_elapsed, no_reuse_first_elapsed, | |
| no_reuse_second_elapsed) = run_case(prompt, | |
| enable_block_reuse=False, | |
| prompt_sizing=prompt_sizing, | |
| enable_cuda_graph=enable_cuda_graph, | |
| enable_autotuner=enable_autotuner, | |
| load_format=load_format, | |
| max_output_tokens=max_output_tokens) | |
| reuse_first_metrics = reuse_first.request_perf_metrics.kv_cache_metrics | |
| reuse_second_metrics = reuse_second.request_perf_metrics.kv_cache_metrics | |
| no_reuse_second_metrics = no_reuse_second.request_perf_metrics.kv_cache_metrics | |
| W = 82 | |
| SEP = "=" * W | |
| THIN = "-" * W | |
| def section(title: str) -> None: | |
| print(SEP, flush=True) | |
| print(f"= {title}", flush=True) | |
| print(SEP, flush=True) | |
| def row(label: str, value: object, width: int = 40) -> None: | |
| print(f"{label + ':':<{width}} {value}", flush=True) | |
| def preview_text(text: str, edge_chars: int = 96) -> str: | |
| if len(text) <= edge_chars * 2 + 5: | |
| return repr(text) | |
| return repr(f"{text[:edge_chars]} ... {text[-edge_chars:]}") | |
| print(flush=True) | |
| section("KV CACHE METRICS") | |
| print(flush=True) | |
| hdr_w = 20 | |
| print(f"{'':40s} {'Reuse ON':>{hdr_w}s} {'Reuse OFF':>{hdr_w}s}", | |
| flush=True) | |
| print(THIN, flush=True) | |
| print( | |
| f"{'Reused blocks (1st request):':<40s} " | |
| f"{reuse_first_metrics.num_reused_blocks:>{hdr_w}d} " | |
| f"{'n/a':>{hdr_w}s}", | |
| flush=True) | |
| print( | |
| f"{'Reused blocks (2nd request):':<40s} " | |
| f"{reuse_second_metrics.num_reused_blocks:>{hdr_w}d} " | |
| f"{no_reuse_second_metrics.num_reused_blocks:>{hdr_w}d}", | |
| flush=True) | |
| print( | |
| f"{'New blocks alloc (2nd request):':<40s} " | |
| f"{reuse_second_metrics.num_new_allocated_blocks:>{hdr_w}d} " | |
| f"{no_reuse_second_metrics.num_new_allocated_blocks:>{hdr_w}d}", | |
| flush=True) | |
| print(flush=True) | |
| section("TIMING (seconds)") | |
| print(flush=True) | |
| print(f"{'':40s} {'Reuse ON':>{hdr_w}s} {'Reuse OFF':>{hdr_w}s}", | |
| flush=True) | |
| print(THIN, flush=True) | |
| print( | |
| f"{'LLM init:':<40s} " | |
| f"{reuse_llm_init_elapsed:>{hdr_w}.3f} " | |
| f"{no_reuse_llm_init_elapsed:>{hdr_w}.3f}", | |
| flush=True) | |
| print( | |
| f"{'1st generate:':<40s} " | |
| f"{reuse_first_elapsed:>{hdr_w}.3f} " | |
| f"{no_reuse_first_elapsed:>{hdr_w}.3f}", | |
| flush=True) | |
| print( | |
| f"{'2nd generate:':<40s} " | |
| f"{reuse_second_elapsed:>{hdr_w}.3f} " | |
| f"{no_reuse_second_elapsed:>{hdr_w}.3f}", | |
| flush=True) | |
| print(flush=True) | |
| print(f"-- Speed ratios {THIN[16:]}", flush=True) | |
| row("1st / 2nd (reuse ON)", | |
| f"{reuse_first_elapsed / reuse_second_elapsed:.2f}x") | |
| row("2nd reuse OFF / 2nd reuse ON", | |
| f"{no_reuse_second_elapsed / reuse_second_elapsed:.2f}x") | |
| print(flush=True) | |
| section("PROMPT SIZING") | |
| print(flush=True) | |
| row("Prefix mode", prompt_sizing.prefix_mode) | |
| row("Prefix size argument", prompt_sizing.num_prefix_words) | |
| row("Prefix text length (chars)", prompt_sizing.prefix_text_len) | |
| row("Full prompt length (chars)", prompt_sizing.prompt_str_len) | |
| row("Prompt token count", prompt_sizing.prompt_token_count) | |
| row("Estimated total tokens", prompt_sizing.estimated_total_tokens) | |
| print(flush=True) | |
| section("CONFIGURATION") | |
| print(flush=True) | |
| row("max_seq_len", prompt_sizing.runtime_max_seq_len) | |
| row("max_num_tokens", prompt_sizing.runtime_max_num_tokens) | |
| row("max_output_tokens", max_output_tokens) | |
| row("CUDA graph", enable_cuda_graph) | |
| row("Autotuner", enable_autotuner) | |
| row("Enable thinking", enable_thinking) | |
| row("Load format", load_format) | |
| print(flush=True) | |
| section("OUTPUT PARITY") | |
| print(flush=True) | |
| row("Prompt preview", preview_text(prompt)) | |
| row("Reuse ON 1st output tokens", len(reuse_first.token_ids)) | |
| row("Reuse ON 2nd output tokens", len(reuse_second.token_ids)) | |
| row("Reuse OFF 2nd output tokens", len(no_reuse_second.token_ids)) | |
| row("Text match", reuse_second.text == no_reuse_second.text) | |
| row("Token ID match", | |
| reuse_second.token_ids == no_reuse_second.token_ids) | |
| row("Reuse ON output", repr(reuse_second.text)) | |
| row("Reuse OFF output", repr(no_reuse_second.text)) | |
| print(SEP, flush=True) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment