Skip to content

Instantly share code, notes, and snippets.

@woshiyyya
Last active October 9, 2024 03:48
Show Gist options
  • Save woshiyyya/d5954d23a444976f28a4366f2ac46676 to your computer and use it in GitHub Desktop.
Save woshiyyya/d5954d23a444976f28a4366f2ac46676 to your computer and use it in GitHub Desktop.
Use adag to train a llama2-7b model with zero bubble pipeline parallel
def generate_zbh1_dag(workers, num_microbatches):
num_workers = len(workers)
num_lead_microbatches = num_workers
with InputNode() as inp:
fwd_queues = [[] for _ in range(num_workers)]
bwd_queues = [[] for _ in range(num_workers)]
# Once a worker's counter reaches 0, it cannot execute another fwd until it
# executes a bwd first.
fwd_counter = [num_lead_microbatches - i for i in range(num_workers)]
bwd_counter = [0 for _ in range(num_workers)]
# All of the done batches.
done = []
# FWD on worker 0.
input_data = workers[0].read_input.bind(inp)
for i in range(num_microbatches):
fwd_queues[0].append(input_data)
while len(done) < num_microbatches:
for i, worker in enumerate(workers):
if fwd_counter[i] > 0 and fwd_queues[i]:
b = fwd_queues[i].pop(0)
b = worker.fwd.bind(b)
if i < num_workers - 1:
fwd_queues[i + 1].append(b)
# Use NCCL channel for communication between workers.
b.with_type_hint(
TorchTensorType(
transport=TorchTensorType.NCCL,
_shape=FORWARD_SHAPE,
_dtype=torch.float32,
_direct_return=True,
)
)
else:
bwd_queues[i].append(b)
fwd_counter[i] -= 1
elif bwd_queues[i]:
b = bwd_queues[i].pop(0)
b = worker.bwd.bind(b)
# Code changes for Zero Bubble PP
# ++++++++++++++++++++++++++++++++++++++++++++++++
bwd_counter[i] += 1
w_offset = i
if bwd_counter[i] > w_offset:
echo_b = worker.echo.bind(b)
w = worker.w.bind(b)
if bwd_counter[i] == num_microbatches:
for _ in range(w_offset):
w = worker.w.bind(w)
else:
echo_b = None
if echo_b:
b = echo_b
# ++++++++++++++++++++++++++++++++++++++++++++++++
if i > 0:
bwd_queues[i - 1].append(b)
# Use NCCL channel for communication between workers.
b.with_type_hint(
TorchTensorType(
transport=TorchTensorType.NCCL,
_shape=BACKWARD_SHAPE,
_dtype=torch.float32,
_direct_return=True,
)
)
else:
done.append(b)
fwd_counter[i] += 1
dag = MultiOutputNode(done)
compiled_dag = dag.experimental_compile()
return compiled_dag
import time
import torch
import torch.nn as nn
from collections import defaultdict
from pytorch_lightning import seed_everything
from transformers import LlamaForCausalLM
import ray
import ray.cluster_utils
from ray.dag import InputNode, MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
NUM_WORKERS = 8
NUM_MICROBATCHES = 16
BATCH_SIZE = 3
SEQ_LENGTH = 1024
HIDDEN_SIZE = 4096
VOCAB_SIZE = 32000
FORWARD_SHAPE = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
BACKWARD_SHAPE = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
class WrappedLinear(nn.Module):
def __init__(self, linear, pp_rank, pp_size, num_batches=NUM_MICROBATCHES) -> None:
super().__init__()
self.linear = linear
self.output_activations = []
self.output_grads = []
self.is_w = False
self.pp_rank = pp_rank
self.pp_size = pp_size
self.num_batches = num_batches
def forward(self, x):
output = self.linear(x)
output.register_hook(self.save_gradients)
# Cache output activations for W operation
self.output_activations.append(output)
return output
def w(self):
self.is_w = True
output_activation = self.output_activations.pop(0)
output_grad = self.output_grads.pop(0)
output_activation.backward(output_grad, inputs=list(self.parameters()))
self.is_w = False
def save_gradients(self, grad):
# Cache the output gradient when doing backward
if not self.is_w:
self.output_grads.append(grad)
return grad
@ray.remote(num_gpus=1)
class Worker:
def __init__(self, pp_rank, pp_size):
seed_everything(420)
self.pp_rank = pp_rank
self.pp_size = pp_size
self.trace = []
self.timer = defaultdict(list)
llama = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# assert llama.config.num_hidden_layers % pp_size == 0
layers_per_chunk = llama.config.num_hidden_layers // pp_size
# Partition llama model
self.wrapped_linears = []
if pp_rank == 0:
self.embedding = llama.model.embed_tokens.cuda()
for param in self.embedding.parameters():
param.requires_grad = False
if pp_rank == pp_size - 1:
self.lm_head = llama.lm_head.cuda()
self.layers = llama.model.layers[
layers_per_chunk * pp_rank : layers_per_chunk * (pp_rank + 1)
]
# Wrap linear layers for disaggregated BWD and W operations.
self.num_replaced = 0
for layer in self.layers:
named_modules = dict(layer.named_modules())
for name, module in named_modules.items():
named_children = dict(module.named_children())
for child_name, child in named_children.items():
if isinstance(child, nn.Linear):
wrapped_child = WrappedLinear(child, self.pp_rank, self.pp_size)
setattr(module, child_name, wrapped_child)
self.num_replaced += 1
self.wrapped_linears.append(wrapped_child)
layer.cuda()
self.loss_fct = nn.CrossEntropyLoss()
self.input_activations = dict()
self.output_activations = dict()
self.fwd_batch_id = 0
self.bwd_batch_id = 0
self.w_batch_id = 0
# Cached dummy inputs
self.cached_input_ids = torch.randint(
0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH)
).cuda()
self.cache_shift_labels = self.cached_input_ids[..., 1:].contiguous().view(-1)
self.position_ids = torch.arange(0, SEQ_LENGTH).unsqueeze(0).cuda()
def calculate_loss(self, input_ids, hidden_states):
logits = self.lm_head(hidden_states)
shift_logits = logits[..., :-1, :].contiguous()
shift_logits = shift_logits.view(-1, VOCAB_SIZE)
shift_labels = self.cache_shift_labels
return self.loss_fct(shift_logits, shift_labels)
def generate_input_tensor(self):
return self.embedding(self.cached_input_ids).detach()
def fwd(self, activations):
self.trace.append("F")
if self.pp_rank == 0:
input_activations = self.generate_input_tensor()
else:
input_activations = activations
input_activations.requires_grad = True
input_activations.retain_grad()
self.input_activations[self.fwd_batch_id] = input_activations
for layer in self.layers:
input_activations = layer(
input_activations, position_ids=self.position_ids
)[0]
output_activation = input_activations
if self.pp_rank == self.pp_size - 1:
loss = self.calculate_loss(self.cached_input_ids, output_activation)
self.output_activations[self.fwd_batch_id] = loss
self.fwd_batch_id += 1
return None
else:
self.output_activations[self.fwd_batch_id] = output_activation
self.fwd_batch_id += 1
return output_activation
def bwd(self, gradient):
self.trace.append("B")
if self.pp_rank == self.pp_size - 1:
gradient = None
output_activation = self.output_activations.pop(self.bwd_batch_id)
input_activation = self.input_activations.pop(self.bwd_batch_id)
for linear in self.wrapped_linears:
for param in linear.parameters():
param.requires_grad = False
input_activation.requires_grad = True
output_activation.backward(
gradient, retain_graph=True, inputs=[input_activation]
)
if self.pp_rank == 0:
self.bwd_batch_id += 1
return None
else:
self.bwd_batch_id += 1
# assert input_activation.grad is not None
return input_activation.grad
def w(self, x):
self.trace.append("W")
for linear in self.wrapped_linears:
for param in linear.parameters():
param.requires_grad = True
for linear in self.wrapped_linears:
linear.w()
return None
def pop_trace(self):
trace = self.trace
self.trace = []
return trace
def read_input(self, inp):
# Placeholder: (batch_id, activations, targets)
self.batch_id = 0
return (None, None, None)
def echo(self, value):
return value
def generate_zbh1_dag(workers, num_microbatches):
num_workers = len(workers)
num_lead_microbatches = num_workers
with InputNode() as inp:
fwd_queues = [[] for _ in range(num_workers)]
bwd_queues = [[] for _ in range(num_workers)]
# Once a worker's counter reaches 0, it cannot execute another fwd until it
# executes a bwd first.
fwd_counter = [num_lead_microbatches - i for i in range(num_workers)]
bwd_counter = [0 for _ in range(num_workers)]
# All of the done batches.
done = []
# FWD on worker 0.
input_data = workers[0].read_input.bind(inp)
for i in range(num_microbatches):
fwd_queues[0].append(input_data)
while len(done) < num_microbatches:
for i, worker in enumerate(workers):
if fwd_counter[i] > 0 and fwd_queues[i]:
b = fwd_queues[i].pop(0)
b = worker.fwd.bind(b)
if i < num_workers - 1:
fwd_queues[i + 1].append(b)
# Use NCCL channel for communication between workers.
b.with_type_hint(
TorchTensorType(
transport=TorchTensorType.NCCL,
_shape=FORWARD_SHAPE,
_dtype=torch.float32,
_direct_return=True,
)
)
else:
bwd_queues[i].append(b)
fwd_counter[i] -= 1
elif bwd_queues[i]:
b = bwd_queues[i].pop(0)
b = worker.bwd.bind(b)
# Code changes for Zero Bubble PP
# It becames 1F1B after removing the below code block
# ++++++++++++++++++++++++++++++++++++++++++++++++
bwd_counter[i] += 1
w_offset = i
if bwd_counter[i] > w_offset:
echo_b = worker.echo.bind(b)
w = worker.w.bind(b)
if bwd_counter[i] == num_microbatches:
for _ in range(w_offset):
w = worker.w.bind(w)
else:
echo_b = None
if echo_b:
b = echo_b
# ++++++++++++++++++++++++++++++++++++++++++++++++
if i > 0:
bwd_queues[i - 1].append(b)
# Use NCCL channel for communication between workers.
b.with_type_hint(
TorchTensorType(
transport=TorchTensorType.NCCL,
_shape=BACKWARD_SHAPE,
_dtype=torch.float32,
_direct_return=True,
)
)
else:
done.append(b)
fwd_counter[i] += 1
dag = MultiOutputNode(done)
compiled_dag = dag.experimental_compile()
return compiled_dag
if __name__ == "__main__":
workers = [Worker.remote(pp_rank, NUM_WORKERS) for pp_rank in range(NUM_WORKERS)]
dag = generate_zbh1_dag(workers=workers, num_microbatches=NUM_MICROBATCHES)
for i in range(11):
if i == 1:
# first step warmup
s = time.perf_counter()
print(f"Step {i} finished:", ray.get(dag.execute(1)))
e = time.perf_counter()
print(f"Total Training Time: {e - s}s")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment