Skip to content

Instantly share code, notes, and snippets.

import os
import torch
import torch.distributed as dist
import lovely_tensors as lt; lt.monkey_patch()
def split_tensor(data: torch.Tensor, dim: int) -> torch.Tensor:
rank = dist.get_rank()
world_size = dist.get_world_size()