Skip to content

Instantly share code, notes, and snippets.

@Cyrilvallez
Last active May 8, 2024 17:54
Show Gist options
  • Save Cyrilvallez/ce1adfad1d561c1e8dc92666ab5a9e8c to your computer and use it in GitHub Desktop.
Save Cyrilvallez/ce1adfad1d561c1e8dc92666ab5a9e8c to your computer and use it in GitHub Desktop.
# BENCHMARK 2
# SECOND BENCHMARK OF ALL GENERATION METHODS
from transformers import AutoTokenizer, AutoModelForCausalLM, PhrasalConstraint
import torch
import time
from tqdm import tqdm
import numpy as np
import json
import gc
# model_name = 'meta-llama/Llama-2-7b-hf'
model_name = 'mistralai/Mistral-7B-v0.1'
dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='flash_attention_2',
torch_dtype=dtype, low_cpu_mem_usage=True).cuda(1)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
assistant = 'gpt2-medium'
assistant = AutoModelForCausalLM.from_pretrained(assistant, torch_dtype=torch.float32,
low_cpu_mem_usage=True).cuda(1)
# Random generated text
LARGE_TEXT = """Title: Monkeys: Nature's Pranksters, Social Geniuses, and Ecological Wonders
Introduction
Monkeys, the charismatic and diverse members of the primate order, have long held a special place in the annals of our fascination with the animal kingdom. With their playful antics, astonishing intelligence, and complex social structures, they serve as a source of both joy and profound scientific inquiry. In this comprehensive exploration, we embark on a deep dive into the world of monkeys, spanning their evolutionary history, classifications, ecological roles, social dynamics, communication methods, and the pressing need for conservation. These captivating creatures offer insights into the intricacies of the natural world, our own evolutionary heritage, and the urgent importance of preserving biodiversity.
I. Evolutionary Origins
To understand the world of monkeys, we must embark on a journey through their evolutionary past, a tapestry that stretches back millions of years. Monkeys are part of the grand order of Primates, and their lineage is interwoven with the broader history of these remarkable mammals.
A. Primate Origins
The story of primates, including monkeys, begins around 60 million years ago. At that time, the world was a vastly different place, dominated by the reign of dinosaurs. It was during this period of Earth's history that the first primates, known as prosimians, emerged. These small, tree-dwelling mammals exhibited several characteristics that would become hallmarks of all primates: grasping hands and feet, forward-facing eyes for stereoscopic vision, and an enlarged brain relative to body size. These adaptations suited them for life in the trees, where they foraged for insects and fruits.
B. The Emergence of Monkeys
Around 35 million years ago, a significant split occurred within the primate family tree, leading to the emergence of two major groups: New World monkeys (Platyrrhini) and Old World monkeys (Catarrhini). This evolutionary divergence set in motion a cascade of adaptations that would result in the striking diversity of monkeys we see today.
The division between New World and Old World monkeys was not merely a matter of geographical separation but also marked significant differences in physical traits and behaviors. New World monkeys, found in Central and South America, are characterized by their prehensile tails and a variety of adaptations that allow them to thrive in the lush forests of the Americas. Old World monkeys, on the other hand, are residents of Africa, Asia, and parts of Gibraltar, and they have developed their own unique set of features to suit the diverse environments they inhabit.
II. Classification and Diversity"""
inputs= tokenizer.encode(LARGE_TEXT*5, return_tensors='pt').to(device=1)
force_word = "monkey"
constraints = [PhrasalConstraint(tokenizer.encode(force_word, add_special_tokens=False))]
matcher = {
'contrastive search': {
'do_sample': False,
'top_k': 10,
'penalty_alpha': 0.5,
},
'greedy': {
'do_sample': False,
},
'sample': {
'do_sample': True,
'top_k': 50,
},
'beam sample': {
'num_beams': 2,
'do_sample': True,
},
'beam search': {
'num_beams': 2,
'do_sample': False,
},
'group beam search': {
'num_beams': 2,
'num_beam_groups': 2,
'do_sample': False,
'diversity_penalty': 0.5,
},
'constrained beam search': {
'num_beams': 2,
'constraints': constraints,
'do_sample': False,
},
'assisted': {
'do_sample': False,
'assistant_model': assistant,
},
}
N_repeat = 3
torch.manual_seed(1)
def main(model, inputs, matcher, N_repeat):
eos_token_id = model.generation_config.eos_token_id
if hasattr(eos_token_id, '__len__') and hasattr(eos_token_id, '__getitem__'):
eos_token_id = eos_token_id[0]
N_toks = [100, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000]
results_time = {k: {} for k in matcher.keys()}
results_memory = {k: {} for k in matcher.keys()}
filename = 'patched_version_benchmark_all_gen_methods'
for strategy, input_params in tqdm(matcher.items(), desc='strategy', total=len(matcher.keys())):
for N_tok in tqdm(N_toks, desc='New tokens', leave=False):
# Flag to break if we OOM
OOMed = False
times = []
memory = []
for i in tqdm(range(N_repeat), leave=False):
# To catch OOMs
try:
torch.cuda.reset_peak_memory_stats(1)
actual_peak = torch.cuda.max_memory_allocated(1) / 1024**3
t0 = time.time()
out = model.generate(inputs[:, :300], max_new_tokens=N_tok, min_new_tokens=N_tok, pad_token_id=eos_token_id,
**input_params)
dt = time.time() - t0
memory_used = (torch.cuda.max_memory_allocated(1) / 1024**3) - actual_peak
# Verify that we did generate the correct number of tokens
assert out.shape[-1] == 300 + N_tok
times.append(dt)
memory.append(memory_used)
# If we actually OOM'ed, continue with next strategy as increasing tokens will still OOM
except RuntimeError as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
OOMed = True
break
else:
raise e
# If we OOMed, break once again to change decoding strategy
if OOMed:
break
results_time[strategy][N_tok] = np.mean(times)
results_memory[strategy][N_tok] = np.mean(memory)
# Rewrite results at each iteration (so that we keep them if OOM)
with open(filename + '_memory.json', 'w') as fp:
json.dump(results_memory, fp, indent='\t')
with open(filename + '_time.json', 'w') as fp:
json.dump(results_time, fp, indent='\t')
del model
gc.collect()
if __name__ == '__main__':
main(model, inputs, matcher, N_repeat)
# TO CREATE FIGURES OF THE BENCHMARK 1
import json
import matplotlib.pyplot as plt
import numpy as np
model_names = ['Mistral-7B-v0.1', 'Llama-2-7b-hf', 'Meta-Llama-3-8B']
fix_batch = '_batch_size_1_input_size_300'
fix_length = '_input_size_300_new_tokens_2000'
folder = 'benchmark/'
fig_folder = 'results/'
def load_json(filename: str) -> dict:
with open(filename, 'r') as fp:
data = json.load(fp)
data = {int(k): v for k,v in data.items()}
a, b = np.array(list(data.keys())), np.array(list(data.values()))
sorting = np.argsort(a)
return a[sorting], b[sorting]
# Fix batch figures
for model in model_names:
# memory figure
plt.figure()
plt.plot(*load_json(folder + 'legacy_' + model + fix_batch + '_memory.json'), 'b-', label='Before')
plt.plot(*load_json(folder + 'patched_' + model + fix_batch + '_memory.json'), 'r-', label='After')
plt.xlabel('New tokens generated')
plt.ylabel('Peak memory usage [GiB]')
plt.grid()
plt.legend()
plt.title(model + '\nInput size 300, Batch size 1')
plt.savefig(fig_folder + model + '_memory_fix_batch.pdf', bbox_inches='tight')
plt.show()
# time figure
plt.figure()
plt.plot(*load_json(folder + 'legacy_' + model + fix_batch + '_time.json'), 'b-', label='Before')
plt.plot(*load_json(folder + 'patched_' + model + fix_batch + '_time.json'), 'r-', label='After')
plt.xlabel('New tokens generated')
plt.ylabel('Generation time [s]')
plt.grid()
plt.legend()
plt.title(model + '\nInput size 300, Batch size 1')
plt.savefig(fig_folder + model + '_time_fix_batch.pdf', bbox_inches='tight')
plt.show()
# Fix new tokens figures
for model in model_names:
# memory figure
plt.figure()
plt.plot(*load_json(folder + 'legacy_' + model + fix_length + '_memory.json'), 'b-', label='Before')
plt.plot(*load_json(folder + 'patched_' + model + fix_length + '_memory.json'), 'r-', label='After')
plt.xlabel('Batch size')
plt.ylabel('Peak memory usage [GiB]')
plt.grid()
plt.legend()
plt.title(model + '\nInput size 300, New tokens 2000')
if model == model_names[1]:
plt.text(4.1, 9, 'OOM after this point', color='b')
plt.savefig(fig_folder + model + '_memory_fix_length.pdf', bbox_inches='tight')
plt.show()
# time figure
plt.figure()
plt.plot(*load_json(folder + 'legacy_' + model + fix_length + '_time.json'), 'b-', label='Before')
plt.plot(*load_json(folder + 'patched_' + model + fix_length + '_time.json'), 'r-', label='After')
plt.xlabel('Batch size')
plt.ylabel('Generation time [s]')
plt.grid()
plt.legend()
plt.title(model + '\nInput size 300, New tokens 2000')
if model == model_names[1]:
plt.text(4.1, 57, 'OOM after this point', color='b')
plt.savefig(fig_folder + model + '_time_fix_length.pdf', bbox_inches='tight')
plt.show()
# FIGURES FOR BENCHMARK 2
import json
import matplotlib.pyplot as plt
import numpy as np
model_names = 'Mistral-7B-v0.1'
folder = 'res/'
filename = '_version_benchmark_all_gen_methods_'
assistant = 'gpt2-medium'
matcher = {
'contrastive search': {
'top_k': 10,
'penalty_alpha': 0.5,
},
'greedy': {
},
'sample': {
'top_k': 50,
},
'beam sample': {
'num_beams': 2,
},
'beam search': {
'num_beams': 2,
},
'group beam search': {
'num_beams': 2,
'num_beam_groups': 2,
'diversity_penalty': 0.5,
},
'constrained beam search': {
'num_beams': 2,
},
'assisted': {
'assistant_model': assistant,
},
}
def load_json(filename: str) -> dict:
with open(filename, 'r') as fp:
data = json.load(fp)
for k,v in data.items():
int_keys = [int(x) for x in v.keys()]
a, b = np.array(int_keys), np.array(list(v.values()))
sorting = np.argsort(a)
data[k] = (a[sorting], b[sorting])
return data
patched_memory, patched_time = load_json(folder + 'patched' + filename + 'memory.json'), load_json(folder + 'patched' + filename + 'time.json')
legacy_memory, legacy_time = load_json(folder + 'legacy' + filename + 'memory.json'), load_json(folder + 'legacy' + filename + 'time.json')
# Fix batch figures
save = False
fig_folder = 'results/all_gen_methods/'
for strategy in patched_memory.keys():
fig, axes = plt.subplots(1, 2, figsize=(6.4*2.1, 5.5), sharex=True)
axes[0].plot(*legacy_memory[strategy], 'b-', label='Before')
axes[0].plot(*patched_memory[strategy], 'r-', label='After')
axes[0].set_xlabel('New tokens generated')
axes[0].set_ylabel('Peak memory usage [GiB]')
axes[0].grid()
axes[0].legend()
if strategy == 'contrastive search':
axes[0].text(1020, 3.75, 'OOM after this point', color='b')
axes[1].plot(*legacy_time[strategy], 'b-', label='Before')
axes[1].plot(*patched_time[strategy], 'r-', label='After')
axes[1].set_xlabel('New tokens generated')
axes[1].set_ylabel('Generation time [s]')
axes[1].grid()
axes[1].legend()
if strategy == 'contrastive search':
axes[1].text(1020, 35, 'OOM after this point', color='b')
title = f'{strategy}\nInput size 300, Batch size 1'
if len(matcher[strategy]) > 0:
gen_args = ', '.join(f"'{k}': {v}" for k,v in matcher[strategy].items())
title += f'\n{gen_args}'
plt.suptitle(title)
if save:
plt.savefig(fig_folder + '_'.join(strategy.split(' ')) + '.pdf', bbox_inches='tight')
plt.show()
# FOR FIGURES OF BENCHMARK 3
import numpy as np
import matplotlib.pyplot as plt
import json
def load_json(filename: str) -> dict:
with open(filename, 'r') as fp:
data = json.load(fp)
out = {}
for k1, v1 in data.items():
out[k1] = {}
for k2, v2 in v1.items():
out[k1][int(k2)] = {int(k3): v3 for k3, v3 in v2.items()}
return out
mistral = 'Mistral-7B-v0.1'
llama = 'Llama-2-7b-hf'
results_before_mistral = load_json('legacy_' + mistral + '_ratios_memory.json')
results_after_mistral = load_json('patched_' + mistral + '_ratios_memory.json')
results_before_llama2 = load_json('legacy_' + llama + '_ratios_memory.json')
results_after_llama2 = load_json('patched_' + llama + '_ratios_memory.json')
results_before_mistral_top_k_4 = load_json('legacy_' + mistral + '_ratios' + '_contrastive_k_4_memory.json')
results_after_mistral_top_k_4 = load_json('patched_' + mistral + '_ratios' + '_contrastive_k_4_memory.json')
results_before_mistral_top_k_8 = load_json('legacy_' + mistral + '_ratios' + '_contrastive_k_8_memory.json')
results_after_mistral_top_k_8 = load_json('patched_' + mistral + '_ratios' + '_contrastive_k_8_memory.json')
# First figures
import numpy as np
import matplotlib.pyplot as plt
methods = ['sample', 'beam sample', 'contrastive search']
input_sizes = [300, 1000]
N_toks = [2, 100, 250, 500, 750, 1000, 1250, 1500, 2000, 3000, 4000]
for strategy in methods:
fig, axes = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(6.4*2, 5.2))
for INPUT in input_sizes:
memory_before_mistral = np.array(list(results_before_mistral[strategy][INPUT].values()))
toks_before_mistral = np.array(list(results_before_mistral[strategy][INPUT].keys()))
sorting = np.argsort(toks_before_mistral)
memory_before_mistral = memory_before_mistral[sorting]
toks_before_mistral = toks_before_mistral[sorting]
memory_after_mistral = np.array(list(results_after_mistral[strategy][INPUT].values()))
toks_after_mistral = np.array(list(results_after_mistral[strategy][INPUT].keys()))
sorting = np.argsort(toks_after_mistral)
memory_after_mistral = memory_after_mistral[sorting]
memory_before_llama2 = np.array(list(results_before_llama2[strategy][INPUT].values()))
toks_before_llama2 = np.array(list(results_before_llama2[strategy][INPUT].keys()))
sorting = np.argsort(toks_before_llama2)
memory_before_llama2 = memory_before_llama2[sorting]
toks_before_llama2 = toks_before_llama2[sorting]
memory_after_llama2 = np.array(list(results_after_llama2[strategy][INPUT].values()))
toks_after_llama2 = np.array(list(results_after_llama2[strategy][INPUT].keys()))
sorting = np.argsort(toks_after_llama2)
memory_after_llama2 = memory_after_llama2[sorting]
min_ = min(len(memory_after_llama2), len(memory_before_llama2))
toks_before_llama2 = toks_before_llama2[:min_]
memory_before_llama2 = memory_before_llama2[:min_]
memory_after_llama2 = memory_after_llama2[:min_]
axes[0].plot(toks_before_llama2, memory_before_llama2 / memory_after_llama2, label=f'Input size = {INPUT}')
axes[1].plot(toks_before_mistral, memory_before_mistral / memory_after_mistral, label=f'Input size = {INPUT}')
if strategy == 'beam sample':
axes[0].text(2020, 1.485, 'OOM for legacy version\nafter those points')
axes[0].set_xlabel('New tokens generated')
axes[0].set_ylabel('Memory before / Memory after')
axes[1].set_xlabel('New tokens generated')
axes[1].set_ylabel('Memory before / Memory after')
axes[0].set_title('Usual cache size (most models, e.g. Llama2)\nHere: Llama2')
axes[1].set_title('Small efficient cache size (e.g. Mistral)\nHere: Mistral')
axes[0].grid()
axes[1].grid()
axes[0].legend()
axes[1].legend()
# plt.suptitle(strategy)
fig.savefig(strategy + '.pdf', bbox_inches='tight')
plt.show()
# Second figure
import numpy as np
import matplotlib.pyplot as plt
methods = ['contrastive search']
input_sizes = [300, 1000]
N_toks = [2, 100, 250, 500, 750, 1000, 1250, 1500, 2000, 3000, 4000]
for strategy in methods:
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(6.4*3, 5.2))
for INPUT in input_sizes:
memory_before_k2 = np.array(list(results_before_mistral[strategy][INPUT].values()))
toks_before_k2 = np.array(list(results_before_mistral[strategy][INPUT].keys()))
sorting = np.argsort(toks_before_k2)
memory_before_k2 = memory_before_k2[sorting]
toks_before_k2 = toks_before_k2[sorting]
memory_after_k2 = np.array(list(results_after_mistral[strategy][INPUT].values()))
toks_after_k2 = np.array(list(results_after_mistral[strategy][INPUT].keys()))
sorting = np.argsort(toks_after_k2)
memory_after_k2 = memory_after_k2[sorting]
memory_before_k4 = np.array(list(results_before_mistral_top_k_4[strategy][INPUT].values()))
toks_before_k4 = np.array(list(results_before_mistral_top_k_4[strategy][INPUT].keys()))
sorting = np.argsort(toks_before_k4)
memory_before_k4 = memory_before_k4[sorting]
toks_before_k4 = toks_before_k4[sorting]
memory_after_k4 = np.array(list(results_after_mistral_top_k_4[strategy][INPUT].values()))
toks_after_k4 = np.array(list(results_after_mistral_top_k_4[strategy][INPUT].keys()))
sorting = np.argsort(toks_after_k4)
memory_after_k4 = memory_after_k4[sorting]
min_ = min(len(memory_before_k4), len(memory_before_k4))
toks_before_k4 = toks_before_k4[:min_]
memory_before_k4 = memory_before_k4[:min_]
memory_after_k4 = memory_after_k4[:min_]
memory_before_k8 = np.array(list(results_before_mistral_top_k_8[strategy][INPUT].values()))
toks_before_k8 = np.array(list(results_before_mistral_top_k_8[strategy][INPUT].keys()))
sorting = np.argsort(toks_before_k8)
memory_before_k8 = memory_before_k8[sorting]
toks_before_k8 = toks_before_k8[sorting]
memory_after_k8 = np.array(list(results_after_mistral_top_k_8[strategy][INPUT].values()))
toks_after_k8 = np.array(list(results_after_mistral_top_k_8[strategy][INPUT].keys()))
sorting = np.argsort(toks_after_k8)
memory_after_k8 = memory_after_k8[sorting]
min_ = min(len(memory_before_k8), len(memory_before_k8))
toks_before_k8 = toks_before_k8[:min_]
memory_before_k8 = memory_before_k8[:min_]
memory_after_k8 = memory_after_k8[:min_]
axes[0].plot(toks_before_k2, memory_before_k2 / memory_after_k2, label=f'Input size = {INPUT}')
axes[1].plot(toks_before_k4, memory_before_k4 / memory_after_k4, label=f'Input size = {INPUT}')
axes[2].plot(toks_before_k8, memory_before_k8 / memory_after_k8, label=f'Input size = {INPUT}')
axes[0].set_xlabel('New tokens generated')
axes[0].set_ylabel('Memory before / Memory after')
axes[1].set_xlabel('New tokens generated')
axes[1].set_ylabel('Memory before / Memory after')
axes[2].set_xlabel('New tokens generated')
axes[2].set_ylabel('Memory before / Memory after')
axes[0].set_title('Top k = 2')
axes[1].set_title('Top k = 4')
axes[2].set_title('Top k = 8')
axes[0].grid()
axes[1].grid()
axes[2].grid()
axes[0].legend()
axes[1].legend()
axes[2].legend()
plt.suptitle('Mistral')
fig.savefig('contrastive_top_k.pdf', bbox_inches='tight')
plt.show()
# BENCHMARK 1
# INITIAL BENCHMARK WHEN OPENING THE PR
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time
from tqdm import tqdm
import numpy as np
import json
import gc
# Random generated text
LARGE_TEXT = """Title: Monkeys: Nature's Pranksters, Social Geniuses, and Ecological Wonders
Introduction
Monkeys, the charismatic and diverse members of the primate order, have long held a special place in the annals of our fascination with the animal kingdom. With their playful antics, astonishing intelligence, and complex social structures, they serve as a source of both joy and profound scientific inquiry. In this comprehensive exploration, we embark on a deep dive into the world of monkeys, spanning their evolutionary history, classifications, ecological roles, social dynamics, communication methods, and the pressing need for conservation. These captivating creatures offer insights into the intricacies of the natural world, our own evolutionary heritage, and the urgent importance of preserving biodiversity.
I. Evolutionary Origins
To understand the world of monkeys, we must embark on a journey through their evolutionary past, a tapestry that stretches back millions of years. Monkeys are part of the grand order of Primates, and their lineage is interwoven with the broader history of these remarkable mammals.
A. Primate Origins
The story of primates, including monkeys, begins around 60 million years ago. At that time, the world was a vastly different place, dominated by the reign of dinosaurs. It was during this period of Earth's history that the first primates, known as prosimians, emerged. These small, tree-dwelling mammals exhibited several characteristics that would become hallmarks of all primates: grasping hands and feet, forward-facing eyes for stereoscopic vision, and an enlarged brain relative to body size. These adaptations suited them for life in the trees, where they foraged for insects and fruits.
B. The Emergence of Monkeys
Around 35 million years ago, a significant split occurred within the primate family tree, leading to the emergence of two major groups: New World monkeys (Platyrrhini) and Old World monkeys (Catarrhini). This evolutionary divergence set in motion a cascade of adaptations that would result in the striking diversity of monkeys we see today.
The division between New World and Old World monkeys was not merely a matter of geographical separation but also marked significant differences in physical traits and behaviors. New World monkeys, found in Central and South America, are characterized by their prehensile tails and a variety of adaptations that allow them to thrive in the lush forests of the Americas. Old World monkeys, on the other hand, are residents of Africa, Asia, and parts of Gibraltar, and they have developed their own unique set of features to suit the diverse environments they inhabit.
II. Classification and Diversity"""
N_repeat = 5
torch.manual_seed(1)
def main(model_name, dtype, restrict_context_size, use_fast):
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='flash_attention_2',
torch_dtype=dtype, low_cpu_mem_usage=True).cuda(1)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=use_fast)
inputs = tokenizer.encode(LARGE_TEXT*10, return_tensors='pt').to(device=1)
# Useless in our case but silences warning
eos_token_id = model.generation_config.eos_token_id
if hasattr(eos_token_id, '__len__') and hasattr(eos_token_id, '__getitem__'):
eos_token_id = eos_token_id[0]
if restrict_context_size:
N_toks = [500, 1000, 2000, 3000, 3500]
else:
N_toks = [500, 1000, 2000, 3000, 4000, 5000, 6000, 7000]
results_time = {}
results_memory = {}
filename = 'patched_' + model_name.split('/', 1)[1] + '_batch_size_1_input_size_300'
for N_tok in tqdm(N_toks, desc='New tokens'):
times = []
memory = []
for i in tqdm(range(N_repeat), leave=False):
torch.cuda.reset_peak_memory_stats(1)
actual_peak = torch.cuda.max_memory_allocated(1) / 1024**3
t0 = time.time()
out = model.generate(inputs[:, :300], max_new_tokens=N_tok, min_new_tokens=N_tok, do_sample=True,
num_return_sequences=1, temperature=0.8, top_k=50, top_p=0.9,
pad_token_id=eos_token_id)
dt = time.time() - t0
memory_used = (torch.cuda.max_memory_allocated(1) / 1024**3) - actual_peak
# Verify that we did generate the correct number of tokens
assert out.shape[-1] == 300 + N_tok
times.append(dt)
memory.append(memory_used)
results_time[N_tok] = np.mean(times)
results_memory[N_tok] = np.mean(memory)
# print(f'New tokens: {N_tok} --- {results_time[N_tok]:.3e} s --- {results_memory[N_tok]:.2f} GiB')
# Rewrite results at each iteration (so that we keep them if OOM)
with open(filename + '_memory.json', 'w') as fp:
json.dump(results_memory, fp, indent='\t')
with open(filename + '_time.json', 'w') as fp:
json.dump(results_time, fp, indent='\t')
batch_sizes = [2, 4, 6, 8, 10]
results_time = {}
results_memory = {}
filename = 'patched_' + model_name.split('/', 1)[1] + '_input_size_300_new_tokens_2000'
for batch_size in tqdm(batch_sizes, desc='Batch sizes'):
times = []
memory = []
for i in tqdm(range(N_repeat), leave=False):
torch.cuda.reset_peak_memory_stats(1)
actual_peak = torch.cuda.max_memory_allocated(1) / 1024**3
t0 = time.time()
out = model.generate(inputs[:, :300], max_new_tokens=2000, min_new_tokens=2000, do_sample=True,
num_return_sequences=batch_size, temperature=0.8, top_k=50, top_p=0.9,
pad_token_id=eos_token_id)
dt = time.time() - t0
memory_used = (torch.cuda.max_memory_allocated(1) / 1024**3) - actual_peak
# Verify that we did generate the correct number of tokens
assert out.shape[-1] == 2300
assert out.shape[0] == batch_size
times.append(dt)
memory.append(memory_used)
results_time[batch_size] = np.mean(times)
results_memory[batch_size] = np.mean(memory)
# print(f'Batch size: {batch_size} --- {results_time[N_tok]:.3e} s --- {results_memory[N_tok]:.2f} GiB')
# Rewrite results at each iteration (so that we keep them if OOM)
with open(filename + '_memory.json', 'w') as fp:
json.dump(results_memory, fp, indent='\t')
with open(filename + '_time.json', 'w') as fp:
json.dump(results_time, fp, indent='\t')
del model
gc.collect()
if __name__ == '__main__':
model_names = ['mistralai/Mistral-7B-v0.1', 'meta-llama/Llama-2-7b-hf', 'meta-llama/Meta-Llama-3-8B']
dtypes = [torch.bfloat16, torch.float16, torch.bfloat16]
restrict_contexts = [False, True, False]
fasts = [False]*3
for name, dtype, restrict_context_size, use_fast in zip(model_names, dtypes, restrict_contexts, fasts):
try:
main(name, dtype, restrict_context_size, use_fast)
except RuntimeError as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
pass
else:
raise e
gc.collect()
torch.cuda.empty_cache()
# BENCHMARK 3
# LAST BENCHMARK (RATIOS) SHOWN IN THE PR
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import warnings
import json
import gc
warnings.filterwarnings("ignore")
LARGE_TEXT = """Title: Monkeys: Nature's Pranksters, Social Geniuses, and Ecological Wonders
Introduction
Monkeys, the charismatic and diverse members of the primate order, have long held a special place in the annals of our fascination with the animal kingdom. With their playful antics, astonishing intelligence, and complex social structures, they serve as a source of both joy and profound scientific inquiry. In this comprehensive exploration, we embark on a deep dive into the world of monkeys, spanning their evolutionary history, classifications, ecological roles, social dynamics, communication methods, and the pressing need for conservation. These captivating creatures offer insights into the intricacies of the natural world, our own evolutionary heritage, and the urgent importance of preserving biodiversity.
I. Evolutionary Origins
To understand the world of monkeys, we must embark on a journey through their evolutionary past, a tapestry that stretches back millions of years. Monkeys are part of the grand order of Primates, and their lineage is interwoven with the broader history of these remarkable mammals.
A. Primate Origins
The story of primates, including monkeys, begins around 60 million years ago. At that time, the world was a vastly different place, dominated by the reign of dinosaurs. It was during this period of Earth's history that the first primates, known as prosimians, emerged. These small, tree-dwelling mammals exhibited several characteristics that would become hallmarks of all primates: grasping hands and feet, forward-facing eyes for stereoscopic vision, and an enlarged brain relative to body size. These adaptations suited them for life in the trees, where they foraged for insects and fruits.
B. The Emergence of Monkeys
Around 35 million years ago, a significant split occurred within the primate family tree, leading to the emergence of two major groups: New World monkeys (Platyrrhini) and Old World monkeys (Catarrhini). This evolutionary divergence set in motion a cascade of adaptations that would result in the striking diversity of monkeys we see today.
The division between New World and Old World monkeys was not merely a matter of geographical separation but also marked significant differences in physical traits and behaviors. New World monkeys, found in Central and South America, are characterized by their prehensile tails and a variety of adaptations that allow them to thrive in the lush forests of the Americas. Old World monkeys, on the other hand, are residents of Africa, Asia, and parts of Gibraltar, and they have developed their own unique set of features to suit the diverse environments they inhabit.
II. Classification and Diversity"""
matcher = {
'contrastive search': {
'do_sample': False,
'top_k': 2,
'penalty_alpha': 0.5,
},
'greedy': {
'do_sample': False,
},
'sample': {
'do_sample': True,
'top_k': 50,
},
'group beam search': {
'num_beams': 2,
'num_beam_groups': 2,
'do_sample': False,
'diversity_penalty': 0.5,
},
'beam sample': {
'num_beams': 2,
'do_sample': True,
},
'beam search': {
'num_beams': 2,
'do_sample': False,
},
}
def memory_usage(past_key_values):
"""Recursively compute the memory footprint of past key values (in bytes).
"""
if isinstance(past_key_values, torch.Tensor):
return past_key_values.nelement() * past_key_values.element_size()
elif isinstance(past_key_values[0], torch.Tensor):
return sum([x.nelement() * x.element_size() for x in past_key_values])
else:
return sum([memory_usage(x) for x in past_key_values])
model_names = ['mistralai/Mistral-7B-v0.1', 'meta-llama/Llama-2-7b-hf']
dtypes = [torch.bfloat16, torch.float16]
torch.manual_seed(1)
methods = ['sample', 'beam sample', 'contrastive search']
input_sizes = [300, 1000]
N_toks = [2, 100, 250, 500, 750, 1000, 1250, 1500, 2000, 3000, 4000]
for model_name, dtype in zip(model_names, dtypes):
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='flash_attention_2',
torch_dtype=dtype, low_cpu_mem_usage=True).cuda(0)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
inputs = tokenizer.encode(LARGE_TEXT*300, return_tensors='pt').to(device=0)
eos_token_id = model.generation_config.eos_token_id
if hasattr(eos_token_id, '__len__') and hasattr(eos_token_id, '__getitem__'):
eos_token_id = eos_token_id[0]
results_memory = {}
filename = 'patched_' + model_name.split('/', 1)[1] + '_ratios'
for strategy in methods:
input_params = matcher[strategy]
results_memory[strategy] = {}
for INPUT in input_sizes:
results_memory[strategy][INPUT] = {}
for TOKS in N_toks:
try:
torch.cuda.reset_peak_memory_stats(0)
actual_peak = torch.cuda.max_memory_allocated(0) / 1024**3
out = model.generate(inputs[:, :INPUT], max_new_tokens=TOKS, min_new_tokens=TOKS, pad_token_id=eos_token_id,
**input_params)
memory_used = (torch.cuda.max_memory_allocated(0) / 1024**3) - actual_peak
results_memory[strategy][INPUT][TOKS] = memory_used
except RuntimeError as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
break
else:
raise e
# Rewrite results at each iteration (so that we keep them if OOM)
with open(filename + '_memory.json', 'w') as fp:
json.dump(results_memory, fp, indent='\t')
# Do more contrastive tries with bigger top_ks
if model_name == 'mistralai/Mistral-7B-v0.1':
for k in [4, 8]:
input_params = {'do_sample': False, 'top_k': k, 'penalty_alpha': 0.5}
results_memory = {'contrastive search': {}}
for INPUT in input_sizes:
results_memory['contrastive search'][INPUT] = {}
for TOKS in N_toks:
try:
torch.cuda.reset_peak_memory_stats(0)
actual_peak = torch.cuda.max_memory_allocated(0) / 1024**3
out = model.generate(inputs[:, :INPUT], max_new_tokens=TOKS, min_new_tokens=TOKS, pad_token_id=eos_token_id,
**input_params)
memory_used = (torch.cuda.max_memory_allocated(0) / 1024**3) - actual_peak
results_memory['contrastive search'][INPUT][TOKS] = memory_used
except RuntimeError as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
break
else:
raise e
# Rewrite results at each iteration (so that we keep them if OOM)
with open(filename + f'_contrastive_k_{k}_memory.json', 'w') as fp:
json.dump(results_memory, fp, indent='\t')
del model
gc.collect()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment