Last active
April 21, 2023 11:57
-
-
Save maxidl/a1b0dd71a72e694531106deb1b1a2ca2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import types | |
import torch | |
from typing import List, Optional, Tuple, Union, Dict | |
import transformers | |
from transformers.modeling_outputs import BaseModelOutputWithPast | |
from transformers.utils import logging as hf_logging | |
logger = hf_logging.get_logger(__name__) | |
""" | |
make the llama model run in model parallel (pipeline parallel) mode across multiple devices | |
""" | |
def llama_model_parallel_forward( | |
self, | |
layer2device: Dict[int, torch.device], | |
input_ids: torch.LongTensor = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_values: Optional[List[torch.FloatTensor]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, BaseModelOutputWithPast]: | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
# retrieve input_ids and inputs_embeds | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") | |
elif input_ids is not None: | |
batch_size, seq_length = input_ids.shape | |
elif inputs_embeds is not None: | |
batch_size, seq_length, _ = inputs_embeds.shape | |
else: | |
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") | |
seq_length_with_past = seq_length | |
past_key_values_length = 0 | |
if past_key_values is not None: | |
past_key_values_length = past_key_values[0][0].shape[2] | |
seq_length_with_past = seq_length_with_past + past_key_values_length | |
if position_ids is None: | |
device = input_ids.device if input_ids is not None else inputs_embeds.device | |
position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device) | |
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) | |
else: | |
position_ids = position_ids.view(-1, seq_length).long() | |
if inputs_embeds is None: | |
inputs_embeds = self.embed_tokens(input_ids) | |
# embed positions | |
if attention_mask is None: | |
attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device) | |
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length) | |
hidden_states = inputs_embeds | |
if self.gradient_checkpointing and self.training: | |
if use_cache: | |
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") | |
use_cache = False | |
# decoder layers | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attns = () if output_attentions else None | |
next_decoder_cache = () if use_cache else None | |
for idx, decoder_layer in enumerate(self.layers): | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
past_key_value = past_key_values[idx] if past_key_values is not None else None | |
# move inputs to layer to correct device | |
hidden_states = hidden_states.to(layer2device[idx]) | |
attention_mask = attention_mask.to(layer2device[idx]) | |
position_ids = position_ids.to(layer2device[idx]) | |
if past_key_values is not None: | |
past_key_value = past_key_value.to(layer2device[idx]) | |
if self.gradient_checkpointing and self.training: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
# None for past_key_value | |
return module(*inputs, output_attentions, None) | |
return custom_forward | |
layer_outputs = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(decoder_layer), | |
hidden_states, | |
attention_mask, | |
position_ids, | |
None, | |
) | |
else: | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) | |
if output_attentions: | |
all_self_attns += (layer_outputs[1],) | |
# move last hidden states back to first device | |
hidden_states = hidden_states.to(layer2device[0]) | |
hidden_states = self.norm(hidden_states) | |
# add hidden states from the last decoder layer | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
next_cache = next_decoder_cache if use_cache else None | |
if not return_dict: | |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) | |
return BaseModelOutputWithPast( | |
last_hidden_state=hidden_states, | |
past_key_values=next_cache, | |
hidden_states=all_hidden_states, | |
attentions=all_self_attns, | |
) | |
def make_model_parallel_llama(model: transformers.LlamaModel, devices: List[torch.device]): | |
num_layers = len(model.layers) | |
num_devices = len(devices) | |
layer2device = { | |
n.item(): devices[i] for i, device_layers in enumerate(torch.arange(0, num_layers).chunk(num_devices)) for n in device_layers | |
} | |
for i, layer in enumerate(model.layers): | |
layer.to(layer2device[i]) | |
torch.cuda.empty_cache() # clear cache to free memory after moving each layer (useful if model is on another gpu device already) | |
model.forward = types.MethodType(partial(llama_model_parallel_forward, layer2device=layer2device), model) | |
return model | |
# model = transformers.LlamaForCausalLM.from_pretrained( | |
# model_args.model_name_or_path, | |
# torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32, | |
# ) | |
# model.train() | |
# print(model.dtype) | |
# devices = [torch.device("cuda:0"), torch.device("cuda:1")] | |
# # test model on 1 gpu | |
# model.to(devices[0]) | |
# with torch.inference_mode(): | |
# orig_outputs = model(torch.ones((1, 128), dtype=torch.long).to(devices[0])).logits.cpu() | |
# torch.cuda.empty_cache() | |
# # test model on 2 gpus | |
# from model_parallel_llama import make_model_parallel_llama | |
# model.model = make_model_parallel_llama(model.model, devices) | |
# with torch.inference_mode(): | |
# mp_forward_outputs = model(torch.ones((1, 128), dtype=torch.long).to(devices[0])).logits.cpu() | |
# torch.cuda.empty_cache() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment