Created
October 24, 2023 09:29
-
-
Save joshlk/681b94a66f2b4bc1583bb6e2abd6bf65 to your computer and use it in GitHub Desktop.
Vocab sharding using DTensors
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import ceil | |
from typing import Optional, Tuple, Union | |
import torch | |
from torch import Tensor, nn | |
from torch.distributed._tensor import DTensor, Replicate, Shard, distribute_module, distribute_tensor | |
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding | |
from torch.distributed._tensor.ops.embedding_ops import embedding_rules | |
from torch.distributed._tensor.ops.utils import register_prop_rule | |
from torch.distributed._tensor.placement_types import DTensorSpec, _Partial | |
from torch.nn import functional as F | |
import torch | |
import torch.distributed.distributed_c10d as c10d | |
import torch.multiprocessing as mp | |
import torch.nn.functional as F | |
from torch import Tensor | |
from torch.distributed._tensor import ( | |
DeviceMesh, | |
DTensor, | |
Shard, | |
) | |
from torch.distributed._tensor.op_schema import OpSchema, OpStrategy, RuntimeSchemaInfo | |
from torch.distributed._tensor.ops.math_ops import ( | |
_infer_reduction_dims, | |
common_reduction_strategy, | |
) | |
from torch.distributed._tensor.ops.utils import register_op_strategy | |
from torch.distributed._tensor.placement_types import Shard | |
from torch.autograd.function import once_differentiable | |
from torch.distributed.tensor.parallel._utils import _prepare_input_validate | |
aten = torch.ops.aten | |
### Sharding propagation rules | |
@register_prop_rule(aten.embedding.default) | |
def embedding_rules_custom(op_schema: OpSchema) -> OutputSharding: | |
weight_spec, inp_spec = op_schema.args_spec | |
if weight_spec.placements == (Shard(0),) and inp_spec.placements == (Shard(0),): | |
return OutputSharding( | |
output_spec=DTensorSpec(mesh=inp_spec.mesh, placements=(_Partial(),)), | |
schema_suggestions=[ | |
OpSchema( | |
op=op_schema.op, | |
args_schema=( | |
DTensorSpec(mesh=weight_spec.mesh, placements=(Shard(0),)), | |
DTensorSpec(mesh=inp_spec.mesh, placements=(Shard(0),)), | |
), | |
kwargs_schema={}, | |
) | |
], | |
) | |
# Current embedding rules | |
return embedding_rules(op_schema=op_schema) | |
@register_op_strategy([aten.max.default, aten.max.dim, aten.max.out], schema_info=RuntimeSchemaInfo(1)) | |
def mean_reduction_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: | |
args_schema = op_schema.args_schema | |
input_strategy = args_schema[0] | |
assert isinstance(input_strategy, OpStrategy) | |
dims = None | |
if len(op_schema.args_schema) > 1: | |
dims = _infer_reduction_dims(args_schema[1], input_strategy.output_ndim) | |
reduce_dims = list(range(input_strategy.output_ndim)) if dims is None else dims | |
keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2]) | |
return common_reduction_strategy( | |
mesh, | |
input_strategy, | |
reduce_dims, | |
keep_dim=keep_dim, | |
reduction_linear=True, | |
reduction_op=c10d.ReduceOp.MAX, | |
) | |
### Module and function changes | |
class Embedding(nn.Module): | |
"""Same as nn.Embedding but uses `embedding` Function which has `zero_OOR` option.""" | |
def __init__(self, num_embeddings: int, embedding_dim: int, device=None, dtype=None, zero_OOR=False) -> None: | |
super().__init__() | |
self.num_embeddings = num_embeddings | |
self.embedding_dim = embedding_dim | |
self.weight = nn.Parameter( | |
torch.empty((num_embeddings, embedding_dim), device=device, dtype=dtype), requires_grad=True | |
) | |
nn.init.normal_(self.weight) | |
self.tp_shards = 1 | |
self.device_mesh: Optional[DeviceMesh] = None | |
self.zero_OOR = zero_OOR | |
def forward(self, input: Tensor) -> Tensor: | |
if self.zero_OOR: | |
weight_shape = self.weight.to_local().shape if isinstance(self.weight, DTensor) else self.weight.shape | |
max_possible_index = weight_shape[0] - 1 | |
OOR_indices = (input < 0) | (input > max_possible_index) | |
input = torch.where(OOR_indices, 0, input) | |
output = F.embedding( | |
input=input, | |
weight=self.weight, | |
) | |
if self.zero_OOR: | |
output = torch.where(OOR_indices.unsqueeze(-1), 0, output) | |
return output | |
def parallelize(self, device_mesh: DeviceMesh) -> None: | |
self.tp_shards = device_mesh.size() | |
global_vocab_size = self.weight.shape[0] | |
def partition_embedding_vocab_fn(name: str, module: nn.Module, device_mesh: DeviceMesh): | |
# Shard in vocab axis | |
weight: torch.Tensor = module.weight | |
weight = distribute_tensor(weight, device_mesh, [Shard(0)]) | |
module.register_parameter("weight", nn.Parameter(weight)) | |
@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56] | |
def vocab_shard_input_fn( | |
inputs: Tuple[Union[torch.Tensor, DTensor], ...], | |
device_mesh: Optional[DeviceMesh] = None, | |
) -> DTensor: | |
input = inputs | |
vocab_shard_size = ceil(global_vocab_size / self.tp_shards) | |
# Adjust indices so they align with local embedding indices | |
offset = vocab_shard_size * dist.get_rank() | |
input_adjusted = input - offset | |
input_adjusted = DTensor.from_local(input_adjusted, device_mesh=device_mesh, placements=[Shard(0)]) | |
return input_adjusted | |
# Acts inplace on module | |
distribute_module(self, device_mesh, partition_embedding_vocab_fn, vocab_shard_input_fn) | |
class crossentropylosssharded(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, input: Tensor, target: Tensor): | |
"""Forward has two collectives: max and sum""" | |
# log_stable_softmax = x-max(x) - log(sum(e^(x-max(x)))) | |
max_input, _ = input.max(dim=1) | |
translated = input - max_input.reshape(-1, 1) | |
denominator = torch.log(torch.exp(translated).sum(dim=1)) | |
logsoftmax = translated - denominator.reshape(-1, 1) | |
loss = torch.gather(-logsoftmax, 1, target.reshape(-1, 1)).flatten() | |
ctx.save_for_backward(target, logsoftmax) | |
return loss | |
@staticmethod | |
@once_differentiable | |
def backward(ctx, loss_grad: Tensor): | |
target, logsoftmax = ctx.saved_tensors | |
# If: | |
# y = softmax(x) | |
# z = nnl_loss(y) | |
# Then, for a given scalar loss l: | |
# dl/dx = dl/dz @ dz/dx | |
# dl/dx = dl/dz . (y - I_t) | |
# Where I_t is one-hot encoded target matrix. '.' is element-wise multiplication with broadcasting | |
softmax = torch.exp(logsoftmax) | |
softmax[torch.arange(len(softmax)), target] -= 1 | |
input_grad = loss_grad.reshape(-1, 1) * softmax | |
return input_grad, None | |
class CrossEntropyLossSharded(nn.Module): | |
def __init__(self, zero_OOR=False) -> None: | |
super().__init__() | |
self.zero_OOR = zero_OOR | |
def forward(self, input, target): | |
if self.zero_OOR: | |
input_shape = input.to_local().shape if isinstance(input, DTensor) else input.shape | |
max_possible_index = input_shape[1] - 1 | |
OOR_indices = (target < 0) | (target > max_possible_index) | |
target_l = torch.where(OOR_indices, 0, target_l) | |
loss = crossentropylosssharded.apply(input, target) | |
if self.zero_OOR: | |
loss = torch.where(OOR_indices, 0, loss) | |
return loss | |
def parallelize(self, classes, device_mesh: DeviceMesh) -> None: | |
tp_shards = device_mesh.size() | |
@_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56] | |
def vocab_shard_input_fn( | |
inputs: Tuple[Union[torch.Tensor, DTensor], ...], | |
device_mesh: Optional[DeviceMesh] = None, | |
) -> DTensor: | |
input, target = inputs | |
vocab_shard_size = ceil(classes / tp_shards) | |
# Adjust indices so they align with local embedding indices | |
offset = vocab_shard_size * dist.get_rank() | |
target_adjusted = target - offset | |
target_adjusted = DTensor.from_local(target_adjusted, device_mesh=device_mesh, placements=[Shard(0)]) | |
return input, target_adjusted | |
# Acts inplace on module | |
distribute_module(self, device_mesh, input_fn=vocab_shard_input_fn) | |
### Example | |
import os | |
import torch.distributed as dist | |
def worker(rank: int, world_size: int, port: int): | |
# Setup env in worker to prevent pollution of parent's env | |
os.environ["MASTER_ADDR"] = "localhost" | |
os.environ["MASTER_PORT"] = str(port) | |
# Init dist with gloo/cpu backend | |
dist.init_process_group("gloo", rank=rank, world_size=world_size) | |
device_mesh = DeviceMesh("cpu", torch.arange(world_size)) | |
# Inputs | |
items = 3 | |
classes = 4 | |
hidden = 2 | |
classes_tp = classes // world_size | |
torch.manual_seed(42) | |
input = torch.randint(classes, (items,)) # torch.rand((items, classes), requires_grad=True) | |
target = torch.randint(classes, (items,)) | |
# DTensors | |
input = DTensor.from_local(input, device_mesh, [Replicate()]) | |
target = DTensor.from_local(target, device_mesh, [Replicate()]) | |
# Model | |
embedding = Embedding(classes, hidden, zero_OOR=True) | |
loss_mod = CrossEntropyLossSharded(zero_OOR=True) | |
embedding.parallelize(device_mesh) | |
loss_mod.parallelize(classes, device_mesh) | |
# Embedding layer | |
x = embedding(input) | |
# Transformer layers here... | |
# Head layer | |
x = x @ embedding.weight.T # Head has tied weight with embedding | |
loss = loss_mod(input, target) | |
print(loss) | |
if __name__ == "__main__": | |
# Example | |
world_size = 2 | |
port = 8345 | |
mp.spawn(worker, nprocs=world_size, args=(world_size, port)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment