Skip to content

Instantly share code, notes, and snippets.

import copy
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._tensor import init_device_mesh
from torch.distributed._tensor import distribute_tensor, DTensor, Shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
import torch.distributed as dist
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
class TestDTensorCompile(DTensorTestBase):
def setUp(self):
super().setUp()
@wanchaol
wanchaol / test_t.py
Created April 20, 2023 22:47
tracing val issue for fused_adam
import torch
from torch._meta_registrations import register_meta
aten = torch.ops.aten
@register_meta([aten._fused_adam.default])
def meta__fused_adam(
params,
addmm(bias, input, weight)
bias: replicated
output = input * weight
input shard(1), weight, shard(0) -> partial tensor
output -> partial -> replicated?
output + bias -> partial? only do on one rank
mesh = DeviceMesh("cuda", [[0, 1], [2, 3]])
placements -> describes how to place the tensor to the device mesh
torch.randn(12, 8)
device mesh dim = 2, size(2, 2)
placements = [Shard(shard_dim=0), Replicate()]
torch.randn(6, 8) -> [0, 1]
@wanchaol
wanchaol / gist:7439039e5c9aa5d51d0adb137be3c361
Last active October 20, 2020 22:43
pybind class_ friend issue
// forward declartion of pybind11::class_
namespace pybind11 {
  template <typename, typename...>
  class class_;
}

// try to friend the pybind11::class_
template<typename TTarget>
class intrusive_ptr {
Perf profile before https://github.com/pytorch/pytorch/pull/33157
-------------------------------------------- --------------- --------------- --------------- --------------- --------------- ---------------
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls
-------------------------------------------- --------------- --------------- --------------- --------------- --------------- ---------------
add_ 94.03% 2223.890s 94.03% 2223.890s 55.625ms 39980
AddmmBackward 0.00% 79.386ms 2.57% 60.762s 10.127ms 6000
mm 2.56% 60.601s 2.56% 60.601s 5.509ms 11000
EmbeddingBagBackward 0.00%
import torch
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def my_method(self):
# type: () -> Tensor
pass
class MyScriptModule(torch.jit.ScriptModule):
import torch
@torch.jit.script
class MyClass(object):
def my_method(self):
# type: () -> Tensor
return torch.randn(10)
@torch.jit.ignore
def mod_init():
class MyScriptModule(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.a = torch.randn(10)
@torch.jit.script_method
def my_method(self):
return self.a