Last active
October 9, 2024 03:48
-
-
Save woshiyyya/d5954d23a444976f28a4366f2ac46676 to your computer and use it in GitHub Desktop.
Use adag to train a llama2-7b model with zero bubble pipeline parallel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def generate_zbh1_dag(workers, num_microbatches): | |
num_workers = len(workers) | |
num_lead_microbatches = num_workers | |
with InputNode() as inp: | |
fwd_queues = [[] for _ in range(num_workers)] | |
bwd_queues = [[] for _ in range(num_workers)] | |
# Once a worker's counter reaches 0, it cannot execute another fwd until it | |
# executes a bwd first. | |
fwd_counter = [num_lead_microbatches - i for i in range(num_workers)] | |
bwd_counter = [0 for _ in range(num_workers)] | |
# All of the done batches. | |
done = [] | |
# FWD on worker 0. | |
input_data = workers[0].read_input.bind(inp) | |
for i in range(num_microbatches): | |
fwd_queues[0].append(input_data) | |
while len(done) < num_microbatches: | |
for i, worker in enumerate(workers): | |
if fwd_counter[i] > 0 and fwd_queues[i]: | |
b = fwd_queues[i].pop(0) | |
b = worker.fwd.bind(b) | |
if i < num_workers - 1: | |
fwd_queues[i + 1].append(b) | |
# Use NCCL channel for communication between workers. | |
b.with_type_hint( | |
TorchTensorType( | |
transport=TorchTensorType.NCCL, | |
_shape=FORWARD_SHAPE, | |
_dtype=torch.float32, | |
_direct_return=True, | |
) | |
) | |
else: | |
bwd_queues[i].append(b) | |
fwd_counter[i] -= 1 | |
elif bwd_queues[i]: | |
b = bwd_queues[i].pop(0) | |
b = worker.bwd.bind(b) | |
# Code changes for Zero Bubble PP | |
# ++++++++++++++++++++++++++++++++++++++++++++++++ | |
bwd_counter[i] += 1 | |
w_offset = i | |
if bwd_counter[i] > w_offset: | |
echo_b = worker.echo.bind(b) | |
w = worker.w.bind(b) | |
if bwd_counter[i] == num_microbatches: | |
for _ in range(w_offset): | |
w = worker.w.bind(w) | |
else: | |
echo_b = None | |
if echo_b: | |
b = echo_b | |
# ++++++++++++++++++++++++++++++++++++++++++++++++ | |
if i > 0: | |
bwd_queues[i - 1].append(b) | |
# Use NCCL channel for communication between workers. | |
b.with_type_hint( | |
TorchTensorType( | |
transport=TorchTensorType.NCCL, | |
_shape=BACKWARD_SHAPE, | |
_dtype=torch.float32, | |
_direct_return=True, | |
) | |
) | |
else: | |
done.append(b) | |
fwd_counter[i] += 1 | |
dag = MultiOutputNode(done) | |
compiled_dag = dag.experimental_compile() | |
return compiled_dag |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
import torch.nn as nn | |
from collections import defaultdict | |
from pytorch_lightning import seed_everything | |
from transformers import LlamaForCausalLM | |
import ray | |
import ray.cluster_utils | |
from ray.dag import InputNode, MultiOutputNode | |
from ray.experimental.channel.torch_tensor_type import TorchTensorType | |
NUM_WORKERS = 8 | |
NUM_MICROBATCHES = 16 | |
BATCH_SIZE = 3 | |
SEQ_LENGTH = 1024 | |
HIDDEN_SIZE = 4096 | |
VOCAB_SIZE = 32000 | |
FORWARD_SHAPE = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) | |
BACKWARD_SHAPE = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) | |
class WrappedLinear(nn.Module): | |
def __init__(self, linear, pp_rank, pp_size, num_batches=NUM_MICROBATCHES) -> None: | |
super().__init__() | |
self.linear = linear | |
self.output_activations = [] | |
self.output_grads = [] | |
self.is_w = False | |
self.pp_rank = pp_rank | |
self.pp_size = pp_size | |
self.num_batches = num_batches | |
def forward(self, x): | |
output = self.linear(x) | |
output.register_hook(self.save_gradients) | |
# Cache output activations for W operation | |
self.output_activations.append(output) | |
return output | |
def w(self): | |
self.is_w = True | |
output_activation = self.output_activations.pop(0) | |
output_grad = self.output_grads.pop(0) | |
output_activation.backward(output_grad, inputs=list(self.parameters())) | |
self.is_w = False | |
def save_gradients(self, grad): | |
# Cache the output gradient when doing backward | |
if not self.is_w: | |
self.output_grads.append(grad) | |
return grad | |
@ray.remote(num_gpus=1) | |
class Worker: | |
def __init__(self, pp_rank, pp_size): | |
seed_everything(420) | |
self.pp_rank = pp_rank | |
self.pp_size = pp_size | |
self.trace = [] | |
self.timer = defaultdict(list) | |
llama = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") | |
# assert llama.config.num_hidden_layers % pp_size == 0 | |
layers_per_chunk = llama.config.num_hidden_layers // pp_size | |
# Partition llama model | |
self.wrapped_linears = [] | |
if pp_rank == 0: | |
self.embedding = llama.model.embed_tokens.cuda() | |
for param in self.embedding.parameters(): | |
param.requires_grad = False | |
if pp_rank == pp_size - 1: | |
self.lm_head = llama.lm_head.cuda() | |
self.layers = llama.model.layers[ | |
layers_per_chunk * pp_rank : layers_per_chunk * (pp_rank + 1) | |
] | |
# Wrap linear layers for disaggregated BWD and W operations. | |
self.num_replaced = 0 | |
for layer in self.layers: | |
named_modules = dict(layer.named_modules()) | |
for name, module in named_modules.items(): | |
named_children = dict(module.named_children()) | |
for child_name, child in named_children.items(): | |
if isinstance(child, nn.Linear): | |
wrapped_child = WrappedLinear(child, self.pp_rank, self.pp_size) | |
setattr(module, child_name, wrapped_child) | |
self.num_replaced += 1 | |
self.wrapped_linears.append(wrapped_child) | |
layer.cuda() | |
self.loss_fct = nn.CrossEntropyLoss() | |
self.input_activations = dict() | |
self.output_activations = dict() | |
self.fwd_batch_id = 0 | |
self.bwd_batch_id = 0 | |
self.w_batch_id = 0 | |
# Cached dummy inputs | |
self.cached_input_ids = torch.randint( | |
0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LENGTH) | |
).cuda() | |
self.cache_shift_labels = self.cached_input_ids[..., 1:].contiguous().view(-1) | |
self.position_ids = torch.arange(0, SEQ_LENGTH).unsqueeze(0).cuda() | |
def calculate_loss(self, input_ids, hidden_states): | |
logits = self.lm_head(hidden_states) | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_logits = shift_logits.view(-1, VOCAB_SIZE) | |
shift_labels = self.cache_shift_labels | |
return self.loss_fct(shift_logits, shift_labels) | |
def generate_input_tensor(self): | |
return self.embedding(self.cached_input_ids).detach() | |
def fwd(self, activations): | |
self.trace.append("F") | |
if self.pp_rank == 0: | |
input_activations = self.generate_input_tensor() | |
else: | |
input_activations = activations | |
input_activations.requires_grad = True | |
input_activations.retain_grad() | |
self.input_activations[self.fwd_batch_id] = input_activations | |
for layer in self.layers: | |
input_activations = layer( | |
input_activations, position_ids=self.position_ids | |
)[0] | |
output_activation = input_activations | |
if self.pp_rank == self.pp_size - 1: | |
loss = self.calculate_loss(self.cached_input_ids, output_activation) | |
self.output_activations[self.fwd_batch_id] = loss | |
self.fwd_batch_id += 1 | |
return None | |
else: | |
self.output_activations[self.fwd_batch_id] = output_activation | |
self.fwd_batch_id += 1 | |
return output_activation | |
def bwd(self, gradient): | |
self.trace.append("B") | |
if self.pp_rank == self.pp_size - 1: | |
gradient = None | |
output_activation = self.output_activations.pop(self.bwd_batch_id) | |
input_activation = self.input_activations.pop(self.bwd_batch_id) | |
for linear in self.wrapped_linears: | |
for param in linear.parameters(): | |
param.requires_grad = False | |
input_activation.requires_grad = True | |
output_activation.backward( | |
gradient, retain_graph=True, inputs=[input_activation] | |
) | |
if self.pp_rank == 0: | |
self.bwd_batch_id += 1 | |
return None | |
else: | |
self.bwd_batch_id += 1 | |
# assert input_activation.grad is not None | |
return input_activation.grad | |
def w(self, x): | |
self.trace.append("W") | |
for linear in self.wrapped_linears: | |
for param in linear.parameters(): | |
param.requires_grad = True | |
for linear in self.wrapped_linears: | |
linear.w() | |
return None | |
def pop_trace(self): | |
trace = self.trace | |
self.trace = [] | |
return trace | |
def read_input(self, inp): | |
# Placeholder: (batch_id, activations, targets) | |
self.batch_id = 0 | |
return (None, None, None) | |
def echo(self, value): | |
return value | |
def generate_zbh1_dag(workers, num_microbatches): | |
num_workers = len(workers) | |
num_lead_microbatches = num_workers | |
with InputNode() as inp: | |
fwd_queues = [[] for _ in range(num_workers)] | |
bwd_queues = [[] for _ in range(num_workers)] | |
# Once a worker's counter reaches 0, it cannot execute another fwd until it | |
# executes a bwd first. | |
fwd_counter = [num_lead_microbatches - i for i in range(num_workers)] | |
bwd_counter = [0 for _ in range(num_workers)] | |
# All of the done batches. | |
done = [] | |
# FWD on worker 0. | |
input_data = workers[0].read_input.bind(inp) | |
for i in range(num_microbatches): | |
fwd_queues[0].append(input_data) | |
while len(done) < num_microbatches: | |
for i, worker in enumerate(workers): | |
if fwd_counter[i] > 0 and fwd_queues[i]: | |
b = fwd_queues[i].pop(0) | |
b = worker.fwd.bind(b) | |
if i < num_workers - 1: | |
fwd_queues[i + 1].append(b) | |
# Use NCCL channel for communication between workers. | |
b.with_type_hint( | |
TorchTensorType( | |
transport=TorchTensorType.NCCL, | |
_shape=FORWARD_SHAPE, | |
_dtype=torch.float32, | |
_direct_return=True, | |
) | |
) | |
else: | |
bwd_queues[i].append(b) | |
fwd_counter[i] -= 1 | |
elif bwd_queues[i]: | |
b = bwd_queues[i].pop(0) | |
b = worker.bwd.bind(b) | |
# Code changes for Zero Bubble PP | |
# It becames 1F1B after removing the below code block | |
# ++++++++++++++++++++++++++++++++++++++++++++++++ | |
bwd_counter[i] += 1 | |
w_offset = i | |
if bwd_counter[i] > w_offset: | |
echo_b = worker.echo.bind(b) | |
w = worker.w.bind(b) | |
if bwd_counter[i] == num_microbatches: | |
for _ in range(w_offset): | |
w = worker.w.bind(w) | |
else: | |
echo_b = None | |
if echo_b: | |
b = echo_b | |
# ++++++++++++++++++++++++++++++++++++++++++++++++ | |
if i > 0: | |
bwd_queues[i - 1].append(b) | |
# Use NCCL channel for communication between workers. | |
b.with_type_hint( | |
TorchTensorType( | |
transport=TorchTensorType.NCCL, | |
_shape=BACKWARD_SHAPE, | |
_dtype=torch.float32, | |
_direct_return=True, | |
) | |
) | |
else: | |
done.append(b) | |
fwd_counter[i] += 1 | |
dag = MultiOutputNode(done) | |
compiled_dag = dag.experimental_compile() | |
return compiled_dag | |
if __name__ == "__main__": | |
workers = [Worker.remote(pp_rank, NUM_WORKERS) for pp_rank in range(NUM_WORKERS)] | |
dag = generate_zbh1_dag(workers=workers, num_microbatches=NUM_MICROBATCHES) | |
for i in range(11): | |
if i == 1: | |
# first step warmup | |
s = time.perf_counter() | |
print(f"Step {i} finished:", ray.get(dag.execute(1))) | |
e = time.perf_counter() | |
print(f"Total Training Time: {e - s}s") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment