Skip to content

Instantly share code, notes, and snippets.

@wanchaol
Created May 31, 2025 21:33
Show Gist options
  • Save wanchaol/6b79a1c69bd96b319329d2b22d184e88 to your computer and use it in GitHub Desktop.
Save wanchaol/6b79a1c69bd96b319329d2b22d184e88 to your computer and use it in GitHub Desktop.
vanilla TP
def apply_tp(
model: nn.Module,
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8_tensorwise_tp: bool,
enable_async_tp: bool,
):
"""Apply tensor parallelism."""
# 1. Parallelize the embedding and shard its outputs (which are the first
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
# 3. Parallelize the final linear output layer
parallelize_module(
model,
tp_mesh,
{
"tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
output_layouts=Replicate(),
),
"output": ColwiseParallel(
input_layouts=Replicate(),
output_layouts=Replicate(),
use_local_output=True,
),
},
)
# Parallel styles used for transformer block linear weights and their
# inputs may be different for float8 linears with tensorwise scaling.
if enable_float8_tensorwise_tp:
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
PrepareFloat8ModuleInput,
)
rowwise_parallel, colwise_parallel, prepare_module_input = (
Float8RowwiseParallel,
Float8ColwiseParallel,
PrepareFloat8ModuleInput,
)
else:
rowwise_parallel, colwise_parallel, prepare_module_input = (
RowwiseParallel,
ColwiseParallel,
PrepareModuleInput,
)
# Apply tensor + sequence parallelism to every transformer block
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel
# by folding (and unfolding) the batch dimension and the sequence dimension.
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
for transformer_block in model.layers.values():
layer_plan = {
"attention.wq": colwise_parallel(),
"attention.wk": colwise_parallel(),
"attention.wv": colwise_parallel(),
"attention.wo": rowwise_parallel(),
"feed_forward.w1": colwise_parallel(),
"feed_forward.w2": rowwise_parallel(),
"feed_forward.w3": colwise_parallel(),
}
parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
parallelize_plan=layer_plan,
)
if enable_async_tp:
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
torch._inductor.config._micro_pipeline_tp = True
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
logger.info(
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
"Tensor Parallelism to the model"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment