Last active
August 15, 2024 17:03
-
-
Save stephanie-wang/7a80d96e55b1245349588de165190f7d to your computer and use it in GitHub Desktop.
1F1B in ray aDAGs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ray | |
import ray.dag | |
@ray.remote | |
class Worker: | |
def __init__(self, rank): | |
self.rank = rank | |
self.trace = [] | |
def fwd(self, b): | |
print("fwd", self.rank, b) | |
self.trace.append(("fwd", b)) | |
return b | |
def bwd(self, b): | |
print("bwd", self.rank, b) | |
self.trace.append(("bwd", b)) | |
return b | |
def pop_trace(self): | |
trace = self.trace | |
self.trace = [] | |
return trace | |
NUM_WORKERS = 4 | |
workers = [Worker.remote(i) for i in range(NUM_WORKERS)] | |
NUM_MICROBATCHES = 8 | |
NUM_LEAD_MICROBATCHES = 4 | |
with ray.dag.InputNode() as inp: | |
fwd_queues = [[] for _ in range(NUM_WORKERS)] | |
bwd_queues = [[] for _ in range(NUM_WORKERS)] | |
# Once a worker's counter reaches 0, it cannot execute another fwd until it | |
# executes a bwd first. | |
fwd_counter = [NUM_LEAD_MICROBATCHES - i for i in range(NUM_WORKERS)] | |
# All of the done batches. | |
done = [] | |
# FWD on worker 0. | |
for i in range(NUM_MICROBATCHES): | |
fwd_queues[0].append(inp[i]) | |
while len(done) < NUM_MICROBATCHES: | |
for i, worker in enumerate(workers): | |
if fwd_counter[i] > 0 and fwd_queues[i]: | |
b = fwd_queues[i].pop(0) | |
b = worker.fwd.bind(b) | |
if i < NUM_WORKERS - 1: | |
fwd_queues[i + 1].append(b) | |
else: | |
bwd_queues[i].append(b) | |
fwd_counter[i] -= 1 | |
elif bwd_queues[i]: | |
b = bwd_queues[i].pop(0) | |
b = worker.bwd.bind(b) | |
if i > 0: | |
bwd_queues[i - 1].append(b) | |
else: | |
done.append(b) | |
fwd_counter[i] += 1 | |
dag = ray.dag.MultiOutputNode(done) | |
dag = dag.experimental_compile() | |
ray.get(dag.execute(*range(NUM_MICROBATCHES))) | |
print(ray.get(workers[0].pop_trace.remote())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@ray.remote(num_gpus=1) | |
class EncoderWorker(BaseWorker): | |
# define and initialize model... | |
def init_tp_strategy(self, tp_plan): | |
world_size = int(os.ENVIRON["WORLD_SIZE"]) | |
tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) | |
torch.distributed.tensor.parallel.parallelize_module( | |
self.model, tp_mesh, tp_plan | |
) | |
def forward(self, input): | |
self.features = self.model(input) | |
return self.features.detach() | |
def backward(self, other_features): | |
# calculate some MM loss | |
self.loss = self.calculate_loss(self.features, other_features) | |
self.loss.backward() | |
return None | |
text_encoder = EncoderWorker.remote() | |
# define a tensor parallel plan for the vision encoder | |
tp_plan = ... | |
vision_encoders = [EncoderWorker.remote() for _ in range(3)] | |
initialize_dist_group(vision_encoders) | |
ray.get([worker.init_tp_strategy.remote(tp_plan) for worker in vision_encoders]) | |
with InputNode() as input_node: | |
# Nested list indexed by data parallel group. | |
# Inner list is a list of [text encoder gradient, vision encoder gradients] | |
grads = [] | |
for i, dp_group in enumerate(dp_groups): | |
text_encoder, vision_encoders = dp_group | |
text_activations = text_encoder.forward.bind(input_node[i]) | |
vision_activations = [worker.forward.bind(input_node[i]) for worker in vision_encoders] | |
for dag_node in [text_activations] + vision_activations: | |
dag_node.with_type_hint(TorchTensorType(transport=TorchTensorType.NCCL)) | |
text_bwd = text_encoder.backward.bind(vision_activations[0]) | |
vision_bwd = [worker.backward.bind(text_activations) for worker in vision_encoders] | |
grads.append([text_bwd] + vision_bwd) | |
dag = [] | |
for dp_group in dp_groups: | |
for worker in dp_group: | |
# allreduce and apply gradients (need to do some list unpacking here). | |
# Returns None to indicate that the worker is done. | |
dag.append(worker.allreduce.bind(*grads)) | |
dag = MultiOutputNode(dag) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment