Last active
February 14, 2024 13:11
-
-
Save younesbelkada/ada0d9c2c48ab034486dbaaf95d29fae to your computer and use it in GitHub Desktop.
Benchmark Mistral 7b model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from mistral.cache import RotatingBufferCache | |
import torch | |
import inspect | |
from typing import List | |
from pathlib import Path | |
from mistral.model import Transformer | |
from mistral.tokenizer import Tokenizer | |
def get_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--model-path", | |
type=str, | |
help="Model path", | |
required=True | |
) | |
parser.add_argument( | |
"--max-new-tokens", | |
type=int, | |
default=512, | |
help="Maximum number of tokens to generate", | |
) | |
parser.add_argument( | |
"--num-batches", | |
type=int, | |
default=1, | |
help="Number of times to run the experiments", | |
) | |
return parser | |
@torch.inference_mode() | |
def generate(prompts: List[str], model: Transformer, tokenizer: Tokenizer, *, max_tokens: int, chunk_size: int = None, temperature: float = 0.7): | |
model = model.eval() | |
B, V = len(prompts), model.args.vocab_size | |
device = torch.device("cuda:0") | |
# Tokenize | |
encoded_prompts = [tokenizer.encode(prompt, bos=True) for prompt in prompts] | |
seqlens = [len(x) for x in encoded_prompts] | |
# Cache | |
cache_window = min(model.args.sliding_window, max(seqlens) + max_tokens) | |
cache = RotatingBufferCache(model.args.n_layers, model.args.max_batch_size, cache_window, model.args.n_kv_heads, model.args.head_dim) | |
cache.to(device=model.device, dtype=model.dtype) | |
cache.reset() | |
last_token_prelogits = None | |
# One chunk if size not specified | |
max_prompt_len = max(seqlens) | |
if chunk_size is None: | |
chunk_size = max_prompt_len | |
start_event = torch.cuda.Event(enable_timing=True) | |
end_event = torch.cuda.Event(enable_timing=True) | |
torch.cuda.synchronize() | |
start_event.record() | |
# Encode prompt by chunks | |
for s in range(0, max_prompt_len, chunk_size): | |
prompt_chunks = [p[s:s+chunk_size] for p in encoded_prompts] | |
assert all(len(p) > 0 for p in prompt_chunks) | |
prelogits = model.forward( | |
torch.tensor(sum(prompt_chunks, []), device=model.device, dtype=torch.long), | |
cache, | |
seqlens=[len(p) for p in prompt_chunks] | |
) | |
offset = 0 | |
for i_seq, sequence in enumerate(prompt_chunks): | |
offset += len(sequence) | |
last_token_prelogits = prelogits.index_select(0, torch.tensor([len(p) for p in prompt_chunks], device=prelogits.device).cumsum(dim=0) - 1) | |
assert last_token_prelogits.shape == (B, V) | |
# decode | |
generated_tokens = [] | |
for i_token in range(max_tokens): | |
next_token = torch.argmax(last_token_prelogits, dim=-1) | |
generated_tokens.append(next_token[:, None]) | |
last_token_prelogits = model.forward(next_token, cache, seqlens=[1] * len(prompts)) | |
assert last_token_prelogits.shape == (B, V) | |
end_event.record() | |
torch.cuda.synchronize() | |
latency_s = start_event.elapsed_time(end_event) * 1e-3 | |
max_memory = torch.cuda.max_memory_allocated(device) | |
generated_words = [] | |
if generated_tokens: | |
generated_tokens = torch.cat(generated_tokens, 1) | |
for i, x in enumerate(encoded_prompts): | |
generated_words.append(tokenizer.decode(x + generated_tokens[i].tolist())) | |
return generated_words, (latency_s, max_memory) | |
def get_text(): | |
# This generates ~11K tokens | |
# Modify this method accordingly to try out different scenarios | |
text = ["""Summarize the following news article in detail:\n""" * 1000] | |
return text | |
def benchmark(model_path: str, max_tokens: int = 35, num_batches: int = 1): | |
tokenizer = Tokenizer(str(Path(model_path) / "tokenizer.model")) | |
text = get_text() | |
transformer = Transformer.from_folder(Path(model_path), max_batch_size=len(text)) | |
# Check if we are effecitively using mem efficient attention from xformers | |
assert "memory_efficient_attention" in inspect.getsource(transformer.layers[0].attention.forward), "You did not loaded the optimized model" | |
assert transformer.dtype == torch.float16 | |
# Warmup | |
_ = generate( | |
["hi"], | |
transformer, | |
tokenizer, | |
max_tokens=10, | |
) | |
total_latency = 0 | |
total_max_memory = 0 | |
# Retrieve generation stats | |
for _ in range(num_batches): | |
_, stats = generate( | |
text, | |
transformer, | |
tokenizer, | |
max_tokens=max_tokens, | |
) | |
latency_s, max_memory = stats | |
total_latency += latency_s | |
total_max_memory += total_max_memory | |
mean_latency = total_latency / num_batches | |
print(f"Mean Latency: {mean_latency}") | |
print(f"{max_tokens / mean_latency} tokens / s") | |
print(f"Mean Max allocated memory: {max_memory / num_batches}") | |
if __name__ == "__main__": | |
parser = get_parser() | |
args = parser.parse_args() | |
benchmark(args.model_path, args.max_new_tokens) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment