Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Created August 2, 2023 09:03
Show Gist options
  • Save tiandiao123/588e751b1b7416c9f45b5c9d5521636f to your computer and use it in GitHub Desktop.
Save tiandiao123/588e751b1b7416c9f45b5c9d5521636f to your computer and use it in GitHub Desktop.
import torch
import time
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from argparse import ArgumentParser
from transformers import LlamaForCausalLM, LlamaTokenizer
from inference import CaiInferenceConfig, convert_to_ds_model, recover_from_ds_model
parser = ArgumentParser()
parser.add_argument("--name", default="/data/scratch/llama-7b-hf", type=str, help="model_name")
parser.add_argument("--batch_size", default=1, type=int, help="batch size")
parser.add_argument("--dtype", default="float16", type=str, choices=["float32", "float16", "int8"], help="data-type")
parser.add_argument("--max_tokens", default=2048, type=int, help="maximum tokens used for the text-generation KV-cache")
parser.add_argument("--max_new_tokens", default=128, type=int, help="maximum new tokens to generate")
parser.add_argument("--greedy", default=False, type=bool, help="greedy generation mode")
parser.add_argument("--use_cache", default=True, type=bool, help="use cache for generation")
parser.add_argument("--test_performance", default=True, type=bool , help="enable latency, bandwidth, and throughout testing")
parser.add_argument("--local_rank", type=int, default=0, help="local rank")
parser.add_argument("--kernel_type", type=str, default="triton", choices=["torch", "ds", "triton"], help="kernel implementation")
args = parser.parse_args()
# torch.distributed.init_process_group(backend="nccl")
# local_rank = int(os.environ["LOCAL_RANK"])
def print_perf_stats(latency_set, config, warmup=3):
# trim warmup queries
latency_set = list(latency_set)
latency_set = latency_set[warmup:]
count = len(latency_set)
if count > 0:
latency_set.sort()
avg = sum(latency_set) / count
num_layers = getattr(config, "num_layers", config.num_hidden_layers)
num_parameters = num_layers * config.hidden_size * config.hidden_size * 12
if args.dtype == "float16":
num_bytes = 2
elif args.dtype == "float32":
num_bytes = 4
else:
num_bytes = 1
print("Avg Per Token Latency: {0:8.2f} ms".format(avg * 1000))
print("Avg BW: {0:8.2f} GB/s".format(1/avg * num_parameters * num_bytes / 1e9))
print("Avg flops: {0:8.2f} TFlops/s".format(1/avg * num_parameters * num_bytes * args.batch_size / 1e12))
print("Avg Throughput: tokens/s: {}".format((1000/(avg * 1000))))
# torch.cuda.set_device(0)
tokenizer = LlamaTokenizer.from_pretrained(args.name)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# tokenizer.pad_token = tokenizer.eos_token
# tokenizer.pad_token = tokenizer.eos_token
model = LlamaForCausalLM.from_pretrained(args.name, pad_token_id=tokenizer.eos_token_id)
model = model.half()
print("model config: ", model.config)
if args.kernel_type in ["ds", "triton"]:
cai_inf_config = CaiInferenceConfig(fp16=True, device=torch.cuda.current_device())
if args.kernel_type == "triton":
cai_inf_config.use_triton = True
model = convert_to_ds_model(model, cai_inf_config)
model.to(torch.cuda.current_device())
def create_text():
text = ""
for i in range(1024):
text += "a" + " "
return text
text = create_text()
inputs = [text]
input_tokens = tokenizer.batch_encode_plus(inputs, padding = True, return_tensors="pt")
input_tokens={"input_ids":torch.randint(1, 1000, (1, 1024))}
# input_tokens={"input_ids":torch.randint(1, 1000, (1, 1024)), "attention_mask":torch.ones(1, 1024)}
for t in input_tokens:
if torch.is_tensor(input_tokens[t]):
input_tokens[t] = input_tokens[t].to(torch.cuda.current_device())
# print(input_tokens[t].shape)
iters = 10 if args.test_performance else 2 #warmup
times = []
warmup=3
prof_flag = 0
generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False)
for i in range(iters):
if i >= warmup:
prof_flag=1
torch.cuda.synchronize()
start = time.time()
outputs = model.generate(**input_tokens,
**generate_kwargs)
torch.cuda.synchronize()
end = time.time()
times.append(end - start)
print(f"generation time is {times[1]} sec")
print("outputs shape ", outputs.shape)
outputs=tokenizer.batch_decode(outputs)
if args.local_rank == 0:
# for i, o in zip(inputs, outputs):
# print(f"\nin={i}\nout={o}\n{'-'*60}")
if args.test_performance:
print_perf_stats(map(lambda t: t / args.max_new_tokens, times), model.config)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment