Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Last active September 12, 2024 23:11
Show Gist options
  • Save woshiyyya/2e5bd19226b7b88fa26c6117f1372172 to your computer and use it in GitHub Desktop.
Save woshiyyya/2e5bd19226b7b88fa26c6117f1372172 to your computer and use it in GitHub Desktop.
import ray
import ray.cluster_utils
from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray.dag import InputNode, MultiOutputNode
from typing import Optional
from ray.dag.compiled_dag_node import CompiledDAG
from argparse import ArgumentError, ArgumentParser
@ray.remote(num_cpus=0, num_gpus=1)
class DummyWorker:
def __init__(self, rank: Optional[int] = None):
self.rank = rank
self.trace = []
def fwd(self, value):
# self.trace.append(("FWD", self.rank))
self.trace.append("F")
return value
def bwd(self, value):
# self.trace.append(("BWD", self.rank))
self.trace.append("B")
return value
def w(self, value):
# self.trace.append(("W", self.rank))
self.trace.append("W")
return None
def echo(self, value):
return value
def pop_trace(self):
trace = self.trace
self.trace = []
return trace
def read_input(self, input):
return input
def no_op(self, value):
return value
def no_op_two(self, value1, value2):
return value1, value2
import torchvision
import torchvision.transforms as transforms
import time
import torch
import torch.nn as nn
import torch.optim as optim
BATCH_SIZE = 4
# FEATURE_SIZE = 8192
FEATURE_SIZE = 24576
FORWARD_SHAPE = (BATCH_SIZE, FEATURE_SIZE)
BACKWARD_SHAPE = (BATCH_SIZE, FEATURE_SIZE)
def cifar_trainset(dl_path="/tmp/cifar10-data"):
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Lambda(lambda x: x.view(-1)), # Flatten the tensor
]
)
trainset = torchvision.datasets.CIFAR10(
root=dl_path, train=True, download=True, transform=transform
)
return trainset
@ray.remote(num_gpus=1)
class Worker:
def __init__(self, n_features, pp_rank, pp_size, micro_batch_size, is_output):
self.pp_rank = pp_rank
self.rank = pp_rank
self.pp_size = pp_size
self.trace = []
layers = []
for i in range(1, len(n_features)):
in_features, out_features = n_features[i - 1], n_features[i]
layers.append(nn.Linear(in_features, out_features))
if not is_output or i < len(n_features) - 1:
layers.append(nn.ReLU(inplace=True))
self.module = nn.Sequential(*layers).cuda()
self.optimizer = optim.SGD(self.module.parameters(), lr=1e-3)
self.loss = nn.CrossEntropyLoss()
# if pp_rank == 0:
self.initialize_dataloader(micro_batch_size=micro_batch_size)
self.input_activations = dict()
self.output_activations = dict()
self.cached_gradients = dict()
self.max_memory_allocated = 0
self.max_memory_reserved = 0
self.fwd_batch_id = 0
self.bwd_batch_id = 0
self.w_batch_id = 0
def initialize_dataloader(self, micro_batch_size):
trainset = cifar_trainset()
sampler = torch.utils.data.distributed.DistributedSampler(
trainset, num_replicas=1, rank=0, shuffle=False
)
self.trainloader = torch.utils.data.DataLoader(
trainset, batch_size=micro_batch_size, shuffle=False, sampler=sampler, pin_memory=True
)
self.traindata_iter = iter(self.trainloader)
# Ste
input_activation, targets = next(self.traindata_iter)
self.cache_input_activation = input_activation.cuda()
self.cache_targets = targets.cuda()
def fwd(self, fwd_inputs):
self.trace.append("F")
input_activation = fwd_inputs
targets = self.cache_targets
batch_id = self.fwd_batch_id
self.fwd_batch_id += 1
if self.pp_rank > 0:
input_activation.requires_grad = True
input_activation.retain_grad()
# Fetch input batches from dataloader
if self.pp_rank == 0:
input_activation = self.cache_input_activation
# Forward Pass
self.input_activations[batch_id] = input_activation
output_activation = self.module(input_activation)
if self.pp_rank == self.pp_size - 1:
loss = self.loss(input_activation, targets)
self.output_activations[batch_id] = loss
return None
else:
self.output_activations[batch_id] = output_activation
return output_activation
def bwd(self, bwd_inputs):
self.trace.append("B")
for param in self.module.parameters():
param.requires_grad = False
gradients = bwd_inputs
batch_id = self.bwd_batch_id
self.bwd_batch_id += 1
self.cached_gradients[batch_id] = gradients
# Backward Pass
self.output_activations[batch_id].backward(gradients)
# (Change #2) Use the below line of code will skip the last i-1 W operations
# self.output_activations[batch_id].backward(gradients, retain_graph=True)
bwd_gradients = self.input_activations[batch_id].grad
# Clear cache to free GRAM
input_activations = self.input_activations.pop(batch_id)
# Reset the flags
for param in self.module.parameters():
param.requires_grad = True
input_activations.requires_grad = False
# Return None to avoid Actor-Driver Comm
if self.pp_rank == 0:
return None
else:
return bwd_gradients
def w(self):
# (Change #1) Use the correct function signature
# def w(self, x):
self.trace.append("W")
batch_id = self.w_batch_id
gradients = self.cached_gradients[batch_id]
self.w_batch_id += 1
self.output_activations[batch_id].backward(gradients)
self.output_activations.pop(batch_id)
self.cached_gradients.pop(batch_id)
def pop_trace(self):
trace = self.trace
self.trace = []
return trace
def read_input(self, inp):
# Placeholder: (batch_id, activations, targets)
self.batch_id = 0
return (None, None, None)
def get_memory_logs(self):
return [self.max_memory_allocated, self.max_memory_reserved]
def get_events(self):
events = getattr(self, "__ray_adag_events", [])
return [event.to_dict() for event in events]
def echo(self, value):
return value
def generate_feature_dim(pp_size):
input_size = 3 * 32 * 32
feature_size = FEATURE_SIZE
feature_dim = []
feature_dim.append([input_size, feature_size, feature_size, feature_size])
for _ in range(pp_size - 2):
feature_dim.append([feature_size, feature_size, feature_size])
feature_dim.append([feature_size, feature_size, feature_size, 10])
return feature_dim
def generate_zbh1_dag(num_workers: int, num_microbatches: int, num_lead_microbatches: int, use_dummy: bool):
if use_dummy:
workers = [DummyWorker.remote() for _ in range(num_workers)]
else:
pp_size = num_workers
num_lead_microbatches = num_workers
feature_dim_list = generate_feature_dim(num_workers)
workers = [
Worker.remote(
n_features, pp_rank, pp_size, BATCH_SIZE, bool(pp_rank == pp_size - 1)
)
for pp_rank, n_features in enumerate(feature_dim_list)
]
with InputNode() as inp:
fwd_queues = [[] for _ in range(num_workers)]
bwd_queues = [[] for _ in range(num_workers)]
# Once a worker's counter reaches 0, it cannot execute another fwd until it
# executes a bwd first.
fwd_counter = [num_lead_microbatches - i for i in range(num_workers)]
bwd_counter = [0 for i in range(num_workers)]
# All of the done batches.
done = []
# FWD on worker 0.
input_data = workers[0].read_input.bind(inp)
for i in range(num_microbatches):
fwd_queues[0].append(input_data)
while len(done) < num_microbatches:
for i, worker in enumerate(workers):
if fwd_counter[i] > 0 and fwd_queues[i]:
b = fwd_queues[i].pop(0)
b = worker.fwd.bind(b)
if i < num_workers - 1:
fwd_queues[i + 1].append(b)
# Use NCCL channel for communication between workers.
# b.with_type_hint(
# TorchTensorType(transport=TorchTensorType.NCCL, _shape=BACKWARD_SHAPE, _dtype=torch.float32, _direct_return=True)
# )
b.with_type_hint(
TorchTensorType(transport=TorchTensorType.NCCL)
)
else:
bwd_queues[i].append(b)
fwd_counter[i] -= 1
elif bwd_queues[i]:
b = bwd_queues[i].pop(0)
b2 = worker.bwd.bind(b)
# Code change for Zero Bubble PP
# ++++++++++++++++++++++++++++++++++++++++++++++++
bwd_counter[i] += 1
if bwd_counter[i] > i:
w2 = worker.w.bind(b2)
if bwd_counter[i] == num_microbatches:
for _ in range(i):
w2 = worker.w.bind(w2)
b2 = worker.echo.bind(b2)
# ++++++++++++++++++++++++++++++++++++++++++++++++
if i > 0:
bwd_queues[i - 1].append(b2)
# Use NCCL channel for communication between workers.
# b2.with_type_hint(
# TorchTensorType(transport=TorchTensorType.NCCL, _shape=BACKWARD_SHAPE, _dtype=torch.float32, _direct_return=True)
# )
b2.with_type_hint(
TorchTensorType(transport=TorchTensorType.NCCL)
)
else:
done.append(b2)
fwd_counter[i] += 1
dag = MultiOutputNode(done)
compiled_dag = dag.experimental_compile()
return compiled_dag, workers
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--use_dummy", action="store_true")
args = parser.parse_args()
dag, workers = generate_zbh1_dag(num_workers=4, num_lead_microbatches=4, num_microbatches=8, use_dummy=args.use_dummy)
ray.get(dag.execute(1))
print(f"Schedule of {'dummy' if args.use_dummy else 'normal'} workers:")
for worker in workers:
print(ray.get(worker.pop_trace.remote()))
Schedule for Dummy Workers:
['F', 'F', 'F', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'B', 'W', 'B', 'W']
['F', 'F', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'B', 'W', 'W']
['F', 'F', 'B', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'W', 'W']
['F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'W', 'W', 'W']
Schedule of Linear Workers if use mismatched func signiture `def w(self):`
['F', 'F', 'F', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'B', 'B', 'B']
['F', 'F', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'B', 'B']
['F', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'B']
['F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'F', 'B']
(Change #1) Schedule of Linear Workers if use a matched func signiture `def w(self, x):`
['F', 'F', 'F', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'B', 'W', 'B', 'W']
['F', 'F', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'B', 'W']
['F', 'F', 'B', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W']
['F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W']
(Change #2) Schedule of Linear Workers if adding `retain_graph=True` option in the `bwd` function.
['F', 'F', 'F', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'B', 'W', 'B', 'W']
['F', 'F', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'B', 'W', 'W']
['F', 'F', 'B', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'B', 'W', 'W', 'W']
['F', 'B', 'F', 'B', 'F', 'B', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'F', 'B', 'W', 'W', 'W', 'W']
I don't know why changing the logic in the bwd function will affect the dag execution schedule...
@woshiyyya
Copy link
Author

PP@2x (6)

The last couple of nodes in the ZBH1 DAG.

@woshiyyya
Copy link
Author

We can reproduce the bug by chaging the w operation of DummyWorker to:

def w(self, value): 
    if self.trace[-1] == "W":
        raise RuntimeError
   ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment