// forward declartion of pybind11::class_
namespace pybind11 {
template <typename, typename...>
class class_;
}
// try to friend the pybind11::class_
template<typename TTarget>
class intrusive_ptr {
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
from dataclasses import dataclass | |
from typing import Callable, Optional, Tuple | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
from torch.distributed._tensor import init_device_mesh | |
from torch.distributed._tensor import distribute_tensor, DTensor, Shard | |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.distributed as dist | |
from torch.testing._internal.distributed._tensor.common_dtensor import ( | |
DTensorTestBase, | |
with_comms, | |
) | |
class TestDTensorCompile(DTensorTestBase): | |
def setUp(self): | |
super().setUp() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch._meta_registrations import register_meta | |
aten = torch.ops.aten | |
@register_meta([aten._fused_adam.default]) | |
def meta__fused_adam( | |
params, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
addmm(bias, input, weight) | |
bias: replicated | |
output = input * weight | |
input shard(1), weight, shard(0) -> partial tensor | |
output -> partial -> replicated? | |
output + bias -> partial? only do on one rank |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mesh = DeviceMesh("cuda", [[0, 1], [2, 3]]) | |
placements -> describes how to place the tensor to the device mesh | |
torch.randn(12, 8) | |
device mesh dim = 2, size(2, 2) | |
placements = [Shard(shard_dim=0), Replicate()] | |
torch.randn(6, 8) -> [0, 1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Perf profile before https://github.com/pytorch/pytorch/pull/33157 | |
-------------------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- | |
Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls | |
-------------------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- | |
add_ 94.03% 2223.890s 94.03% 2223.890s 55.625ms 39980 | |
AddmmBackward 0.00% 79.386ms 2.57% 60.762s 10.127ms 6000 | |
mm 2.56% 60.601s 2.56% 60.601s 5.509ms 11000 | |
EmbeddingBagBackward 0.00% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
@torch.jit.interface | |
class ModuleInterface(torch.nn.Module): | |
def my_method(self): | |
# type: () -> Tensor | |
pass | |
class MyScriptModule(torch.jit.ScriptModule): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
@torch.jit.script | |
class MyClass(object): | |
def my_method(self): | |
# type: () -> Tensor | |
return torch.randn(10) | |
@torch.jit.ignore | |
def mod_init(): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MyScriptModule(torch.jit.ScriptModule): | |
def __init__(self): | |
super().__init__() | |
self.a = torch.randn(10) | |
@torch.jit.script_method | |
def my_method(self): | |
return self.a | |
NewerOlder