Skip to content

Instantly share code, notes, and snippets.

@vanbasten23
Created August 5, 2025 23:01
Show Gist options
  • Save vanbasten23/d5b1f15645f3d358b4224927f1bdcbce to your computer and use it in GitHub Desktop.
Save vanbasten23/d5b1f15645f3d358b4224927f1bdcbce to your computer and use it in GitHub Desktop.
0 -> ('', _VllmRunner(
(vllm_model): Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): VocabParallelEmbedding(num_embeddings=151936, embedding_dim=2048, org_vocab_size=151936, num_embeddings_padded=151936, tp_size=1)
(layers): ModuleList(
(0-35): 36 x Qwen2DecoderLayer(
(self_attn): Qwen2Attention(
(qkv_proj): MergedQKVParallelLinearWithLoRA(
(base_layer): JaxQKVParallelLinear()
)
(o_proj): RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
)
(rotary_emb): RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=32768, base=1000000.0, is_neox_style=True)
(attn): JaxAttention()
)
(mlp): Qwen2MLP(
(gate_up_proj): MergedColumnParallelLinearWithLoRA(
(base_layer): JaxMergedColumnParallelLinear()
)
(down_proj): RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
)
(act_fn): SiluAndMul()
)
(input_layernorm): RMSNorm(hidden_size=2048, eps=1e-06)
(post_attention_layernorm): RMSNorm(hidden_size=2048, eps=1e-06)
)
)
(norm): RMSNorm(hidden_size=2048, eps=1e-06)
)
(lm_head): VocabParallelEmbedding(num_embeddings=151936, embedding_dim=2048, org_vocab_size=151936, num_embeddings_padded=151936, tp_size=1)
(logits_processor): LogitsProcessor(vocab_size=151936, org_vocab_size=151936, scale=1.0, logits_as_input=False)
)
))
1 -> ('vllm_model', Qwen2ForCausalLM(
(model): Qwen2Model(
(embed_tokens): VocabParallelEmbedding(num_embeddings=151936, embedding_dim=2048, org_vocab_size=151936, num_embeddings_padded=151936, tp_size=1)
(layers): ModuleList(
(0-35): 36 x Qwen2DecoderLayer(
(self_attn): Qwen2Attention(
(qkv_proj): MergedQKVParallelLinearWithLoRA(
(base_layer): JaxQKVParallelLinear()
)
(o_proj): RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
)
(rotary_emb): RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=32768, base=1000000.0, is_neox_style=True)
(attn): JaxAttention()
)
(mlp): Qwen2MLP(
(gate_up_proj): MergedColumnParallelLinearWithLoRA(
(base_layer): JaxMergedColumnParallelLinear()
)
(down_proj): RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
)
(act_fn): SiluAndMul()
)
(input_layernorm): RMSNorm(hidden_size=2048, eps=1e-06)
(post_attention_layernorm): RMSNorm(hidden_size=2048, eps=1e-06)
)
)
(norm): RMSNorm(hidden_size=2048, eps=1e-06)
)
(lm_head): VocabParallelEmbedding(num_embeddings=151936, embedding_dim=2048, org_vocab_size=151936, num_embeddings_padded=151936, tp_size=1)
(logits_processor): LogitsProcessor(vocab_size=151936, org_vocab_size=151936, scale=1.0, logits_as_input=False)
))
2 -> ('vllm_model.model', Qwen2Model(
(embed_tokens): VocabParallelEmbedding(num_embeddings=151936, embedding_dim=2048, org_vocab_size=151936, num_embeddings_padded=151936, tp_size=1)
(layers): ModuleList(
(0-35): 36 x Qwen2DecoderLayer(
(self_attn): Qwen2Attention(
(qkv_proj): MergedQKVParallelLinearWithLoRA(
(base_layer): JaxQKVParallelLinear()
)
(o_proj): RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
)
(rotary_emb): RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=32768, base=1000000.0, is_neox_style=True)
(attn): JaxAttention()
)
(mlp): Qwen2MLP(
(gate_up_proj): MergedColumnParallelLinearWithLoRA(
(base_layer): JaxMergedColumnParallelLinear()
)
(down_proj): RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
)
(act_fn): SiluAndMul()
)
(input_layernorm): RMSNorm(hidden_size=2048, eps=1e-06)
(post_attention_layernorm): RMSNorm(hidden_size=2048, eps=1e-06)
)
)
(norm): RMSNorm(hidden_size=2048, eps=1e-06)
))
3 -> ('vllm_model.model.embed_tokens', VocabParallelEmbedding(num_embeddings=151936, embedding_dim=2048, org_vocab_size=151936, num_embeddings_padded=151936, tp_size=1))
4 -> ('vllm_model.model.layers', ModuleList(
(0-35): 36 x Qwen2DecoderLayer(
(self_attn): Qwen2Attention(
(qkv_proj): MergedQKVParallelLinearWithLoRA(
(base_layer): JaxQKVParallelLinear()
)
(o_proj): RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
)
(rotary_emb): RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=32768, base=1000000.0, is_neox_style=True)
(attn): JaxAttention()
)
(mlp): Qwen2MLP(
(gate_up_proj): MergedColumnParallelLinearWithLoRA(
(base_layer): JaxMergedColumnParallelLinear()
)
(down_proj): RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
)
(act_fn): SiluAndMul()
)
(input_layernorm): RMSNorm(hidden_size=2048, eps=1e-06)
(post_attention_layernorm): RMSNorm(hidden_size=2048, eps=1e-06)
)
))
5 -> ('vllm_model.model.layers.0', Qwen2DecoderLayer(
(self_attn): Qwen2Attention(
(qkv_proj): MergedQKVParallelLinearWithLoRA(
(base_layer): JaxQKVParallelLinear()
)
(o_proj): RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
)
(rotary_emb): RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=32768, base=1000000.0, is_neox_style=True)
(attn): JaxAttention()
)
(mlp): Qwen2MLP(
(gate_up_proj): MergedColumnParallelLinearWithLoRA(
(base_layer): JaxMergedColumnParallelLinear()
)
(down_proj): RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
)
(act_fn): SiluAndMul()
)
(input_layernorm): RMSNorm(hidden_size=2048, eps=1e-06)
(post_attention_layernorm): RMSNorm(hidden_size=2048, eps=1e-06)
))
6 -> ('vllm_model.model.layers.0.self_attn', Qwen2Attention(
(qkv_proj): MergedQKVParallelLinearWithLoRA(
(base_layer): JaxQKVParallelLinear()
)
(o_proj): RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
)
(rotary_emb): RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=32768, base=1000000.0, is_neox_style=True)
(attn): JaxAttention()
))
7 -> ('vllm_model.model.layers.0.self_attn.qkv_proj', MergedQKVParallelLinearWithLoRA(
(base_layer): JaxQKVParallelLinear()
))
8 -> ('vllm_model.model.layers.0.self_attn.qkv_proj.base_layer', JaxQKVParallelLinear())
9 -> ('vllm_model.model.layers.0.self_attn.o_proj', RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
))
10 -> ('vllm_model.model.layers.0.self_attn.o_proj.base_layer', JaxRowParallelLinear())
11 -> ('vllm_model.model.layers.0.self_attn.rotary_emb', RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=32768, base=1000000.0, is_neox_style=True))
12 -> ('vllm_model.model.layers.0.self_attn.attn', JaxAttention())
13 -> ('vllm_model.model.layers.0.mlp', Qwen2MLP(
(gate_up_proj): MergedColumnParallelLinearWithLoRA(
(base_layer): JaxMergedColumnParallelLinear()
)
(down_proj): RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
)
(act_fn): SiluAndMul()
))
14 -> ('vllm_model.model.layers.0.mlp.gate_up_proj', MergedColumnParallelLinearWithLoRA(
(base_layer): JaxMergedColumnParallelLinear()
))
15 -> ('vllm_model.model.layers.0.mlp.gate_up_proj.base_layer', JaxMergedColumnParallelLinear())
16 -> ('vllm_model.model.layers.0.mlp.down_proj', RowParallelLinearWithLoRA(
(base_layer): JaxRowParallelLinear()
))
17 -> ('vllm_model.model.layers.0.mlp.down_proj.base_layer', JaxRowParallelLinear())
18 -> ('vllm_model.model.layers.0.mlp.act_fn', SiluAndMul())
19 -> ('vllm_model.model.layers.0.input_layernorm', RMSNorm(hidden_size=2048, eps=1e-06))
20 -> ('vllm_model.model.layers.0.post_attention_layernorm', RMSNorm(hidden_size=2048, eps=1e-06))
21 -> ('vllm_model.model.layers.1', Qwen2DecoderLayer(
...
545 -> ('vllm_model.model.layers.35.post_attention_layernorm', RMSNorm(hidden_size=2048, eps=1e-06))
546 -> ('vllm_model.model.norm', RMSNorm(hidden_size=2048, eps=1e-06))
547 -> ('vllm_model.logits_processor', LogitsProcessor(vocab_size=151936, org_vocab_size=151936, scale=1.0, logits_as_input=False))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment