Created
August 20, 2024 00:28
-
-
Save woshiyyya/bbc528cf39e88c1c2b6d9766f149a119 to your computer and use it in GitHub Desktop.
DAG NCCL channel error when binding with a node of the same actor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ray | |
import torch | |
from ray.dag.input_node import InputNode | |
from ray.dag.output_node import MultiOutputNode | |
from ray.experimental.channel.torch_tensor_type import TorchTensorType | |
@ray.remote(num_gpus=1) | |
class MyActor: | |
def __init__(self): | |
pass | |
def entrypoint(self, inp): | |
pass | |
def aggregate(self, *args): | |
return args[0] | |
def forward(self, inp): | |
return torch.randn(10, 10).cuda() | |
workers = [MyActor.remote() for _ in range(2)] | |
with InputNode() as input_node: | |
entrypoint = workers[0].entrypoint.bind(input_node) | |
activations = [worker.forward.bind(entrypoint) for worker in workers] | |
for activation in activations: | |
activation.with_type_hint( | |
TorchTensorType(transport=TorchTensorType.NCCL) | |
) | |
dag = workers[0].aggregate.bind(*activations) | |
dag = dag.experimental_compile() | |
dag.execute(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment