Created
August 7, 2023 08:36
-
-
Save tiandiao123/3e7a87010b55ed997e13c332a3573ba0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import time | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
from argparse import ArgumentParser | |
from transformers import LlamaForCausalLM, LlamaTokenizer | |
from inference import CaiInferenceConfig, convert_to_ds_model, recover_from_ds_model | |
from torch.profiler import profile, record_function, ProfilerActivity | |
from types import MethodType | |
from typing import Optional, Sequence, Tuple, Union | |
import torch | |
import torch.nn.functional as F | |
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, LlamaAttention, LlamaModel, LlamaForCausalLM | |
from einops import rearrange | |
from colossalai.logging import get_dist_logger | |
from inference.policy.attention_helper.llama_attention import _forward_v1 | |
from inference.policy.attention_helper.llama2_attention import _forward_v2 | |
logger = get_dist_logger() | |
parser = ArgumentParser() | |
parser.add_argument("--name", default="/data3/users/lcjt/projs/chinese-llama2/pretrained/llama/llama-2-7b-hf", type=str, help="model_name") | |
parser.add_argument("--batch_size", default=1, type=int, help="batch size") | |
parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type") | |
parser.add_argument("--max_tokens", default=2048, type=int, help="maximum tokens used for the text-generation KV-cache") | |
parser.add_argument("--max_new_tokens", default=128, type=int, help="maximum new tokens to generate") | |
parser.add_argument("--greedy", default=False, type=bool, help="greedy generation mode") | |
parser.add_argument("--use_cache", default=True, type=bool, help="use cache for generation") | |
parser.add_argument("--test_performance", default=True, type=bool , help="enable latency, bandwidth, and throughout testing") | |
parser.add_argument("--local_rank", type=int, default=0, help="local rank") | |
parser.add_argument("--kernel_type", type=str, default="triton", choices=["torch", "ds", "triton"], help="kernel implementation") | |
args = parser.parse_args() | |
_llama_flash_attention_forward = _forward_v2 | |
def self_defined_tokens(tokenizer): | |
text = "how is weather today? I want to know the weather of beijing. " | |
inputs = [text] | |
input_tokens = tokenizer.batch_encode_plus(inputs, padding = True, return_tensors="pt") | |
return input_tokens | |
def print_perf_stats(latency_set, config, warmup=3): | |
# trim warmup queries | |
latency_set = list(latency_set) | |
latency_set = latency_set[warmup:] | |
count = len(latency_set) | |
if count > 0: | |
latency_set.sort() | |
avg = sum(latency_set) / count | |
num_layers = getattr(config, "num_layers", config.num_hidden_layers) | |
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 | |
if args.dtype == "float16": | |
num_bytes = 2 | |
elif args.dtype == "float32": | |
num_bytes = 4 | |
else: | |
num_bytes = 1 | |
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) | |
print("Avg BW: {0:8.2f} GB/s".format(1/avg * num_parameters * num_bytes / 1e9)) | |
print("Avg flops: {0:8.2f} TFlops/s".format(1/avg * num_parameters * num_bytes * args.batch_size / 1e12)) | |
print("Avg Throughput: tokens/s: {}".format((1000/(avg * 1000)))) | |
def _prepare_decoder_flash_attention_mask(self: LlamaModel, | |
attention_mask: torch.Tensor, | |
input_shape: Union[torch.Size, Sequence[int]], | |
inputs_embeds: torch.Tensor, | |
past_key_values_length: int = 0) -> torch.Tensor: | |
""" | |
Prepare attention mask for decoder-only LLM (e.g., Llama) when using flash-attn. | |
Args: | |
attention_mask (`torch.Tensor`): | |
A (bsz, max_len) shape tensor represents mini-batch 2D mask. | |
input_shape: | |
inputs_embeds: | |
past_key_values_length: | |
Returns: | |
(`torch.Tensor`): | |
A mask tensor of shape (bsz, max_len) | |
""" | |
return attention_mask | |
def replace_flash_attention_for_llama(model: Union[torch.nn.Module, LlamaForCausalLM]) -> None: | |
for module in model.modules(): | |
if isinstance(module, LlamaAttention) is True: | |
module.forward = MethodType(_llama_flash_attention_forward, module) | |
logger.info("Replace `LlamaAttention.forward` method.") | |
if isinstance(module, LlamaModel) is True: | |
# replace attention mask computation. | |
module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_flash_attention_mask, module) | |
logger.info("Replace `LlamaModel._prepare_decoder_attention_mask` method.") | |
def test(use_self_defined_input = False): | |
tokenizer = LlamaTokenizer.from_pretrained(args.name) | |
tokenizer.pad_token_id = tokenizer.unk_token_id | |
model = LlamaForCausalLM.from_pretrained(args.name, pad_token_id=tokenizer.eos_token_id) | |
model = model.half() | |
print("model config: ", model.config) | |
replace_flash_attention_for_llama(model) | |
model.to(torch.cuda.current_device()) | |
if use_self_defined_input is False: | |
input_tokens={"input_ids":torch.randint(1, 1000, (1, 1024))} | |
else: | |
input_tokens = self_defined_tokens(tokenizer) | |
input_len = 0 | |
for t in input_tokens: | |
if torch.is_tensor(input_tokens[t]): | |
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) | |
# print(input_tokens[t].shape) | |
input_len = input_tokens[t].shape[1] | |
iters = 10 if args.test_performance else 2 #warmup | |
print("input token length is " + str(input_len)) | |
times = [] | |
warmup=3 | |
prof_flag = 0 | |
generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False) | |
for i in range(iters): | |
if i >= warmup: | |
prof_flag=1 | |
torch.cuda.synchronize() | |
start = time.time() | |
outputs = model.generate(**input_tokens, | |
**generate_kwargs, early_stopping=False) | |
torch.cuda.synchronize() | |
end = time.time() | |
num_tokens_generation = outputs.shape[1] - input_len | |
print(num_tokens_generation) | |
print(f"generation time is {(end - start) * 1000} ms") | |
time_spend = (end-start)/num_tokens_generation | |
times.append(time_spend) | |
print("outputs shape ", outputs.shape) | |
outputs=tokenizer.batch_decode(outputs) | |
print(outputs) | |
if args.local_rank == 0: | |
if args.test_performance: | |
print_perf_stats(times, model.config) | |
with profile(activities=[ | |
ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: | |
with record_function("model_inference"): | |
torch.cuda.synchronize() | |
outputs = model.generate(**input_tokens, | |
**generate_kwargs) | |
torch.cuda.synchronize() | |
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) | |
if __name__ == "__main__": | |
test(False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment