Skip to content

Instantly share code, notes, and snippets.

@Cyrilvallez
Created June 6, 2024 16:00
Show Gist options
  • Save Cyrilvallez/92f48e402aa2968c854a8128796f50c3 to your computer and use it in GitHub Desktop.
Save Cyrilvallez/92f48e402aa2968c854a8128796f50c3 to your computer and use it in GitHub Desktop.
Transformers logits benchmark
import numpy as np
import matplotlib.pyplot as plt
import json
def load_json(filename: str) -> dict:
with open(filename, 'r') as fp:
data = json.load(fp)
out = {}
for k1, v1 in data.items():
out[k1] = {int(k2): v2 for k2,v2 in v1.items()}
return out
mistral = 'Mistral-7B-v0.1'
llama2 = 'Llama-2-7b-hf'
llama3 = 'Meta-Llama-3-8B'
leg_folder = 'legacy_logits_benchmarks/'
pat_folder = 'patched_logits_benchmarks/'
res_folder = 'figures_float_casting/'
methods = ['contrastive search', 'greedy', 'sample', 'group beam search', 'beam sample', 'beam search']
models = [llama3, llama2, mistral]
for strategy in methods:
fig, axes = plt.subplots(1, 3, sharex=True, sharey=False, figsize=(6.4*3, 5.2))
for i, model in enumerate(models):
results_before = load_json(leg_folder + 'legacy_' + model + '.json')
results_after = load_json(pat_folder + 'patched_' + model + '.json')
memory_before = np.array(list(results_before[strategy].values()))
inputs_before = np.array(list(results_before[strategy].keys()))
sorting = np.argsort(inputs_before)
memory_before = memory_before[sorting]
inputs_before = inputs_before[sorting]
memory_after = np.array(list(results_after[strategy].values()))
inputs_after = np.array(list(results_after[strategy].keys()))
sorting = np.argsort(inputs_after)
memory_after = memory_after[sorting]
inputs_after = inputs_after[sorting]
# axes[i].plot(inputs_before, memory_before, label=f'Before')
# axes[i].plot(inputs_after, memory_after, label=f'After')
min_ = min(len(memory_after), len(memory_before))
inputs = inputs_before[:min_]
memory_before = memory_before[:min_]
memory_after = memory_after[:min_]
axes[i].plot(inputs, memory_before / memory_after)
axes[i].set_xlabel('Input size')
# axes[i].set_ylabel('Peak Memory [GiB]')
axes[i].set_ylabel('Memory before / Memory after')
axes[i].set_title(model)
axes[i].grid()
axes[i].legend()
plt.suptitle(strategy)
plt.subplots_adjust(top=0.85)
fig.savefig(res_folder + strategy + '.pdf', bbox_inches='tight')
plt.show()
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
import os
import json
from tqdm import tqdm
import gc
import warnings
warnings.filterwarnings("ignore")
matcher = {
'contrastive search': {
'do_sample': False,
'top_k': 2,
'penalty_alpha': 0.5,
},
'greedy': {
'do_sample': False,
},
'sample': {
'do_sample': True,
'top_k': 50,
},
'group beam search': {
'num_beams': 2,
'num_beam_groups': 2,
'do_sample': False,
'diversity_penalty': 0.5,
},
'beam sample': {
'num_beams': 2,
'do_sample': True,
},
'beam search': {
'num_beams': 2,
'do_sample': False,
},
}
model_names = ['meta-llama/Meta-Llama-3-8B', 'mistralai/Mistral-7B-v0.1', 'meta-llama/Llama-2-7b-hf']
dtypes = [torch.bfloat16, torch.bfloat16, torch.float16]
max_inputs = [8192, 8192, 4096]
res_folder = 'logits_benchmarks/'
if not os.path.isdir(res_folder):
os.mkdir(res_folder)
for model_name, dtype, max_input in tqdm(zip(model_names, dtypes, max_inputs), total=len(model_names), desc='Models'):
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='flash_attention_2',
torch_dtype=dtype, low_cpu_mem_usage=True).cuda(1)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
eos_token_id = model.generation_config.eos_token_id
if hasattr(eos_token_id, '__len__') and hasattr(eos_token_id, '__getitem__'):
eos_token_id = eos_token_id[0]
# Results
results_memory = {}
filename = res_folder + 'patched_' + model_name.split('/', 1)[1]
for strategy in tqdm(list(matcher.keys()), desc='strategy', leave=False):
results_memory[strategy] = {}
sizes = [int(x) for x in np.linspace(50, max_input - 50, 15)]
for size in sizes:
torch.manual_seed(1)
# Random tokens as input
input = torch.randint(1, 20000, (1, size), device=1)
try:
torch.cuda.reset_peak_memory_stats(1)
actual_peak = torch.cuda.max_memory_allocated(1) / 1024**3
out = model.generate(input, max_new_tokens=10, min_new_tokens=10, pad_token_id=eos_token_id,
**matcher[strategy])
memory_used = (torch.cuda.max_memory_allocated(1) / 1024**3) - actual_peak
results_memory[strategy][size] = memory_used
except RuntimeError as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
break
else:
raise e
# Rewrite results at each iteration (so that we keep them if OOM)
with open(filename + '.json', 'w') as fp:
json.dump(results_memory, fp, indent='\t')
del model
gc.collect()
torch.cuda.empty_cache()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment