Skip to content

Instantly share code, notes, and snippets.

@wanchaol
wanchaol / tp.py
Created May 31, 2025 21:33
vanilla TP
def apply_tp(
model: nn.Module,
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8_tensorwise_tp: bool,
enable_async_tp: bool,
):
"""Apply tensor parallelism."""
# 1. Parallelize the embedding and shard its outputs (which are the first
# transformer block's inputs)
import copy
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._tensor import init_device_mesh
from torch.distributed._tensor import distribute_tensor, DTensor, Shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
import torch.distributed as dist
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
class TestDTensorCompile(DTensorTestBase):
def setUp(self):
super().setUp()
@wanchaol
wanchaol / test_t.py
Created April 20, 2023 22:47
tracing val issue for fused_adam
import torch
from torch._meta_registrations import register_meta
aten = torch.ops.aten
@register_meta([aten._fused_adam.default])
def meta__fused_adam(
params,
addmm(bias, input, weight)
bias: replicated
output = input * weight
input shard(1), weight, shard(0) -> partial tensor
output -> partial -> replicated?
output + bias -> partial? only do on one rank
mesh = DeviceMesh("cuda", [[0, 1], [2, 3]])
placements -> describes how to place the tensor to the device mesh
torch.randn(12, 8)
device mesh dim = 2, size(2, 2)
placements = [Shard(shard_dim=0), Replicate()]
torch.randn(6, 8) -> [0, 1]
@wanchaol
wanchaol / gist:7439039e5c9aa5d51d0adb137be3c361
Last active October 20, 2020 22:43
pybind class_ friend issue
// forward declartion of pybind11::class_
namespace pybind11 {
  template <typename, typename...>
  class class_;
}

// try to friend the pybind11::class_
template<typename TTarget>
class intrusive_ptr {
Perf profile before https://github.com/pytorch/pytorch/pull/33157
-------------------------------------------- --------------- --------------- --------------- --------------- --------------- ---------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls
-------------------------------------------- --------------- --------------- --------------- --------------- --------------- ---------------
add_ 94.03% 2223.890s 94.03% 2223.890s 55.625ms 39980
AddmmBackward 0.00% 79.386ms 2.57% 60.762s 10.127ms 6000
mm 2.56% 60.601s 2.56% 60.601s 5.509ms 11000
EmbeddingBagBackward 0.00%
import torch
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def my_method(self):
# type: () -> Tensor
pass
class MyScriptModule(torch.jit.ScriptModule):
import torch
@torch.jit.script
class MyClass(object):
def my_method(self):
# type: () -> Tensor
return torch.randn(10)
@torch.jit.ignore
def mod_init():