Skip to content

Instantly share code, notes, and snippets.

@vanbasten23
Created August 5, 2025 23:21
Show Gist options
  • Save vanbasten23/75c27f47d7c3ddcc263e371cb7c4a2c8 to your computer and use it in GitHub Desktop.
Save vanbasten23/75c27f47d7c3ddcc263e371cb7c4a2c8 to your computer and use it in GitHub Desktop.
1. After we load the original model `vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)`, vllm_model looks like
for idx, m in enumerate(vllm_model.named_modules()):
print(idx, '->', m)
https://gist.github.com/vanbasten23/56a5cf844c0a527453a37af36efd3193
2. After replace the layer with LoRA layers (via `load_lora_model`), the model looks like
for idx, m in enumerate(vllm_model.named_modules()):
print(idx, '->', m)
https://gist.github.com/vanbasten23/fc5ab730ea88d60605057b903a0570ea
3. Before we run shard_model_to_tpu, the self.model looks like
for idx, m in enumerate(self.model.named_modules()):
print(idx, '->', m)
https://gist.github.com/vanbasten23/dab6a90283c905882647e8aa5d0b9ca1
4. After we run shard_model_to_tpu, the self.model looks like
for idx, m in enumerate(self.model.named_modules()):
print(idx, '->', m)
https://gist.github.com/vanbasten23/d5b1f15645f3d358b4224927f1bdcbce
Conclusion:
1. Compare (1) with (2), we replace
- QKVParallelLinear with MergedQKVParallelLinearWithLoRA
- MergedColumnParallelLinear with MergedColumnParallelLinearWithLoRA
2. Compare (3) with (4), we change the "base_layer part" and replace
- QKVParallelLinear with JaxQKVParallelLinear
- RowParallelLinear with JaxRowParallelLinear
- Attention with JaxAttention
- MergedColumnParallelLinear with JaxMergedColumnParallelLinear
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment