Created
September 14, 2023 03:16
-
-
Save madaan/75fe8aded3d1d95b023abf5cd5cd58a0 to your computer and use it in GitHub Desktop.
Patch for VLLM to use seq length of 4096 with LLAMA2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/vllm/config.py b/vllm/config.py | |
index 2e8d584..83395d0 100644 | |
--- a/vllm/config.py | |
+++ b/vllm/config.py | |
@@ -134,6 +134,7 @@ class ModelConfig: | |
max_len_key = getattr(self.hf_config, key, None) | |
if max_len_key is not None: | |
max_model_len = min(max_model_len, max_len_key) | |
+ return 4096 | |
return max_model_len | |
def get_num_layers(self, parallel_config: "ParallelConfig") -> int: | |
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py | |
index 99fe593..2b1a5a7 100644 | |
--- a/vllm/engine/arg_utils.py | |
+++ b/vllm/engine/arg_utils.py | |
@@ -25,7 +25,7 @@ class EngineArgs: | |
block_size: int = 16 | |
swap_space: int = 4 # GiB | |
gpu_memory_utilization: float = 0.90 | |
- max_num_batched_tokens: int = 2560 | |
+ max_num_batched_tokens: int = 4096 | |
max_num_seqs: int = 256 | |
disable_log_stats: bool = False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment