Created
May 31, 2025 21:33
-
-
Save wanchaol/6b79a1c69bd96b319329d2b22d184e88 to your computer and use it in GitHub Desktop.
vanilla TP
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def apply_tp( | |
model: nn.Module, | |
tp_mesh: DeviceMesh, | |
loss_parallel: bool, | |
enable_float8_tensorwise_tp: bool, | |
enable_async_tp: bool, | |
): | |
"""Apply tensor parallelism.""" | |
# 1. Parallelize the embedding and shard its outputs (which are the first | |
# transformer block's inputs) | |
# 2. Parallelize the root norm layer over the sequence dim | |
# 3. Parallelize the final linear output layer | |
parallelize_module( | |
model, | |
tp_mesh, | |
{ | |
"tok_embeddings": RowwiseParallel( | |
input_layouts=Replicate(), | |
output_layouts=Replicate(), | |
), | |
"output": ColwiseParallel( | |
input_layouts=Replicate(), | |
output_layouts=Replicate(), | |
use_local_output=True, | |
), | |
}, | |
) | |
# Parallel styles used for transformer block linear weights and their | |
# inputs may be different for float8 linears with tensorwise scaling. | |
if enable_float8_tensorwise_tp: | |
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there | |
from torchao.float8.float8_tensor_parallel import ( | |
Float8ColwiseParallel, | |
Float8RowwiseParallel, | |
PrepareFloat8ModuleInput, | |
) | |
rowwise_parallel, colwise_parallel, prepare_module_input = ( | |
Float8RowwiseParallel, | |
Float8ColwiseParallel, | |
PrepareFloat8ModuleInput, | |
) | |
else: | |
rowwise_parallel, colwise_parallel, prepare_module_input = ( | |
RowwiseParallel, | |
ColwiseParallel, | |
PrepareModuleInput, | |
) | |
# Apply tensor + sequence parallelism to every transformer block | |
# NOTE: At the cost of model code change, we can accelerate Sequence Parallel | |
# by folding (and unfolding) the batch dimension and the sequence dimension. | |
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437 | |
for transformer_block in model.layers.values(): | |
layer_plan = { | |
"attention.wq": colwise_parallel(), | |
"attention.wk": colwise_parallel(), | |
"attention.wv": colwise_parallel(), | |
"attention.wo": rowwise_parallel(), | |
"feed_forward.w1": colwise_parallel(), | |
"feed_forward.w2": rowwise_parallel(), | |
"feed_forward.w3": colwise_parallel(), | |
} | |
parallelize_module( | |
module=transformer_block, | |
device_mesh=tp_mesh, | |
parallelize_plan=layer_plan, | |
) | |
if enable_async_tp: | |
from torch.distributed._symmetric_memory import enable_symm_mem_for_group | |
torch._inductor.config._micro_pipeline_tp = True | |
enable_symm_mem_for_group(tp_mesh.get_group().group_name) | |
logger.info( | |
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" | |
"Tensor Parallelism to the model" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment