Tensor Parallelism (TP) shards individual tensors across devices, following the Megatron-LM pattern:
- Column-wise: Input projections (q/k/v, gate/up)
- Row-wise: Output projections (o_proj, down_proj)
- Sequence Parallel: Shards activations on sequence dimension for memory savings
- Loss Parallel: Keep logits sharded on vocab dimension for efficient cross-entropy
def tp_plan_block() -> dict:
return {
# Attention projections
"self_attn.q_proj": ColwiseParallel(),
"self_attn.k_proj": ColwiseParallel(),
"self_attn.v_proj": ColwiseParallel(),
"self_attn.o_proj": RowwiseParallel(),
# MLP projections
"mlp.gate_proj": ColwiseParallel(),
"mlp.up_proj": ColwiseParallel(),
"mlp.down_proj": RowwiseParallel(),
}
def tp_plan_block() -> dict:
return {
# Sequence-parallel norms (output Shard(1))
"input_layernorm": SequenceParallel(),
"post_attention_layernorm": SequenceParallel(),
# Gather sequence-sharded to replicated for attention
"self_attn": PrepareModuleInput(
input_kwarg_layouts={"hidden_states": Shard(1)},
desired_input_kwarg_layouts={"hidden_states": Replicate()},
),
"self_attn.q_proj": ColwiseParallel(),
"self_attn.k_proj": ColwiseParallel(),
"self_attn.v_proj": ColwiseParallel(),
"self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)),
# MLP: sequence-sharded input -> replicated
"mlp": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"mlp.gate_proj": ColwiseParallel(),
"mlp.up_proj": ColwiseParallel(),
"mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),
}
def tp_full_model(model, tp_mesh):
# Parallelize decoder layers
for blk in model.model.layers:
parallelize_module(blk, tp_mesh, tp_plan_block())
# Parallelize top-level components
top_plan = {
"model.embed_tokens": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Shard(1), # for sequence parallel
),
"model.norm": SequenceParallel(), # final norm
"lm_head": ColwiseParallel(use_local_output=False), # Keep DTensor
}
parallelize_module(model, tp_mesh, top_plan)
return model
Problem: PrepareModuleInput
caused "module inputs and input_layouts should have same length!"
Root Cause: PrepareModuleInput by default only handles positional arguments, but Qwen2Attention.forward() is called with keyword arguments. The hook sees len(inputs) == 0
(empty positional tuple) but len(input_layouts) == 1
.
Solution: Use input_kwarg_layouts
and desired_input_kwarg_layouts
for keyword arguments:
"self_attn": PrepareModuleInput(
input_kwarg_layouts={"hidden_states": Shard(1)},
desired_input_kwarg_layouts={"hidden_states": Replicate()},
)
Problem: SequenceParallel()
shards sequence dimension, causing rotary embeddings size mismatch
Error: "The size of tensor a (896) must match the size of tensor b (897)"
Root Cause: Sequence Parallel requires sequence length to be evenly divisible by TP world size. When seq_len % tp_world_size != 0
, DTensor's all-gather silently drops/pads tokens, causing length mismatch with full-length rotary caches.
Solution: Pad all inputs to make sequence length divisible by TP world size:
def pad_to_multiple(x, multiple, value=0):
pad = (-x.size(-1)) % multiple
if pad: x = torch.nn.functional.pad(x, (0, pad), value=value)
return x
for key in ("input_ids", "labels", "position_ids"):
sample[key] = pad_to_multiple(sample[key], tp_mesh.size(0),
value=(-100 if key=="labels" else 0))
Source: PyTorch docs state: "Sequence Parallel currently assumes the input sequence length can be evenly divided by the number of devices in the mesh" and "When using Shard(dim) as the input/output layouts... we assume the input/output activation tensors are evenly sharded on the tensor dimension dim"
Problem: Loss from TP model is DTensor, all_reduce()
expects regular tensor
Error: "found no DeviceMesh from dtensor args"
Solution: Convert DTensor to local before collectives using _to_local()
# 2D mesh: FSDP=4, TP=2
mesh = init_device_mesh("cuda", (4, 2), mesh_dim_names=("fsdp", "tp"))
fsdp_mesh, tp_mesh = mesh["fsdp"], mesh["tp"]
# Pad inputs for sequence parallel
for key in ("input_ids", "labels", "position_ids"):
sample[key] = pad_to_multiple(sample[key], tp_mesh.size(), value=(-100 if key=="labels" else 0))
# Apply TP + FSDP
model = tp_full_model(model, tp_mesh)
model = wrap_fsdp2(model, fsdp_mesh)
# In model parallelization
"lm_head": ColwiseParallel(use_local_output=False) # Keep DTensor
# In loss computation
with loss_parallel():
loss = F.cross_entropy(logits, labels, reduction='none')
loss.backward() # backward call also needs to be inside loss_parallel() context
- ✅ Classical Megatron TP (column/row sharding)
- ✅ Sequence parallel with proper input padding
- ✅ PrepareModuleInput with keyword argument layouts
- ✅ Loss parallel with DTensor logits
- ✅ FSDP2 integration with activation checkpointing
- ✅ 2D parallelism (TP intra-host, FSDP inter-host)
- Sequence length must be divisible by TP world size - pad inputs if needed
- Use
input_kwarg_layouts
for keyword arguments - most HuggingFace models use kwargs - Proper layout transformations - Shard(1) → Replicate() → Shard(1) flow
┌─────────────────────────────────────────┐
│ 2D DeviceMesh: (FSDP=4, TP=2) │
├─────────────────────────────────────────┤
│ Input Padding (seq_len % tp_size == 0) │
├─────────────────────────────────────────┤
│ FSDP2 (inter-host) │
│ ├─ Activation Checkpointing │
│ ├─ Mixed Precision (bf16) │
│ └─ Per-layer + Full Model Wrapping │
├─────────────────────────────────────────┤
│ Tensor + Sequence Parallel (intra-host) │
│ ├─ Column: q/k/v/gate/up projections │
│ ├─ Row: o_proj/down_proj │
│ ├─ Sequence: norm layers (Shard(1)) │
│ ├─ PrepareModuleInput: layout transforms│
│ └─ Loss Parallel: vocab-sharded logits │
└─────────────────────────────────────────┘