Created
June 6, 2024 16:00
-
-
Save Cyrilvallez/92f48e402aa2968c854a8128796f50c3 to your computer and use it in GitHub Desktop.
Transformers logits benchmark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import json | |
def load_json(filename: str) -> dict: | |
with open(filename, 'r') as fp: | |
data = json.load(fp) | |
out = {} | |
for k1, v1 in data.items(): | |
out[k1] = {int(k2): v2 for k2,v2 in v1.items()} | |
return out | |
mistral = 'Mistral-7B-v0.1' | |
llama2 = 'Llama-2-7b-hf' | |
llama3 = 'Meta-Llama-3-8B' | |
leg_folder = 'legacy_logits_benchmarks/' | |
pat_folder = 'patched_logits_benchmarks/' | |
res_folder = 'figures_float_casting/' | |
methods = ['contrastive search', 'greedy', 'sample', 'group beam search', 'beam sample', 'beam search'] | |
models = [llama3, llama2, mistral] | |
for strategy in methods: | |
fig, axes = plt.subplots(1, 3, sharex=True, sharey=False, figsize=(6.4*3, 5.2)) | |
for i, model in enumerate(models): | |
results_before = load_json(leg_folder + 'legacy_' + model + '.json') | |
results_after = load_json(pat_folder + 'patched_' + model + '.json') | |
memory_before = np.array(list(results_before[strategy].values())) | |
inputs_before = np.array(list(results_before[strategy].keys())) | |
sorting = np.argsort(inputs_before) | |
memory_before = memory_before[sorting] | |
inputs_before = inputs_before[sorting] | |
memory_after = np.array(list(results_after[strategy].values())) | |
inputs_after = np.array(list(results_after[strategy].keys())) | |
sorting = np.argsort(inputs_after) | |
memory_after = memory_after[sorting] | |
inputs_after = inputs_after[sorting] | |
# axes[i].plot(inputs_before, memory_before, label=f'Before') | |
# axes[i].plot(inputs_after, memory_after, label=f'After') | |
min_ = min(len(memory_after), len(memory_before)) | |
inputs = inputs_before[:min_] | |
memory_before = memory_before[:min_] | |
memory_after = memory_after[:min_] | |
axes[i].plot(inputs, memory_before / memory_after) | |
axes[i].set_xlabel('Input size') | |
# axes[i].set_ylabel('Peak Memory [GiB]') | |
axes[i].set_ylabel('Memory before / Memory after') | |
axes[i].set_title(model) | |
axes[i].grid() | |
axes[i].legend() | |
plt.suptitle(strategy) | |
plt.subplots_adjust(top=0.85) | |
fig.savefig(res_folder + strategy + '.pdf', bbox_inches='tight') | |
plt.show() | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import numpy as np | |
import os | |
import json | |
from tqdm import tqdm | |
import gc | |
import warnings | |
warnings.filterwarnings("ignore") | |
matcher = { | |
'contrastive search': { | |
'do_sample': False, | |
'top_k': 2, | |
'penalty_alpha': 0.5, | |
}, | |
'greedy': { | |
'do_sample': False, | |
}, | |
'sample': { | |
'do_sample': True, | |
'top_k': 50, | |
}, | |
'group beam search': { | |
'num_beams': 2, | |
'num_beam_groups': 2, | |
'do_sample': False, | |
'diversity_penalty': 0.5, | |
}, | |
'beam sample': { | |
'num_beams': 2, | |
'do_sample': True, | |
}, | |
'beam search': { | |
'num_beams': 2, | |
'do_sample': False, | |
}, | |
} | |
model_names = ['meta-llama/Meta-Llama-3-8B', 'mistralai/Mistral-7B-v0.1', 'meta-llama/Llama-2-7b-hf'] | |
dtypes = [torch.bfloat16, torch.bfloat16, torch.float16] | |
max_inputs = [8192, 8192, 4096] | |
res_folder = 'logits_benchmarks/' | |
if not os.path.isdir(res_folder): | |
os.mkdir(res_folder) | |
for model_name, dtype, max_input in tqdm(zip(model_names, dtypes, max_inputs), total=len(model_names), desc='Models'): | |
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='flash_attention_2', | |
torch_dtype=dtype, low_cpu_mem_usage=True).cuda(1) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
eos_token_id = model.generation_config.eos_token_id | |
if hasattr(eos_token_id, '__len__') and hasattr(eos_token_id, '__getitem__'): | |
eos_token_id = eos_token_id[0] | |
# Results | |
results_memory = {} | |
filename = res_folder + 'patched_' + model_name.split('/', 1)[1] | |
for strategy in tqdm(list(matcher.keys()), desc='strategy', leave=False): | |
results_memory[strategy] = {} | |
sizes = [int(x) for x in np.linspace(50, max_input - 50, 15)] | |
for size in sizes: | |
torch.manual_seed(1) | |
# Random tokens as input | |
input = torch.randint(1, 20000, (1, size), device=1) | |
try: | |
torch.cuda.reset_peak_memory_stats(1) | |
actual_peak = torch.cuda.max_memory_allocated(1) / 1024**3 | |
out = model.generate(input, max_new_tokens=10, min_new_tokens=10, pad_token_id=eos_token_id, | |
**matcher[strategy]) | |
memory_used = (torch.cuda.max_memory_allocated(1) / 1024**3) - actual_peak | |
results_memory[strategy][size] = memory_used | |
except RuntimeError as e: | |
if isinstance(e, torch.cuda.OutOfMemoryError): | |
break | |
else: | |
raise e | |
# Rewrite results at each iteration (so that we keep them if OOM) | |
with open(filename + '.json', 'w') as fp: | |
json.dump(results_memory, fp, indent='\t') | |
del model | |
gc.collect() | |
torch.cuda.empty_cache() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment