Skip to content

Instantly share code, notes, and snippets.

@richardliaw
Created September 29, 2024 23:20
Show Gist options
  • Save richardliaw/d05aab9fde4d18abaf8aadc59fd5346a to your computer and use it in GitHub Desktop.
Save richardliaw/d05aab9fde4d18abaf8aadc59fd5346a to your computer and use it in GitHub Desktop.
from time import perf_counter
from time import sleep
from contextlib import contextmanager
@contextmanager
def catchtime() -> Callable[[], float]:
t1 = t2 = perf_counter()
yield lambda: t2 - t1
t2 = perf_counter()
import ray
import ray.dag
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import torch
@ray.remote(num_gpus=1)
class GPUSender:
def send(self, shape):
return torch.rand(shape, device="cuda")
@ray.remote(num_gpus=1)
class GPUReceiver:
def recv(self, tensor: torch.Tensor):
assert tensor.device.type == "cuda"
return tensor.shape
shape = (1000,1000)
def test_basic():
sender = GPUSender.remote()
receiver = GPUReceiver.remote()
obj = sender.send.remote(shape)
result = receiver.recv.remote(obj)
assert ray.get(result) == shape
def test_dag():
sender = GPUSender.remote()
receiver = GPUReceiver.remote()
with ray.dag.InputNode() as inp:
dag = sender.send.bind(inp)
dag = dag.with_type_hint(TorchTensorType(transport="nccl")))
dag = receiver.recv.bind(dag)
# Creates a NCCL group across the participating actors. The group is destroyed during dag.teardown().
adag = dag.experimental_compile()
# Execute the DAG. Ray aDAG will orchestrate any NCCL ops.
assert ray.get(adag.execute(shape)) == shape
if __name__ == "__main__":
ray.init()
with catchtime() as time:
test_basic()
print(f"Basic: {time()}")
with catchtime() as time:
test_dag()
print(f"DAG: {time()}")
ray.shutdown()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment