Skip to content

Instantly share code, notes, and snippets.

@stephanie-wang
Last active August 15, 2024 17:03
Show Gist options
  • Save stephanie-wang/7a80d96e55b1245349588de165190f7d to your computer and use it in GitHub Desktop.
Save stephanie-wang/7a80d96e55b1245349588de165190f7d to your computer and use it in GitHub Desktop.
1F1B in ray aDAGs
import ray
import ray.dag
@ray.remote
class Worker:
def __init__(self, rank):
self.rank = rank
self.trace = []
def fwd(self, b):
print("fwd", self.rank, b)
self.trace.append(("fwd", b))
return b
def bwd(self, b):
print("bwd", self.rank, b)
self.trace.append(("bwd", b))
return b
def pop_trace(self):
trace = self.trace
self.trace = []
return trace
NUM_WORKERS = 4
workers = [Worker.remote(i) for i in range(NUM_WORKERS)]
NUM_MICROBATCHES = 8
NUM_LEAD_MICROBATCHES = 4
with ray.dag.InputNode() as inp:
fwd_queues = [[] for _ in range(NUM_WORKERS)]
bwd_queues = [[] for _ in range(NUM_WORKERS)]
# Once a worker's counter reaches 0, it cannot execute another fwd until it
# executes a bwd first.
fwd_counter = [NUM_LEAD_MICROBATCHES - i for i in range(NUM_WORKERS)]
# All of the done batches.
done = []
# FWD on worker 0.
for i in range(NUM_MICROBATCHES):
fwd_queues[0].append(inp[i])
while len(done) < NUM_MICROBATCHES:
for i, worker in enumerate(workers):
if fwd_counter[i] > 0 and fwd_queues[i]:
b = fwd_queues[i].pop(0)
b = worker.fwd.bind(b)
if i < NUM_WORKERS - 1:
fwd_queues[i + 1].append(b)
else:
bwd_queues[i].append(b)
fwd_counter[i] -= 1
elif bwd_queues[i]:
b = bwd_queues[i].pop(0)
b = worker.bwd.bind(b)
if i > 0:
bwd_queues[i - 1].append(b)
else:
done.append(b)
fwd_counter[i] += 1
dag = ray.dag.MultiOutputNode(done)
dag = dag.experimental_compile()
ray.get(dag.execute(*range(NUM_MICROBATCHES)))
print(ray.get(workers[0].pop_trace.remote()))
@ray.remote(num_gpus=1)
class EncoderWorker(BaseWorker):
# define and initialize model...
def init_tp_strategy(self, tp_plan):
world_size = int(os.ENVIRON["WORLD_SIZE"])
tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
torch.distributed.tensor.parallel.parallelize_module(
self.model, tp_mesh, tp_plan
)
def forward(self, input):
self.features = self.model(input)
return self.features.detach()
def backward(self, other_features):
# calculate some MM loss
self.loss = self.calculate_loss(self.features, other_features)
self.loss.backward()
return None
text_encoder = EncoderWorker.remote()
# define a tensor parallel plan for the vision encoder
tp_plan = ...
vision_encoders = [EncoderWorker.remote() for _ in range(3)]
initialize_dist_group(vision_encoders)
ray.get([worker.init_tp_strategy.remote(tp_plan) for worker in vision_encoders])
with InputNode() as input_node:
# Nested list indexed by data parallel group.
# Inner list is a list of [text encoder gradient, vision encoder gradients]
grads = []
for i, dp_group in enumerate(dp_groups):
text_encoder, vision_encoders = dp_group
text_activations = text_encoder.forward.bind(input_node[i])
vision_activations = [worker.forward.bind(input_node[i]) for worker in vision_encoders]
for dag_node in [text_activations] + vision_activations:
dag_node.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL))
text_bwd = text_encoder.backward.bind(vision_activations[0])
vision_bwd = [worker.backward.bind(text_activations) for worker in vision_encoders]
grads.append([text_bwd] + vision_bwd)
dag = []
for dp_group in dp_groups:
for worker in dp_group:
# allreduce and apply gradients (need to do some list unpacking here).
# Returns None to indicate that the worker is done.
dag.append(worker.allreduce.bind(*grads))
dag = MultiOutputNode(dag)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment