Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save m0o0scar/a02c603d24e089e03e2813148e480043 to your computer and use it in GitHub Desktop.
Save m0o0scar/a02c603d24e089e03e2813148e480043 to your computer and use it in GitHub Desktop.
LazyLLM: Dynamic Token Pruning for Efficient Long Context LLM Inference. Continue this conversation at https://readfm.vercel.app?gist=a02c603d24e089e03e2813148e480043

[arxiv] LazyLLM: Dynamic Token Pruning for Efficient Long Context LLM Inference

Source

Qichen Fu, Minsik Cho, Thomas Merth, Sachin Mehta, Mohammad Rastegari, Mahyar Najibi

The inference of transformer-based large language models consists of two sequential stages: 1) a prefilling stage to compute the KV cache of prompts and generate the first token, and 2) a decoding stage to generate subsequent tokens. For long prompts, the KV cache must be computed for all tokens during the prefilling stage, which can significantly increase the time needed to generate the first token. Consequently, the prefilling stage may become a bottleneck in the generation process. An open question remains whether all prompt tokens are essential for generating the first token. To answer this, we introduce a novel method, LazyLLM, that selectively computes the KV for tokens important for the next token prediction in both the prefilling and decoding stages. Contrary to static pruning approaches that prune the prompt at once, LazyLLM allows language models to dynamically select different subsets of tokens from the context in different generation steps, even though they might be pruned in previous steps. Extensive experiments on standard datasets across various tasks demonstrate that LazyLLM is a generic method that can be seamlessly integrated with existing language models to significantly accelerate the generation without fine-tuning. For instance, in the multi-document question-answering task, LazyLLM accelerates the prefilling stage of the LLama 2 7B model by 2.34x while maintaining accuracy.

URL: https://huggingface.co/papers/2407.14057

Conversation

▲ Please summarize the paper into bullet points.

Here is a summary of the paper in bullet points:

  • Problem: LLMs with long prompts have slow "time-to-first-token" (TTFT) due to the need to compute the KV cache for all tokens in the prompt.
  • Solution: LazyLLM dynamically prunes tokens during both the prefilling and decoding stages, only computing the KV for tokens deemed important for the next token prediction.
  • Key Features:
    • Dynamic Pruning: Unlike static pruning, LazyLLM allows the model to revive previously pruned tokens, which is crucial for maintaining accuracy.
    • Progressive Pruning: Prunes more tokens in later transformer layers, where they are less critical for performance.
    • Aux Cache: Caches hidden states of pruned tokens to avoid repetitive computation when they are revived.
  • Benefits:
    • Universal: Can be integrated with any transformer-based LLM.
    • Training-free: No fine-tuning required.
    • Effective: Significantly accelerates inference without sacrificing accuracy.
  • Results:
    • LazyLLM achieves significant TTFT speedup (up to 2.34x) with negligible accuracy loss on various tasks.
    • LazyLLM outperforms baselines (random token drop, static token pruning, prompt compression) in terms of speedup vs. accuracy trade-off.
    • LazyLLM reduces the overall computation by not computing all tokens in the prompt, leading to further speedup in the generation process.
  • Conclusion: LazyLLM is a simple yet effective technique for accelerating LLM inference with long prompts, offering a significant improvement in efficiency without requiring any model retraining.

▲ Tell me more about dynamic pruning and progressive pruning, what are they and how do they work?

Here's a breakdown of dynamic pruning and progressive pruning as described in the LazyLLM paper:

Dynamic Pruning

  • What it is: Dynamic pruning is a method of selectively removing tokens from the input sequence during each generation step of an LLM. This means the model can choose different subsets of tokens to use in different steps, even if a token was pruned in a previous step.
  • How it works:
    1. Importance Assessment: LazyLLM uses the attention scores from the previous transformer layer to determine the importance of each token. Tokens with low attention scores are considered less important.
    2. Token Selection: A threshold is applied to the attention scores, and tokens below the threshold are pruned. This threshold can be dynamically adjusted based on the task and the current generation step.
    3. Revival: Crucially, LazyLLM allows the model to revive previously pruned tokens if they become relevant in later steps. This is done by storing the hidden states of pruned tokens in an "Aux Cache."

Progressive Pruning

  • What it is: Progressive pruning is a strategy for applying dynamic pruning in a way that gradually reduces the number of tokens used as the model progresses through the transformer layers.
  • How it works:
    1. Layer-wise Pruning: Pruning is applied at multiple layers within the transformer stack.
    2. Increasing Pruning Ratio: The percentage of tokens pruned increases as the model moves to later layers. This is because later layers are generally less sensitive to token pruning.
    3. Balance: This approach balances the need for efficiency (by pruning more tokens) with the need for accuracy (by preserving more tokens in the early layers).

Why Dynamic and Progressive Pruning are Important

  • Accuracy: Dynamic pruning allows the model to adapt to the changing context of the generation process, ensuring that important tokens are not permanently removed.
  • Efficiency: Progressive pruning optimizes the computation by gradually reducing the number of tokens processed, leading to significant speedup.
  • Flexibility: The pruning parameters (number of layers, pruning ratio) can be adjusted to fine-tune the trade-off between speed and accuracy.

In essence, dynamic and progressive pruning work together to enable LazyLLM to efficiently process long prompts while maintaining high accuracy.

▲ How many tokens can this method save?

The paper doesn't explicitly state a fixed percentage of tokens saved by LazyLLM. However, it provides several insights that help us understand the potential savings:

  • Sparse Attention: The paper highlights that attention scores for input tokens are often sparse, meaning many tokens have low attention scores and are likely redundant. This suggests significant potential for pruning.
  • Cumulative Token Usage: Figure 7 in the paper shows that LazyLLM often uses less than 100% of the prompt tokens during the entire generation process. This indicates that a substantial number of tokens are never selected, implying significant savings.
  • Task-Specific: The amount of token saving will likely vary depending on the specific task and the length of the prompt. Tasks with more complex relationships between tokens might require more tokens to be processed.

Overall, while a precise percentage of token savings is not provided, the paper strongly suggests that LazyLLM can significantly reduce the number of tokens processed, leading to substantial computational savings.

To get a more concrete idea of the savings, you could:

  • Analyze the attention scores: Examine the attention scores for different tasks and prompts to estimate the percentage of tokens with low scores.
  • Experiment with LazyLLM: Run experiments with different pruning parameters and measure the percentage of tokens computed.

By doing so, you can gain a better understanding of the potential token savings achievable with LazyLLM for your specific use case.

{
"url": "https://huggingface.co/papers/2407.14057",
"type": "arxiv",
"title": "LazyLLM: Dynamic Token Pruning for Efficient Long Context LLM Inference",
"subtitle": "Qichen Fu, Minsik Cho, Thomas Merth, Sachin Mehta, Mohammad Rastegari, Mahyar Najibi",
"description": "The inference of transformer-based large language models consists of two sequential stages: 1) a prefilling stage to compute the KV cache of prompts and generate the first token, and 2) a decoding stage to generate subsequent tokens. For long prompts, the KV cache must be computed for all tokens during the prefilling stage, which can significantly increase the time needed to generate the first token. Consequently, the prefilling stage may become a bottleneck in the generation process. An open question remains whether all prompt tokens are essential for generating the first token. To answer this, we introduce a novel method, LazyLLM, that selectively computes the KV for tokens important for the next token prediction in both the prefilling and decoding stages. Contrary to static pruning approaches that prune the prompt at once, LazyLLM allows language models to dynamically select different subsets of tokens from the context in different generation steps, even though they might be pruned in previous steps. Extensive experiments on standard datasets across various tasks demonstrate that LazyLLM is a generic method that can be seamlessly integrated with existing language models to significantly accelerate the generation without fine-tuning. For instance, in the multi-document question-answering task, LazyLLM accelerates the prefilling stage of the LLama 2 7B model by 2.34x while maintaining accuracy."
}

LazyLLM: Dynamic Token Pruning for Efficient Long Context LLM Inference

Qichen FuApple Minsik ChoApple Thomas MerthApple Sachin MehtaApple Mohammad Rastegari Mahyar NajibiApple

Abstract

The inference of transformer-based large language models consists of two sequential stages: 1) a prefilling stage to compute the KV cache of prompts and generate the first token, and 2) a decoding stage to generate subsequent tokens. For long prompts, the KV cache must be computed for all tokens during the prefilling stage, which can significantly increase the time needed to generate the first token. Consequently, the prefilling stage may become a bottleneck in the generation process. An open question remains whether all prompt tokens are essential for generating the first token. To answer this, we introduce a novel method, LazyLLM, that selectively computes the KV for tokens important for the next token prediction in both the prefilling and decoding stages. Contrary to static pruning approaches that prune the prompt at once, LazyLLM allows language models to dynamically select different subsets of tokens from the context in different generation steps, even though they might be pruned in previous steps. Extensive experiments on standard datasets across various tasks demonstrate that LazyLLM is a generic method that can be seamlessly integrated with existing language models to significantly accelerate the generation without fine-tuning. For instance, in the multi-document question-answering task, LazyLLM accelerates the prefilling stage of the LLama 2 7B model by 2.34×2.34\times2.34 × while maintaining accuracy.

1 Introduction

Standard prompt-based LLM inference has two sequential stages: prefilling and decoding, as shown in Figure 1. During the prefilling stage, the model computes and saves the KV cache of each token from the prompt, and predicts the first token. We refer to the time taken during prefilling stage as “time-to-first-token” (TTFT). Following the prefilling stage is the decoding stage, where the model reuses cached KVs to decode the next token iteratively until the stop criteria are met.

During the prefilling stage, all tokens from the prompt are used by all transformer layers. For long prompts, TTFT could be slow because state-of-the-art transformer-based LLMs are both deep and wide (Pope et al., 2023; Kim et al., 2023; Aminabadi et al., 2022), and the cost of computing attention increases quadratically with the number of tokens in the prompts. For instance, Llama 2 (Touvron et al., 2023), with 7 billion parameters, stacks 32 transformer layers with a model dimension of 4096. In this scenario, TTFT requires 21×21\times21 × the walltime of each subsequent decoding step, and accounts for approximately 23% of the total generation time on the LongBench benchmark111The average LongBench prompt length is 3376337633763376 tokens and the average generation length is 68686868 tokens. (Bai et al., 2023). Therefore, optimizing TTFT is a critical path toward efficient LLM inference (NVIDIA, 2024).

While optimizing LLM inference is an active area of research, many methods (Leviathan et al., 2023; Cai et al., 2024; Zhang et al., 2024; Bhendawade et al., 2024; Li et al., 2024) have focused on improving inference speed during the decoding stage. Yet, there is little attention given to improving TTFT. We note that some compression-based works implicitly improve the TTFT by reducing the size of LLMs (Frantar et al., 2022; Sun et al., 2023; Ma et al., 2023). However, an orthogonal line of research(Li et al., 2023; Jiang et al., 2023; Dao et al., 2022) investigates how TTFT can be improved given a static transformer architecture. Within this line of research, a natural question arises: Are all prompt tokens essential for generating the first token?

LLM profiling on the LongBench benchmark (Bai et al., 2023) in Figure 2 reveals that the attention scores of input tokens w.r.t. to the first generated token are very sparse, indicating that many tokens in the input prompt are redundant and can be removed without affecting the next token prediction. To this end, we propose LazyLLM, a novel, simple, yet effective technique tailored for speeding up prefilling. As depicted in Figure 3, in each generation step, LazyLLM selectively computes the KV for tokens important for the next token prediction and “lazily” defers the computation of remaining tokens to later steps when they become relevant. We propose using the attention score of the prior transformer layer to measure the importance of tokens and progressively prune tokens along the depth of the transformer. In contrast to prompt compression works (Li et al., 2023; Jiang et al., 2023; Xu et al., 2023), which permanently reduce the prompt for all the following generation steps, our method allows the model to revive previously pruned tokens, which we found crucial to retain accuracy. Extending progressive token pruning to all generation steps is non-trivial. Specifically, if a token is pruned at generation step t𝑡titalic_t, and is revived at generation step t" class="ltx_Math" display="inline" id="S1.p4.2.m2.1">t′>tsuperscript𝑡′𝑡t^{\prime}>titalic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT > italic_t, some hidden states would need to be recomputed during step t′superscript𝑡′t^{\prime}italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT. To avoid such repetitive computation, we employ an additional caching mechanism, Aux Cache, to cache the hidden states of pruned tokens. This enables a computationally efficient pathway to revive pruned tokens, and ensures that the worst runtime of LazyLLM is never slower than the baseline.

In summary, the advantages of LazyLLM are: (1) Universal: LazyLLM can be seamlessly integrated with any existing transformer-based LLM to improve inference speed, (2) Training-free: LazyLLM doesn’t require any finetuning and can be directly integrated without any parameter modification, (3) Effective: Empirical results on 16 standard datasets across 6 different language tasks shows LazyLLM can improve the inference speed of the LLM during both prefilling and decoding stages.

2 Related Work

The increase in the scale of large language models (LLMs) has greatly enhanced their performance but also introduced challenges with respect to their inference efficiency. The inference of generative LLMs consists of two distinct stages as depicted in Figure 1. In particular, extensive computation is needed under long context scenarios to calculate the full KV cache during the prefilling stage, resulting in a long time-to-first-token (TTFT). This delay causes users to wait several seconds after submitting a prompt before receiving any response from the agent, leading to a poor user experience.

Efficient Long Context Inference. Extensive work (Merth et al., 2024; Chen et al., 2023; Beltagy et al., 2020; Kitaev et al., 2020) has been proposed to improve inference efficiency for long context applications by reducing the memory footprint and total computations. Some works have focused on tailoring the architecture of the transformer for long context input. For instance, (Beltagy et al., 2020) introduces a drop-in replacement for standard self-attention and combines local windowed attention with task-motivated global attention. In parallel, Reformer (Kitaev et al., 2020) replaces dot-product attention by one that uses locality-sensitive hashing to reduce its computational complexity. Though the above methods can speed up long context inference, they require significant model architecture change and re-training. This drawback makes them impractical to be applied to existing pre-trained LLMs. Closer to our work are efficient techniques that optimize the KV cache (Zhang et al., 2024; Li et al., 2024; Anagnostidis et al., 2024; Nawrot et al., 2024) by minimizing the KV cache size and data transfer. However, these works only focus on accelerating decoding steps, which are not applicable to reducing TTFT.

Token Pruning. Previous studies on the sentence classification task (Kim et al., 2022; Anagnostidis et al., 2024; He et al., 2021) has shown that not all tokens (i.e. words) in an input sequence are necessary to make a successful prediction. This provides several possibilities for token pruning, which minimizes computational demands by selectively removing less important tokens during inference. For example, (Kim et al., 2022) presents Learned Token Pruning which adaptively removes unimportant tokens as an input sequence passes through transformer layers. In parallel, (He et al., 2021) proposes to reduce width-wise computation via token pruning for transformer-based models such as BERT (Devlin et al., 2018). These aforementioned approaches were designed for tasks requiring only a single iteration of processing, such as text classification. In this work, we extend the idea of token pruning to generative LLMs. Specifically, our method allows the model to dynamically choose different sets of tokens at each generation step, which is crucial to retaining the performance. Furthermore, we also introduce Aux Cache to ensure that each token is computed at most once along the whole generation, and ensure the worst runtime of our method is not slower than the baseline.

3LazyLLM

3.1 Background on LLM Inference

Generative LLM inference consists of two stages: prefilling and decoding (see Figure 1). In the prefilling stage, the model receives the prompt (a sequence of tokens) 𝒯={ti}i=1N𝒯superscriptsubscriptsubscript𝑡𝑖𝑖1𝑁\mathcal{T}=\{t_{i}\}_{i=1}^{N}caligraphic_T = { italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT of length N, where tisubscript𝑡𝑖t_{i}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT denotes a token and N𝑁Nitalic_N denotes the length of the prompt, then computes and saves the KV cache of each token, and produces the first token tn+1subscript𝑡𝑛1t_{n+1}italic_t start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT. The transformer architecture commonly used in LLMs is a stack of layers where each layer shares the same architecture with a multiple-head self-attention mechanism followed by a multi-layer perception (MLP). The time of prefilling is referred to as time-to-first-token (a.k.a. TTFT). Following the prefilling is the decoding steps, where the model appends the generated token tn+1subscript𝑡𝑛1t_{n+1}italic_t start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT to the input, and subsequently decodes the following token. The decoding step is repeatedly performed until the stop criteria are met. While the formula of each decoding step is similar to prefilling, the amount of its computation is significantly lower thanks to the KV cache. Specifically, with saved KV cache from prefilling, all the previous tokens do not need to pass any linear layers in the model.

3.2 Inference with LazyLLM

The overview of the proposed LazyLLM framework is illustrated in Figure 4. LazyLLM starts with the full context and progressively prunes tokens to gradually reduce the number of computations towards the end of the model. Note, LazyLLM allows the model to select different subsets of tokens from the context in different generation steps, even though some of them may be pruned in previous steps. Compared to static pruning which prunes all the tokens at once, dynamic pruning optimizes the next token prediction in each generation step, which is crucial to retaining the performance.

Progressive Token Pruning. Prior to this work, token pruning has been successfully applied to optimize LLM inference (Zhang et al., 2024; Li et al., 2024; Adnan et al., 2024; Nawrot et al., 2024). However, these approaches require accumulating the full attention maps of predicting the first few tokens to profile the importance of prompt tokens before starting pruning. Consequently, they are not applicable to reduce TTFT as they still require computing all the KV cache at the prefilling stage.

In contrast, LazyLLM only “lazily” computes the tokens that are important to predict the next token by starting from the first iteration of the inference (the prefilling step). A key challenge to pruning tokens in the first iteration is determining their importance. Inspired by the early exiting work (Elhoushi et al., 2024) which shows the token hidden states gradually evolve through the transformer layers, we apply layer-wise token pruning in each generation step. Specifically, we use the attention map of the layer Al∈ℛH×N×Nsuperscript𝐴𝑙superscriptℛ𝐻𝑁𝑁A^{l}\in\mathcal{R}^{H\times N\times N}italic_A start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT ∈ caligraphic_R start_POSTSUPERSCRIPT italic_H × italic_N × italic_N end_POSTSUPERSCRIPT to determine the importance of input token tisubscript𝑡𝑖t_{i}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT w.r.t. the next token to be predicted as

sil=1H⁢∑h=1HAh,i,Nlsuperscriptsubscript𝑠𝑖𝑙1𝐻superscriptsubscriptℎ1𝐻subscriptsuperscript𝐴𝑙ℎ𝑖𝑁s_{i}^{l}=\frac{1}{H}\sum_{h=1}^{H}A^{l}_{h,i,N}italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_H end_ARG ∑ start_POSTSUBSCRIPT italic_h = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_H end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_h , italic_i , italic_N end_POSTSUBSCRIPT (1)

where H𝐻Hitalic_H denotes number of attention heads, N𝑁Nitalic_N is the sequence length, and Ah,i,jsubscript𝐴ℎ𝑖𝑗A_{h,i,j}italic_A start_POSTSUBSCRIPT italic_h , italic_i , italic_j end_POSTSUBSCRIPT is the attention probability of the token tjsubscript𝑡𝑗t_{j}italic_t start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT attending to token tisubscript𝑡𝑖t_{i}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT at ht⁢hsuperscriptℎ𝑡ℎh^{th}italic_h start_POSTSUPERSCRIPT italic_t italic_h end_POSTSUPERSCRIPT head.

After computing the confidence scores of tokens, it is challenging to determine the threshold value to prune the token. Concretely, the threshold can change as the distribution of the attention scores varies between different layers and different tasks. We address this challenge by using the top-k𝑘kitalic_k percentile selection strategy to prune tokens. Specifically, token tisubscript𝑡𝑖t_{i}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is pruned at layer l+1𝑙1l+1italic_l + 1 if its confidence score silsubscriptsuperscript𝑠𝑙𝑖s^{l}_{i}italic_s start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is smaller than klsuperscript𝑘𝑙k^{l}italic_k start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPTth percentile among the input tokens. Once the token is pruned, it is excluded from the computation of all successive layers. In other words, the tokens used in the later layers will be a subset of previous layers.

Our study in Section 5.4 shows the performance changes with different locations of pruning layers and the number of tokens pruned. In particular, when pruning at the same transformer layer, the model’s performance gradually decreases as fewer tokens are kept. We also found pruning at later transformer layers consistently has better performance than pruning at earlier layers, suggesting that later layers are less sensitive to token pruning. To achieve a better balance of speedup and accuracy, as shown in Figure 4, we apply progressive pruning that keeps more tokens at earlier transformer layers and gradually reduces the number of tokens towards the end of the transformer.

Aux Cache. In the prefilling stage, there is no KV cache and every token is represented by hidden states. Thus, progressive token pruning can be implemented by removing pruned tokens’ hidden states. However, extending the progressive token pruning to the following decoding steps is non-trivial. This is because each decoding step leverages the KV cache computed in the prefilling to compute attention. As the LazyLLM performs progressive token pruning at the prefilling stage, the KV of tokens pruned at layer l𝑙litalic_l (e.g. T⁢4𝑇4T4italic_T 4 in Figure 4) will not exist in the KV cache of layer l+1𝑙1l+1italic_l + 1. As a reminder, the LazyLLM framework allows each generation step to pick a different subset set of tokens from the full input token sequences in every step, regardless of whether they are pruned in previous generation steps or not. For example, during the following decoding steps, those pruned tokens (e.g. T⁢4𝑇4T4italic_T 4) that do not exist in the KV cache of layer l+1𝑙1l+1italic_l + 1 may be re-selected to compute attention. In such cases, the model can not retrieve the KV cache of these tokens. An intuitive solution is to pass those tokens again from the beginning of the transformer. However, that would cause repetitive computation for the same token, and eventually slow down the whole generation.

To tackle this challenge, we introduce Aux Cache in addition to the original KV cache, which stores the hidden states of those pruned tokens (e.g. T⁢4𝑇4T4italic_T 4 and T⁢7𝑇7T7italic_T 7 in Figure 4) if their KV is not present in the following layer’s KV cache, which could be potentially retrieved for the following iterations. As shown in Figure 4, in each decoding step, each transformer layer (e.g. layer l+1𝑙1l+1italic_l + 1) first retrieves the KV cache of past tokens if they exist (e.g. T⁢1𝑇1T1italic_T 1 and T⁢8𝑇8T8italic_T 8). For those tokens that do not exist in the KV cache (e.g. T⁢3𝑇3T3italic_T 3), we could retrieve their hidden states from the Aux Cache of its previous layer directly instead of passing through previous layers again. The introduction of Aux Cache ensures that each token is computed at most once in every transformer layer, and ensures the worst runtime of LazyLLM is not slower than the baseline.

4 Implementations Details

We implement LazyLLM on Llama 2 (Touvron et al., 2023) and XGen (Nijkamp et al., 2023) and evaluate it on the LongBench (Bai et al., 2023) using HuggingFace222https://github.com/huggingface/transformers/. We follow the official GitHub repository333https://github.com/THUDM/LongBench of LongBench for data preprocessing and prompting in all experiments. The LongBench benchmark consists of multiple datasets in different tasks, where each task may have different metrics, including ROUGE-L, F1, Accuracy, and Edit Sim. Following the official evaluation pipeline, we categorize all results over major task categories by computing the macro-average score.

As previously noted, the proposed LazyLLM doesn’t require any training. Thus, LazyLLM uses the exact same existing checkpoints as the baseline, for all models. For inference, we conduct all experiments on NVIDIA A100 GPUs. We measure and report the speedup based on the empirical walltime improvement. Specifically, for TTFT Speedup, we measure the empirical walltime between when the prompt is fed to the model, and when the model generates the first token. For Generation Speedup, we measure the empirical walltime between when the prompt is fed to the model, and when the model finished generating all output tokens. We add 5 warmup runs for each experiment before starting the time measurement to remove the noise such as loading model parameters.

5 Experiments

We examine our method using two large language models: Llama 2 7B and XGen 7B. We compare our method with baselines using the same publicly released pretrained checkpoints, without employing any additional training. We perform experiments using LongBench, a multi-task benchmark for long content understanding. The LongBench comprises 16 datasets and covers 6 tasks including single-doc QA, multi-doc QA, summarization, few-shot learning, synthetic tasks, and code completion.

For the metrics, we primarily evaluate the effectiveness and efficiency of each method in the TTFT speedup vs. accuracy trade-off. Following LongBench, the accuracy (score) denotes the macro-averaged scores across datasets in each task. The TTFT speedup measures the wall time improvement w.r.t. to the baseline for generating the first token. In analysis, we also assess the impact of our method on %percent\%% of Prompt Token Computed and Generation speedup. The %percent\%% of Prompt Token Computed measures the accumulated percent of prompt tokens computed at the end of the generation, which indicates the save of total computation. The Generation speedup measures the walltime change w.r.t. to the baseline for completing the entire generation process.

5.1 Results

Table 1 presents the TTFT speedup vs. accuracy comparisons between LazyLLM, standard LLM, and other baselines. In the table, the “baseline” refers to the standard LLM inference. The “random token drop” baseline is based on (Yao et al., 2022) that randomly prunes the prompt tokens before feeding them to the LLMs. We report the average metrics across 5 runs for the “random token drop” baseline. Our “static token pruning” baseline prunes input tokens at once based on their attention score of the first few transformer layers during the prefilling stage. We also compare with the prompt compression method (Li et al., 2023) which pruning redundancy in the input context using LLMs. Table 1 shows LazyLLM consistently achieves better TTFT speedup with negligible accuracy drop across multiple tasks. It is worth noting that the overhead of running LLMs to compress the prompt is very computationally expensive. Even though the inference on the reduced prompt is faster, the actual TTFT of the “prompt compression” baseline is longer than the baseline.

5.2TTFT Speedup vs. Accuracy

The inference efficiency of LazyLLM is controlled using three parameters: 1) the number of pruning layers, 2) the locations of these pruning layers, and 3) the number of tokens pruned within these layers. Increasing the number of pruning layers and pruning more tokens optimize computation by processing fewer tokens, and pruning tokens at earlier layers can save the computations for the successive layers. Prompting these factors will give more overall computation reduction, and offer better TTFT speedup. As a side effect, excessively pruning tokens may cause information loss and eventually lead to performance degradation. Similarly, the TTFT speedup and accuracy of baselines can vary with different hyperparameters.

We compare TTFT speedup vs. accuracy in Figure 5 with different hyperparameters. The visualization shows that, without any training, the proposed LazyLLM retains the accuracy better than baselines under the same TTFT speedup. For example, our method can offer 2.34×2.34\times2.34 × TTFT speedup in the multi-document question-answering task with negligible (≤1%absentpercent1\leq 1\%≤ 1 %) performance loss. By controlling the pruning parameters, LazyLLM provides a good trade-off between accuracy and inference speed as compared to baseline methods. For instance, LazyLLM can achieve 3.0×3.0\times3.0 × TTFT speedup in the multi-document question-answering task with ≤10%absentpercent10\leq 10\%≤ 10 % degradation in accuracy. On the other hand, baseline methods accuracy degrades significantly for similar TTFT speed-up. Note that the prompt compression approaches fail at improving TTFT because of the compression overhead.

5.3 Impact on Overall Generation Speed

To evaluate the impact of the proposed method on the overall generation process, we also profile the %percent\%% of Prompt Token Computed and Generation speedup in Table 2. We can find the %percent\%% of Token Computed of LazyLLM is less than 100%percent\%%, indicating that not all tokens in the prompt are selected by LazyLLM at the end of the generation, even though theoretically the model could use all tokens. Computations in the FFN layers increase linearly, while those in the attention layers grow quadratically with the %percent\%% of Token Computed. A lower %percent\%% of Token Computed indicates LazyLLM reduces the total computation, consequently offering additional speedup to the overall generation process across diverse tasks.

5.4 Drop Rate in Different Layers

In this section, we analyze the effect of the locations of pruning layers, and the number of tokens pruned. In particular, we report a series of experiments using a simplified version of LazyLLM that prunes tokens just once within the transformer. For each trial, we position the pruning layer at various levels of the transformer stack and apply different pruning ratios. We perform the experiments for both Llama 2 and XGen, and visualize the results in Figure 6.

The results show both models share a similar trend. As expected, when pruning at the same transformer layer, the model’s performance gradually decreases as fewer tokens are kept. Furthermore, pruning at later transformer layers consistently yields better performance compared to pruning at earlier layers, suggesting that later layers are less sensitive to token pruning. Based on these observations, we propose progressive token pruning in Section 3.2, which strategically prunes more tokens in later layers while preserving more in the earlier layers, optimizing the balance between efficiency and performance retention.

5.5 Progressive KV Growth

In this section, we characterize the internals of the model with the token pruning logic. Specifically, we seek to understand what fractions of prompt tokens are cumulatively used and, inversely, not used. This “cumulative token usage” can be equivalently defined as the KV cache size at each given step. Figure 7 presents these cumulative prompt token usage numbers for each of the stages of the LazyLLM.

Our analysis supports the hypothesis that many tokens are never selected by the model (even though theoretically the model could use all tokens in the prompt). Since this model retains accuracy on the task(s), we can conclude that the model effectively drops the tokens which do not affect the output quality.

6 Conclusion

In this work, we proposed a novel LazyLLM technique for efficient LLM inference, in particular under long context scenarios. LazyLLM selectively computes the KV for tokens important for the next token prediction and “lazily” defers the computation of remaining tokens to later steps, when they become relevant. We carefully examine LazyLLM on various tasks, where we observed the proposed method effectively reduces TTFT with negligible performance loss. It is worth noting that our method can be seamlessly integrated with existing transformer-based LLMs to improve their inference speed without requiring any fine-tuning.

References

  • Adnan et al. (2024)Muhammad Adnan, Akhil Arunkumar, Gaurav Jain, Prashant J Nair, Ilya Soloveychik, and Purushotham Kamath.Keyformer: Kv cache reduction through key tokens selection for efficient generative inference.arXiv preprint arXiv:2403.09054, 2024.
  • Aminabadi et al. (2022)Reza Yazdani Aminabadi, Samyam Rajbhandari, Ammar Ahmad Awan, Cheng Li, Du Li, Elton Zheng, Olatunji Ruwase, Shaden Smith, Minjia Zhang, Jeff Rasley, et al.Deepspeed-inference: enabling efficient inference of transformer models at unprecedented scale.In SC22: International Conference for High Performance Computing, Networking, Storage and Analysis, pp. 1–15. IEEE, 2022.
  • Anagnostidis et al. (2024)Sotiris Anagnostidis, Dario Pavllo, Luca Biggio, Lorenzo Noci, Aurelien Lucchi, and Thomas Hofmann.Dynamic context pruning for efficient and interpretable autoregressive transformers.Advances in Neural Information Processing Systems, 36, 2024.
  • Bai et al. (2023)Yushi Bai, Xin Lv, Jiajie Zhang, Hongchang Lyu, Jiankai Tang, Zhidian Huang, Zhengxiao Du, Xiao Liu, Aohan Zeng, Lei Hou, et al.Longbench: A bilingual, multitask benchmark for long context understanding.arXiv preprint arXiv:2308.14508, 2023.
  • Beltagy et al. (2020)Iz Beltagy, Matthew E Peters, and Arman Cohan.Longformer: The long-document transformer.arXiv preprint arXiv:2004.05150, 2020.
  • Bhendawade et al. (2024)Nikhil Bhendawade, Irina Belousova, Qichen Fu, Henry Mason, Mohammad Rastegari, and Mahyar Najibi.Speculative streaming: Fast llm inference without auxiliary models.arXiv preprint arXiv:2402.11131, 2024.
  • Cai et al. (2024)Tianle Cai, Yuhong Li, Zhengyang Geng, Hongwu Peng, Jason D Lee, Deming Chen, and Tri Dao.Medusa: Simple llm inference acceleration framework with multiple decoding heads.arXiv preprint arXiv:2401.10774, 2024.
  • Chen et al. (2023)Yukang Chen, Shengju Qian, Haotian Tang, Xin Lai, Zhijian Liu, Song Han, and Jiaya Jia.Longlora: Efficient fine-tuning of long-context large language models.arXiv preprint arXiv:2309.12307, 2023.
  • Dao et al. (2022)Tri Dao, Dan Fu, Stefano Ermon, Atri Rudra, and Christopher Ré.Flashattention: Fast and memory-efficient exact attention with io-awareness.Advances in Neural Information Processing Systems, 35:16344–16359, 2022.
  • Devlin et al. (2018)Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova.Bert: Pre-training of deep bidirectional transformers for language understanding.arXiv preprint arXiv:1810.04805, 2018.
  • Elhoushi et al. (2024)Mostafa Elhoushi, Akshat Shrivastava, Diana Liskovich, Basil Hosmer, Bram Wasti, Liangzhen Lai, Anas Mahmoud, Bilge Acun, Saurabh Agarwal, Ahmed Roman, et al.Layer skip: Enabling early exit inference and self-speculative decoding.arXiv preprint arXiv:2404.16710, 2024.
  • Frantar et al. (2022)Elias Frantar, Saleh Ashkboos, Torsten Hoefler, and Dan Alistarh.Gptq: Accurate post-training quantization for generative pre-trained transformers.arXiv preprint arXiv:2210.17323, 2022.
  • He et al. (2021)Xuanli He, Iman Keivanloo, Yi Xu, Xiang He, Belinda Zeng, Santosh Rajagopalan, and Trishul Chilimbi.Magic pyramid: Accelerating inference with early exiting and token pruning.arXiv preprint arXiv:2111.00230, 2021.
  • Jiang et al. (2023)Huiqiang Jiang, Qianhui Wu, Chin-Yew Lin, Yuqing Yang, and Lili Qiu.Llmlingua: Compressing prompts for accelerated inference of large language models.arXiv preprint arXiv:2310.05736, 2023.
  • Kim et al. (2022)Sehoon Kim, Sheng Shen, David Thorsley, Amir Gholami, Woosuk Kwon, Joseph Hassoun, and Kurt Keutzer.Learned token pruning for transformers.In Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, pp. 784–794, 2022.
  • Kim et al. (2023)Sehoon Kim, Coleman Hooper, Thanakul Wattanawong, Minwoo Kang, Ruohan Yan, Hasan Genc, Grace Dinh, Qijing Huang, Kurt Keutzer, Michael W Mahoney, et al.Full stack optimization of transformer inference: a survey.arXiv preprint arXiv:2302.14017, 2023.
  • Kitaev et al. (2020)Nikita Kitaev, Łukasz Kaiser, and Anselm Levskaya.Reformer: The efficient transformer.arXiv preprint arXiv:2001.04451, 2020.
  • Leviathan et al. (2023)Yaniv Leviathan, Matan Kalman, and Yossi Matias.Fast inference from transformers via speculative decoding.In International Conference on Machine Learning, pp. 19274–19286. PMLR, 2023.
  • Li et al. (2023)Yucheng Li, Bo Dong, Chenghua Lin, and Frank Guerin.Compressing context to enhance inference efficiency of large language models.arXiv preprint arXiv:2310.06201, 2023.
  • Li et al. (2024)Yuhong Li, Yingbing Huang, Bowen Yang, Bharat Venkitesh, Acyr Locatelli, Hanchen Ye, Tianle Cai, Patrick Lewis, and Deming Chen.Snapkv: Llm knows what you are looking for before generation.arXiv preprint arXiv:2404.14469, 2024.
  • Ma et al. (2023)Xinyin Ma, Gongfan Fang, and Xinchao Wang.Llm-pruner: On the structural pruning of large language models.Advances in neural information processing systems, 36:21702–21720, 2023.
  • Merth et al. (2024)Thomas Merth, Qichen Fu, Mohammad Rastegari, and Mahyar Najibi.Superposition prompting: Improving and accelerating retrieval-augmented generation.2024.URL https://api.semanticscholar.org/CorpusID:269033436.
  • Nawrot et al. (2024)Piotr Nawrot, Adrian Łańcucki, Marcin Chochowski, David Tarjan, and Edoardo M Ponti.Dynamic memory compression: Retrofitting llms for accelerated inference.arXiv preprint arXiv:2403.09636, 2024.
  • Nijkamp et al. (2023)Erik Nijkamp, Tian Xie, Hiroaki Hayashi, Bo Pang, Congying Xia, Chen Xing, Jesse Vig, Semih Yavuz, Philippe Laban, Ben Krause, et al.Xgen-7b technical report.arXiv preprint arXiv:2309.03450, 2023.
  • NVIDIA (2024)NVIDIA.NVIDIA L40S: Unparalleled AI and graphics performance for the data center.https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413, 2024.[Online; accessed 31-May-2024].
  • Pope et al. (2023)Reiner Pope, Sholto Douglas, Aakanksha Chowdhery, Jacob Devlin, James Bradbury, Jonathan Heek, Kefan Xiao, Shivani Agrawal, and Jeff Dean.Efficiently scaling transformer inference.Proceedings of Machine Learning and Systems, 5, 2023.
  • Sun et al. (2023)Mingjie Sun, Zhuang Liu, Anna Bair, and J Zico Kolter.A simple and effective pruning approach for large language models.arXiv preprint arXiv:2306.11695, 2023.
  • Touvron et al. (2023)Hugo Touvron, Louis Martin, Kevin Stone, Peter Albert, Amjad Almahairi, Yasmine Babaei, Nikolay Bashlykov, Soumya Batra, Prajjwal Bhargava, Shruti Bhosale, et al.Llama 2: Open foundation and fine-tuned chat models.arXiv preprint arXiv:2307.09288, 2023.
  • Xu et al. (2023)Zhaozhuo Xu, Zirui Liu, Beidi Chen, Yuxin Tang, Jue Wang, Kaixiong Zhou, Xia Hu, and Anshumali Shrivastava.Compress, then prompt: Improving accuracy-efficiency trade-off of llm inference with transferable prompt.arXiv preprint arXiv:2305.11186, 2023.
  • Yao et al. (2022)Zhewei Yao, Xiaoxia Wu, Conglong Li, Connor Holmes, Minjia Zhang, Cheng Li, and Yuxiong He.Random-ltd: Random and layerwise token dropping brings efficient training for large-scale transformers.arXiv preprint arXiv:2211.11586, 2022.
  • Zhang et al. (2024)Zhenyu Zhang, Ying Sheng, Tianyi Zhou, Tianlong Chen, Lianmin Zheng, Ruisi Cai, Zhao Song, Yuandong Tian, Christopher Ré, Clark Barrett, et al.H2o: Heavy-hitter oracle for efficient generative inference of large language models.Advances in Neural Information Processing Systems, 36, 2024.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment