Created
August 5, 2025 23:21
-
-
Save vanbasten23/75c27f47d7c3ddcc263e371cb7c4a2c8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1. After we load the original model `vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)`, vllm_model looks like | |
for idx, m in enumerate(vllm_model.named_modules()): | |
print(idx, '->', m) | |
https://gist.github.com/vanbasten23/56a5cf844c0a527453a37af36efd3193 | |
2. After replace the layer with LoRA layers (via `load_lora_model`), the model looks like | |
for idx, m in enumerate(vllm_model.named_modules()): | |
print(idx, '->', m) | |
https://gist.github.com/vanbasten23/fc5ab730ea88d60605057b903a0570ea | |
3. Before we run shard_model_to_tpu, the self.model looks like | |
for idx, m in enumerate(self.model.named_modules()): | |
print(idx, '->', m) | |
https://gist.github.com/vanbasten23/dab6a90283c905882647e8aa5d0b9ca1 | |
4. After we run shard_model_to_tpu, the self.model looks like | |
for idx, m in enumerate(self.model.named_modules()): | |
print(idx, '->', m) | |
https://gist.github.com/vanbasten23/d5b1f15645f3d358b4224927f1bdcbce | |
Conclusion: | |
1. Compare (1) with (2), we replace | |
- QKVParallelLinear with MergedQKVParallelLinearWithLoRA | |
- MergedColumnParallelLinear with MergedColumnParallelLinearWithLoRA | |
2. Compare (3) with (4), we change the "base_layer part" and replace | |
- QKVParallelLinear with JaxQKVParallelLinear | |
- RowParallelLinear with JaxRowParallelLinear | |
- Attention with JaxAttention | |
- MergedColumnParallelLinear with JaxMergedColumnParallelLinear |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment