Last active
September 12, 2024 23:11
-
-
Save woshiyyya/2e5bd19226b7b88fa26c6117f1372172 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ray | |
import ray.cluster_utils | |
from ray.experimental.channel.torch_tensor_type import TorchTensorType | |
from ray.dag import InputNode, MultiOutputNode | |
from typing import Optional | |
from ray.dag.compiled_dag_node import CompiledDAG | |
from argparse import ArgumentError, ArgumentParser | |
@ray.remote(num_cpus=0, num_gpus=1) | |
class DummyWorker: | |
def __init__(self, rank: Optional[int] = None): | |
self.rank = rank | |
self.trace = [] | |
def fwd(self, value): | |
# self.trace.append(("FWD", self.rank)) | |
self.trace.append("F") | |
return value | |
def bwd(self, value): | |
# self.trace.append(("BWD", self.rank)) | |
self.trace.append("B") | |
return value | |
def w(self, value): | |
# self.trace.append(("W", self.rank)) | |
self.trace.append("W") | |
return None | |
def echo(self, value): | |
return value | |
def pop_trace(self): | |
trace = self.trace | |
self.trace = [] | |
return trace | |
def read_input(self, input): | |
return input | |
def no_op(self, value): | |
return value | |
def no_op_two(self, value1, value2): | |
return value1, value2 | |
import torchvision | |
import torchvision.transforms as transforms | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
BATCH_SIZE = 4 | |
# FEATURE_SIZE = 8192 | |
FEATURE_SIZE = 24576 | |
FORWARD_SHAPE = (BATCH_SIZE, FEATURE_SIZE) | |
BACKWARD_SHAPE = (BATCH_SIZE, FEATURE_SIZE) | |
def cifar_trainset(dl_path="/tmp/cifar10-data"): | |
transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
transforms.Lambda(lambda x: x.view(-1)), # Flatten the tensor | |
] | |
) | |
trainset = torchvision.datasets.CIFAR10( | |
root=dl_path, train=True, download=True, transform=transform | |
) | |
return trainset | |
@ray.remote(num_gpus=1) | |
class Worker: | |
def __init__(self, n_features, pp_rank, pp_size, micro_batch_size, is_output): | |
self.pp_rank = pp_rank | |
self.rank = pp_rank | |
self.pp_size = pp_size | |
self.trace = [] | |
layers = [] | |
for i in range(1, len(n_features)): | |
in_features, out_features = n_features[i - 1], n_features[i] | |
layers.append(nn.Linear(in_features, out_features)) | |
if not is_output or i < len(n_features) - 1: | |
layers.append(nn.ReLU(inplace=True)) | |
self.module = nn.Sequential(*layers).cuda() | |
self.optimizer = optim.SGD(self.module.parameters(), lr=1e-3) | |
self.loss = nn.CrossEntropyLoss() | |
# if pp_rank == 0: | |
self.initialize_dataloader(micro_batch_size=micro_batch_size) | |
self.input_activations = dict() | |
self.output_activations = dict() | |
self.cached_gradients = dict() | |
self.max_memory_allocated = 0 | |
self.max_memory_reserved = 0 | |
self.fwd_batch_id = 0 | |
self.bwd_batch_id = 0 | |
self.w_batch_id = 0 | |
def initialize_dataloader(self, micro_batch_size): | |
trainset = cifar_trainset() | |
sampler = torch.utils.data.distributed.DistributedSampler( | |
trainset, num_replicas=1, rank=0, shuffle=False | |
) | |
self.trainloader = torch.utils.data.DataLoader( | |
trainset, batch_size=micro_batch_size, shuffle=False, sampler=sampler, pin_memory=True | |
) | |
self.traindata_iter = iter(self.trainloader) | |
# Ste | |
input_activation, targets = next(self.traindata_iter) | |
self.cache_input_activation = input_activation.cuda() | |
self.cache_targets = targets.cuda() | |
def fwd(self, fwd_inputs): | |
self.trace.append("F") | |
input_activation = fwd_inputs | |
targets = self.cache_targets | |
batch_id = self.fwd_batch_id | |
self.fwd_batch_id += 1 | |
if self.pp_rank > 0: | |
input_activation.requires_grad = True | |
input_activation.retain_grad() | |
# Fetch input batches from dataloader | |
if self.pp_rank == 0: | |
input_activation = self.cache_input_activation | |
# Forward Pass | |
self.input_activations[batch_id] = input_activation | |
output_activation = self.module(input_activation) | |
if self.pp_rank == self.pp_size - 1: | |
loss = self.loss(input_activation, targets) | |
self.output_activations[batch_id] = loss | |
return None | |
else: | |
self.output_activations[batch_id] = output_activation | |
return output_activation | |
def bwd(self, bwd_inputs): | |
self.trace.append("B") | |
for param in self.module.parameters(): | |
param.requires_grad = False | |
gradients = bwd_inputs | |
batch_id = self.bwd_batch_id | |
self.bwd_batch_id += 1 | |
self.cached_gradients[batch_id] = gradients | |
# Backward Pass | |
self.output_activations[batch_id].backward(gradients) | |
# (Change #2) Use the below line of code will skip the last i-1 W operations | |
# self.output_activations[batch_id].backward(gradients, retain_graph=True) | |
bwd_gradients = self.input_activations[batch_id].grad | |
# Clear cache to free GRAM | |
input_activations = self.input_activations.pop(batch_id) | |
# Reset the flags | |
for param in self.module.parameters(): | |
param.requires_grad = True | |
input_activations.requires_grad = False | |
# Return None to avoid Actor-Driver Comm | |
if self.pp_rank == 0: | |
return None | |
else: | |
return bwd_gradients | |
def w(self): | |
# (Change #1) Use the correct function signature | |
# def w(self, x): | |
self.trace.append("W") | |
batch_id = self.w_batch_id | |
gradients = self.cached_gradients[batch_id] | |
self.w_batch_id += 1 | |
self.output_activations[batch_id].backward(gradients) | |
self.output_activations.pop(batch_id) | |
self.cached_gradients.pop(batch_id) | |
def pop_trace(self): | |
trace = self.trace | |
self.trace = [] | |
return trace | |
def read_input(self, inp): | |
# Placeholder: (batch_id, activations, targets) | |
self.batch_id = 0 | |
return (None, None, None) | |
def get_memory_logs(self): | |
return [self.max_memory_allocated, self.max_memory_reserved] | |
def get_events(self): | |
events = getattr(self, "__ray_adag_events", []) | |
return [event.to_dict() for event in events] | |
def echo(self, value): | |
return value | |
def generate_feature_dim(pp_size): | |
input_size = 3 * 32 * 32 | |
feature_size = FEATURE_SIZE | |
feature_dim = [] | |
feature_dim.append([input_size, feature_size, feature_size, feature_size]) | |
for _ in range(pp_size - 2): | |
feature_dim.append([feature_size, feature_size, feature_size]) | |
feature_dim.append([feature_size, feature_size, feature_size, 10]) | |
return feature_dim | |
def generate_zbh1_dag(num_workers: int, num_microbatches: int, num_lead_microbatches: int, use_dummy: bool): | |
if use_dummy: | |
workers = [DummyWorker.remote() for _ in range(num_workers)] | |
else: | |
pp_size = num_workers | |
num_lead_microbatches = num_workers | |
feature_dim_list = generate_feature_dim(num_workers) | |
workers = [ | |
Worker.remote( | |
n_features, pp_rank, pp_size, BATCH_SIZE, bool(pp_rank == pp_size - 1) | |
) | |
for pp_rank, n_features in enumerate(feature_dim_list) | |
] | |
with InputNode() as inp: | |
fwd_queues = [[] for _ in range(num_workers)] | |
bwd_queues = [[] for _ in range(num_workers)] | |
# Once a worker's counter reaches 0, it cannot execute another fwd until it | |
# executes a bwd first. | |
fwd_counter = [num_lead_microbatches - i for i in range(num_workers)] | |
bwd_counter = [0 for i in range(num_workers)] | |
# All of the done batches. | |
done = [] | |
# FWD on worker 0. | |
input_data = workers[0].read_input.bind(inp) | |
for i in range(num_microbatches): | |
fwd_queues[0].append(input_data) | |
while len(done) < num_microbatches: | |
for i, worker in enumerate(workers): | |
if fwd_counter[i] > 0 and fwd_queues[i]: | |
b = fwd_queues[i].pop(0) | |
b = worker.fwd.bind(b) | |
if i < num_workers - 1: | |
fwd_queues[i + 1].append(b) | |
# Use NCCL channel for communication between workers. | |
# b.with_type_hint( | |
# TorchTensorType(transport=TorchTensorType.NCCL, _shape=BACKWARD_SHAPE, _dtype=torch.float32, _direct_return=True) | |
# ) | |
b.with_type_hint( | |
TorchTensorType(transport=TorchTensorType.NCCL) | |
) | |
else: | |
bwd_queues[i].append(b) | |
fwd_counter[i] -= 1 | |
elif bwd_queues[i]: | |
b = bwd_queues[i].pop(0) | |
b2 = worker.bwd.bind(b) | |
# Code change for Zero Bubble PP | |
# ++++++++++++++++++++++++++++++++++++++++++++++++ | |
bwd_counter[i] += 1 | |
if bwd_counter[i] > i: | |
w2 = worker.w.bind(b2) | |
if bwd_counter[i] == num_microbatches: | |
for _ in range(i): | |
w2 = worker.w.bind(w2) | |
b2 = worker.echo.bind(b2) | |
# ++++++++++++++++++++++++++++++++++++++++++++++++ | |
if i > 0: | |
bwd_queues[i - 1].append(b2) | |
# Use NCCL channel for communication between workers. | |
# b2.with_type_hint( | |
# TorchTensorType(transport=TorchTensorType.NCCL, _shape=BACKWARD_SHAPE, _dtype=torch.float32, _direct_return=True) | |
# ) | |
b2.with_type_hint( | |
TorchTensorType(transport=TorchTensorType.NCCL) | |
) | |
else: | |
done.append(b2) | |
fwd_counter[i] += 1 | |
dag = MultiOutputNode(done) | |
compiled_dag = dag.experimental_compile() | |
return compiled_dag, workers | |
if __name__ == "__main__": | |
parser = ArgumentParser() | |
parser.add_argument("--use_dummy", action="store_true") | |
args = parser.parse_args() | |
dag, workers = generate_zbh1_dag(num_workers=4, num_lead_microbatches=4, num_microbatches=8, use_dummy=args.use_dummy) | |
ray.get(dag.execute(1)) | |
print(f"Schedule of {'dummy' if args.use_dummy else 'normal'} workers:") | |
for worker in workers: | |
print(ray.get(worker.pop_trace.remote())) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Schedule for Dummy Workers: | |
['F', 'F', 'F', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'B', 'W', 'B', 'W'] | |
['F', 'F', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'B', 'W', 'W'] | |
['F', 'F', 'B', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'W', 'W'] | |
['F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'W', 'W', 'W'] | |
Schedule of Linear Workers if use mismatched func signiture `def w(self):` | |
['F', 'F', 'F', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'B', 'B', 'B'] | |
['F', 'F', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'B', 'B'] | |
['F', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'B'] | |
['F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B'] | |
(Change #1) Schedule of Linear Workers if use a matched func signiture `def w(self, x):` | |
['F', 'F', 'F', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'B', 'W', 'B', 'W'] | |
['F', 'F', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'B', 'W'] | |
['F', 'F', 'B', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W'] | |
['F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W'] | |
(Change #2) Schedule of Linear Workers if adding `retain_graph=True` option in the `bwd` function. | |
['F', 'F', 'F', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'B', 'W', 'B', 'W'] | |
['F', 'F', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'B', 'W', 'W'] | |
['F', 'F', 'B', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'W', 'W'] | |
['F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'W', 'W', 'W'] | |
I don't know why changing the logic in the bwd function will affect the dag execution schedule... |
We can reproduce the bug by chaging the w operation of DummyWorker to:
def w(self, value):
if self.trace[-1] == "W":
raise RuntimeError
...
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The last couple of nodes in the ZBH1 DAG.