Created
August 2, 2023 09:03
-
-
Save tiandiao123/588e751b1b7416c9f45b5c9d5521636f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import time | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
from argparse import ArgumentParser | |
from transformers import LlamaForCausalLM, LlamaTokenizer | |
from inference import CaiInferenceConfig, convert_to_ds_model, recover_from_ds_model | |
parser = ArgumentParser() | |
parser.add_argument("--name", default="/data/scratch/llama-7b-hf", type=str, help="model_name") | |
parser.add_argument("--batch_size", default=1, type=int, help="batch size") | |
parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type") | |
parser.add_argument("--max_tokens", default=2048, type=int, help="maximum tokens used for the text-generation KV-cache") | |
parser.add_argument("--max_new_tokens", default=128, type=int, help="maximum new tokens to generate") | |
parser.add_argument("--greedy", default=False, type=bool, help="greedy generation mode") | |
parser.add_argument("--use_cache", default=True, type=bool, help="use cache for generation") | |
parser.add_argument("--test_performance", default=True, type=bool , help="enable latency, bandwidth, and throughout testing") | |
parser.add_argument("--local_rank", type=int, default=0, help="local rank") | |
parser.add_argument("--kernel_type", type=str, default="triton", choices=["torch", "ds", "triton"], help="kernel implementation") | |
args = parser.parse_args() | |
# torch.distributed.init_process_group(backend="nccl") | |
# local_rank = int(os.environ["LOCAL_RANK"]) | |
def print_perf_stats(latency_set, config, warmup=3): | |
# trim warmup queries | |
latency_set = list(latency_set) | |
latency_set = latency_set[warmup:] | |
count = len(latency_set) | |
if count > 0: | |
latency_set.sort() | |
avg = sum(latency_set) / count | |
num_layers = getattr(config, "num_layers", config.num_hidden_layers) | |
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12 | |
if args.dtype == "float16": | |
num_bytes = 2 | |
elif args.dtype == "float32": | |
num_bytes = 4 | |
else: | |
num_bytes = 1 | |
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000)) | |
print("Avg BW: {0:8.2f} GB/s".format(1/avg * num_parameters * num_bytes / 1e9)) | |
print("Avg flops: {0:8.2f} TFlops/s".format(1/avg * num_parameters * num_bytes * args.batch_size / 1e12)) | |
print("Avg Throughput: tokens/s: {}".format((1000/(avg * 1000)))) | |
# torch.cuda.set_device(0) | |
tokenizer = LlamaTokenizer.from_pretrained(args.name) | |
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
# tokenizer.pad_token = tokenizer.eos_token | |
# tokenizer.pad_token = tokenizer.eos_token | |
model = LlamaForCausalLM.from_pretrained(args.name, pad_token_id=tokenizer.eos_token_id) | |
model = model.half() | |
print("model config: ", model.config) | |
if args.kernel_type in ["ds", "triton"]: | |
cai_inf_config = CaiInferenceConfig(fp16=True, device=torch.cuda.current_device()) | |
if args.kernel_type == "triton": | |
cai_inf_config.use_triton = True | |
model = convert_to_ds_model(model, cai_inf_config) | |
model.to(torch.cuda.current_device()) | |
def create_text(): | |
text = "" | |
for i in range(1024): | |
text += "a" + " " | |
return text | |
text = create_text() | |
inputs = [text] | |
input_tokens = tokenizer.batch_encode_plus(inputs, padding = True, return_tensors="pt") | |
input_tokens={"input_ids":torch.randint(1, 1000, (1, 1024))} | |
# input_tokens={"input_ids":torch.randint(1, 1000, (1, 1024)), "attention_mask":torch.ones(1, 1024)} | |
for t in input_tokens: | |
if torch.is_tensor(input_tokens[t]): | |
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device()) | |
# print(input_tokens[t].shape) | |
iters = 10 if args.test_performance else 2 #warmup | |
times = [] | |
warmup=3 | |
prof_flag = 0 | |
generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False) | |
for i in range(iters): | |
if i >= warmup: | |
prof_flag=1 | |
torch.cuda.synchronize() | |
start = time.time() | |
outputs = model.generate(**input_tokens, | |
**generate_kwargs) | |
torch.cuda.synchronize() | |
end = time.time() | |
times.append(end - start) | |
print(f"generation time is {times[1]} sec") | |
print("outputs shape ", outputs.shape) | |
outputs=tokenizer.batch_decode(outputs) | |
if args.local_rank == 0: | |
# for i, o in zip(inputs, outputs): | |
# print(f"\nin={i}\nout={o}\n{'-'*60}") | |
if args.test_performance: | |
print_perf_stats(map(lambda t: t / args.max_new_tokens, times), model.config) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment