Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save RohanAwhad/89f82fe9db3b564f769f6ee7017aa95c to your computer and use it in GitHub Desktop.
Save RohanAwhad/89f82fe9db3b564f769f6ee7017aa95c to your computer and use it in GitHub Desktop.

Tensor Parallelism Knowledge Compilation

Core Understanding

Tensor Parallelism (TP) shards individual tensors across devices, following the Megatron-LM pattern:

  • Column-wise: Input projections (q/k/v, gate/up)
  • Row-wise: Output projections (o_proj, down_proj)
  • Sequence Parallel: Shards activations on sequence dimension for memory savings
  • Loss Parallel: Keep logits sharded on vocab dimension for efficient cross-entropy

Working Implementation

1. Basic Tensor Parallel Plan

def tp_plan_block() -> dict:
  return {
    # Attention projections
    "self_attn.q_proj": ColwiseParallel(),
    "self_attn.k_proj": ColwiseParallel(), 
    "self_attn.v_proj": ColwiseParallel(),
    "self_attn.o_proj": RowwiseParallel(),
    
    # MLP projections  
    "mlp.gate_proj": ColwiseParallel(),
    "mlp.up_proj":   ColwiseParallel(),
    "mlp.down_proj": RowwiseParallel(),
  }

2. Sequence Parallel Plan (Fixed)

def tp_plan_block() -> dict:
  return {
    # Sequence-parallel norms (output Shard(1))
    "input_layernorm": SequenceParallel(),
    "post_attention_layernorm": SequenceParallel(),
    
    # Gather sequence-sharded to replicated for attention
    "self_attn": PrepareModuleInput(
      input_kwarg_layouts={"hidden_states": Shard(1)},
      desired_input_kwarg_layouts={"hidden_states": Replicate()},
    ),
    "self_attn.q_proj": ColwiseParallel(),
    "self_attn.k_proj": ColwiseParallel(),
    "self_attn.v_proj": ColwiseParallel(),
    "self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)),
    
    # MLP: sequence-sharded input -> replicated
    "mlp": PrepareModuleInput(
      input_layouts=(Shard(1),),
      desired_input_layouts=(Replicate(),),
    ),
    "mlp.gate_proj": ColwiseParallel(),
    "mlp.up_proj":   ColwiseParallel(),
    "mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)),
  }

3. Full Model Parallelization

def tp_full_model(model, tp_mesh):
  # Parallelize decoder layers
  for blk in model.model.layers:
    parallelize_module(blk, tp_mesh, tp_plan_block())
  
  # Parallelize top-level components
  top_plan = {
    "model.embed_tokens": RowwiseParallel(
      input_layouts=Replicate(),
      output_layouts=Shard(1),  # for sequence parallel
    ),
    "model.norm": SequenceParallel(),  # final norm
    "lm_head": ColwiseParallel(use_local_output=False),  # Keep DTensor
  }
  parallelize_module(model, tp_mesh, top_plan)
  return model

Key Issues & Solutions (RESOLVED)

Issue 1: PrepareModuleInput Length Mismatch ✅ FIXED

Problem: PrepareModuleInput caused "module inputs and input_layouts should have same length!"

Root Cause: PrepareModuleInput by default only handles positional arguments, but Qwen2Attention.forward() is called with keyword arguments. The hook sees len(inputs) == 0 (empty positional tuple) but len(input_layouts) == 1.

Solution: Use input_kwarg_layouts and desired_input_kwarg_layouts for keyword arguments:

"self_attn": PrepareModuleInput(
  input_kwarg_layouts={"hidden_states": Shard(1)},
  desired_input_kwarg_layouts={"hidden_states": Replicate()},
)

Issue 2: Sequence Parallel + Rotary Embeddings Conflict ✅ FIXED

Problem: SequenceParallel() shards sequence dimension, causing rotary embeddings size mismatch Error: "The size of tensor a (896) must match the size of tensor b (897)"

Root Cause: Sequence Parallel requires sequence length to be evenly divisible by TP world size. When seq_len % tp_world_size != 0, DTensor's all-gather silently drops/pads tokens, causing length mismatch with full-length rotary caches.

Solution: Pad all inputs to make sequence length divisible by TP world size:

def pad_to_multiple(x, multiple, value=0):
  pad = (-x.size(-1)) % multiple
  if pad: x = torch.nn.functional.pad(x, (0, pad), value=value)
  return x
for key in ("input_ids", "labels", "position_ids"):
  sample[key] = pad_to_multiple(sample[key], tp_mesh.size(0),
                                value=(-100 if key=="labels" else 0))

Source: PyTorch docs state: "Sequence Parallel currently assumes the input sequence length can be evenly divided by the number of devices in the mesh" and "When using Shard(dim) as the input/output layouts... we assume the input/output activation tensors are evenly sharded on the tensor dimension dim"

Issue 3: DTensor + All-Reduce Incompatibility ✅ FIXED

Problem: Loss from TP model is DTensor, all_reduce() expects regular tensor Error: "found no DeviceMesh from dtensor args" Solution: Convert DTensor to local before collectives using _to_local()

Working Configuration

# 2D mesh: FSDP=4, TP=2
mesh = init_device_mesh("cuda", (4, 2), mesh_dim_names=("fsdp", "tp"))
fsdp_mesh, tp_mesh = mesh["fsdp"], mesh["tp"]

# Pad inputs for sequence parallel
for key in ("input_ids", "labels", "position_ids"):
  sample[key] = pad_to_multiple(sample[key], tp_mesh.size(), value=(-100 if key=="labels" else 0))

# Apply TP + FSDP
model = tp_full_model(model, tp_mesh)
model = wrap_fsdp2(model, fsdp_mesh)

Loss Parallel Integration

# In model parallelization
"lm_head": ColwiseParallel(use_local_output=False)  # Keep DTensor

# In loss computation
with loss_parallel():
    loss = F.cross_entropy(logits, labels, reduction='none')
    loss.backward()  # backward call also needs to be inside loss_parallel() context

What Works Now ✅

  • ✅ Classical Megatron TP (column/row sharding)
  • Sequence parallel with proper input padding
  • PrepareModuleInput with keyword argument layouts
  • ✅ Loss parallel with DTensor logits
  • ✅ FSDP2 integration with activation checkpointing
  • ✅ 2D parallelism (TP intra-host, FSDP inter-host)

Requirements for Sequence Parallel

  1. Sequence length must be divisible by TP world size - pad inputs if needed
  2. Use input_kwarg_layouts for keyword arguments - most HuggingFace models use kwargs
  3. Proper layout transformations - Shard(1) → Replicate() → Shard(1) flow

Final Architecture

┌─────────────────────────────────────────┐
│ 2D DeviceMesh: (FSDP=4, TP=2)           │
├─────────────────────────────────────────┤
│ Input Padding (seq_len % tp_size == 0)  │
├─────────────────────────────────────────┤
│ FSDP2 (inter-host)                      │
│ ├─ Activation Checkpointing             │
│ ├─ Mixed Precision (bf16)               │
│ └─ Per-layer + Full Model Wrapping      │
├─────────────────────────────────────────┤
│ Tensor + Sequence Parallel (intra-host) │
│ ├─ Column: q/k/v/gate/up projections    │
│ ├─ Row: o_proj/down_proj                │
│ ├─ Sequence: norm layers (Shard(1))     │
│ ├─ PrepareModuleInput: layout transforms│
│ └─ Loss Parallel: vocab-sharded logits  │
└─────────────────────────────────────────┘
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment