Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Last active August 30, 2024 23:57
Show Gist options
  • Save woshiyyya/31181c98f818f136ac275e188d48b528 to your computer and use it in GitHub Desktop.
Save woshiyyya/31181c98f818f136ac275e188d48b528 to your computer and use it in GitHub Desktop.
Benchmark NCCL Data Transfer
import ray
import torch
from ray.experimental.channel.torch_tensor_type import TorchTensorType
# shape = (4, 8192)
shape = (4, 24576)
@ray.remote(num_gpus=1)
class MyActor:
def __init__(self, rank) -> None:
self.rank = rank
self.tensor = torch.rand(*shape).cuda()
def forward(self, tensor):
if self.rank == 0:
return self.tensor
if self.rank == 3:
return None
return tensor
from ray.dag.input_node import InputNode
actors = [MyActor.remote(rank=i) for i in range(4)]
with InputNode() as input_node:
# dag = actors[0].read_data.bind(input_node)
dag = input_node
for rank, actor in enumerate(actors):
dag = actor.forward.bind(dag)
if rank < len(actors) - 1:
dag.with_type_hint(
TorchTensorType(transport=TorchTensorType.NCCL, _shape=shape, _dtype=torch.float32, _direct_return=True)
# TorchTensorType(transport=TorchTensorType.NCCL)
)
dag = dag.experimental_compile()
import time
elapsed_s = []
for i in range(100):
s = time.perf_counter()
ray.get(dag.execute(1))
elapsed_s.append(time.perf_counter() - s)
print(f"Avg execution time: {sum(elapsed_s) / len(elapsed_s) * 1000} ms")
from ray.train.torch import TorchTrainer
import ray.train
import torch.distributed as dist
import torch
import time
# shape = (4, 8192)
shape = (4, 24576)
def train_func():
rank = ray.train.get_context().get_world_rank()
# first one as warm up
tensor = torch.rand(*shape).cuda()
dist.barrier()
n_iters = 100
s = time.perf_counter()
for i in range(n_iters):
if rank > 0:
dist.recv(tensor, src=rank - 1)
if rank < 3:
dist.send(tensor, dst=rank + 1)
dist.barrier()
e = time.perf_counter()
print(f"{tensor.shape} Avg Time elapsed: {(e - s) / n_iters * 1000} ms")
trainer = TorchTrainer(
train_func,
scaling_config=ray.train.ScalingConfig(
num_workers=4,
use_gpu=True
)
)
trainer.fit()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment