Last active
May 8, 2024 17:54
-
-
Save Cyrilvallez/ce1adfad1d561c1e8dc92666ab5a9e8c to your computer and use it in GitHub Desktop.
Benchmark for PR https://github.com/huggingface/transformers/pull/30536 on Transformers
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# BENCHMARK 2 | |
# SECOND BENCHMARK OF ALL GENERATION METHODS | |
from transformers import AutoTokenizer, AutoModelForCausalLM, PhrasalConstraint | |
import torch | |
import time | |
from tqdm import tqdm | |
import numpy as np | |
import json | |
import gc | |
# model_name = 'meta-llama/Llama-2-7b-hf' | |
model_name = 'mistralai/Mistral-7B-v0.1' | |
dtype = torch.bfloat16 | |
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='flash_attention_2', | |
torch_dtype=dtype, low_cpu_mem_usage=True).cuda(1) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
assistant = 'gpt2-medium' | |
assistant = AutoModelForCausalLM.from_pretrained(assistant, torch_dtype=torch.float32, | |
low_cpu_mem_usage=True).cuda(1) | |
# Random generated text | |
LARGE_TEXT = """Title: Monkeys: Nature's Pranksters, Social Geniuses, and Ecological Wonders | |
Introduction | |
Monkeys, the charismatic and diverse members of the primate order, have long held a special place in the annals of our fascination with the animal kingdom. With their playful antics, astonishing intelligence, and complex social structures, they serve as a source of both joy and profound scientific inquiry. In this comprehensive exploration, we embark on a deep dive into the world of monkeys, spanning their evolutionary history, classifications, ecological roles, social dynamics, communication methods, and the pressing need for conservation. These captivating creatures offer insights into the intricacies of the natural world, our own evolutionary heritage, and the urgent importance of preserving biodiversity. | |
I. Evolutionary Origins | |
To understand the world of monkeys, we must embark on a journey through their evolutionary past, a tapestry that stretches back millions of years. Monkeys are part of the grand order of Primates, and their lineage is interwoven with the broader history of these remarkable mammals. | |
A. Primate Origins | |
The story of primates, including monkeys, begins around 60 million years ago. At that time, the world was a vastly different place, dominated by the reign of dinosaurs. It was during this period of Earth's history that the first primates, known as prosimians, emerged. These small, tree-dwelling mammals exhibited several characteristics that would become hallmarks of all primates: grasping hands and feet, forward-facing eyes for stereoscopic vision, and an enlarged brain relative to body size. These adaptations suited them for life in the trees, where they foraged for insects and fruits. | |
B. The Emergence of Monkeys | |
Around 35 million years ago, a significant split occurred within the primate family tree, leading to the emergence of two major groups: New World monkeys (Platyrrhini) and Old World monkeys (Catarrhini). This evolutionary divergence set in motion a cascade of adaptations that would result in the striking diversity of monkeys we see today. | |
The division between New World and Old World monkeys was not merely a matter of geographical separation but also marked significant differences in physical traits and behaviors. New World monkeys, found in Central and South America, are characterized by their prehensile tails and a variety of adaptations that allow them to thrive in the lush forests of the Americas. Old World monkeys, on the other hand, are residents of Africa, Asia, and parts of Gibraltar, and they have developed their own unique set of features to suit the diverse environments they inhabit. | |
II. Classification and Diversity""" | |
inputs= tokenizer.encode(LARGE_TEXT*5, return_tensors='pt').to(device=1) | |
force_word = "monkey" | |
constraints = [PhrasalConstraint(tokenizer.encode(force_word, add_special_tokens=False))] | |
matcher = { | |
'contrastive search': { | |
'do_sample': False, | |
'top_k': 10, | |
'penalty_alpha': 0.5, | |
}, | |
'greedy': { | |
'do_sample': False, | |
}, | |
'sample': { | |
'do_sample': True, | |
'top_k': 50, | |
}, | |
'beam sample': { | |
'num_beams': 2, | |
'do_sample': True, | |
}, | |
'beam search': { | |
'num_beams': 2, | |
'do_sample': False, | |
}, | |
'group beam search': { | |
'num_beams': 2, | |
'num_beam_groups': 2, | |
'do_sample': False, | |
'diversity_penalty': 0.5, | |
}, | |
'constrained beam search': { | |
'num_beams': 2, | |
'constraints': constraints, | |
'do_sample': False, | |
}, | |
'assisted': { | |
'do_sample': False, | |
'assistant_model': assistant, | |
}, | |
} | |
N_repeat = 3 | |
torch.manual_seed(1) | |
def main(model, inputs, matcher, N_repeat): | |
eos_token_id = model.generation_config.eos_token_id | |
if hasattr(eos_token_id, '__len__') and hasattr(eos_token_id, '__getitem__'): | |
eos_token_id = eos_token_id[0] | |
N_toks = [100, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000] | |
results_time = {k: {} for k in matcher.keys()} | |
results_memory = {k: {} for k in matcher.keys()} | |
filename = 'patched_version_benchmark_all_gen_methods' | |
for strategy, input_params in tqdm(matcher.items(), desc='strategy', total=len(matcher.keys())): | |
for N_tok in tqdm(N_toks, desc='New tokens', leave=False): | |
# Flag to break if we OOM | |
OOMed = False | |
times = [] | |
memory = [] | |
for i in tqdm(range(N_repeat), leave=False): | |
# To catch OOMs | |
try: | |
torch.cuda.reset_peak_memory_stats(1) | |
actual_peak = torch.cuda.max_memory_allocated(1) / 1024**3 | |
t0 = time.time() | |
out = model.generate(inputs[:, :300], max_new_tokens=N_tok, min_new_tokens=N_tok, pad_token_id=eos_token_id, | |
**input_params) | |
dt = time.time() - t0 | |
memory_used = (torch.cuda.max_memory_allocated(1) / 1024**3) - actual_peak | |
# Verify that we did generate the correct number of tokens | |
assert out.shape[-1] == 300 + N_tok | |
times.append(dt) | |
memory.append(memory_used) | |
# If we actually OOM'ed, continue with next strategy as increasing tokens will still OOM | |
except RuntimeError as e: | |
if isinstance(e, torch.cuda.OutOfMemoryError): | |
OOMed = True | |
break | |
else: | |
raise e | |
# If we OOMed, break once again to change decoding strategy | |
if OOMed: | |
break | |
results_time[strategy][N_tok] = np.mean(times) | |
results_memory[strategy][N_tok] = np.mean(memory) | |
# Rewrite results at each iteration (so that we keep them if OOM) | |
with open(filename + '_memory.json', 'w') as fp: | |
json.dump(results_memory, fp, indent='\t') | |
with open(filename + '_time.json', 'w') as fp: | |
json.dump(results_time, fp, indent='\t') | |
del model | |
gc.collect() | |
if __name__ == '__main__': | |
main(model, inputs, matcher, N_repeat) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# TO CREATE FIGURES OF THE BENCHMARK 1 | |
import json | |
import matplotlib.pyplot as plt | |
import numpy as np | |
model_names = ['Mistral-7B-v0.1', 'Llama-2-7b-hf', 'Meta-Llama-3-8B'] | |
fix_batch = '_batch_size_1_input_size_300' | |
fix_length = '_input_size_300_new_tokens_2000' | |
folder = 'benchmark/' | |
fig_folder = 'results/' | |
def load_json(filename: str) -> dict: | |
with open(filename, 'r') as fp: | |
data = json.load(fp) | |
data = {int(k): v for k,v in data.items()} | |
a, b = np.array(list(data.keys())), np.array(list(data.values())) | |
sorting = np.argsort(a) | |
return a[sorting], b[sorting] | |
# Fix batch figures | |
for model in model_names: | |
# memory figure | |
plt.figure() | |
plt.plot(*load_json(folder + 'legacy_' + model + fix_batch + '_memory.json'), 'b-', label='Before') | |
plt.plot(*load_json(folder + 'patched_' + model + fix_batch + '_memory.json'), 'r-', label='After') | |
plt.xlabel('New tokens generated') | |
plt.ylabel('Peak memory usage [GiB]') | |
plt.grid() | |
plt.legend() | |
plt.title(model + '\nInput size 300, Batch size 1') | |
plt.savefig(fig_folder + model + '_memory_fix_batch.pdf', bbox_inches='tight') | |
plt.show() | |
# time figure | |
plt.figure() | |
plt.plot(*load_json(folder + 'legacy_' + model + fix_batch + '_time.json'), 'b-', label='Before') | |
plt.plot(*load_json(folder + 'patched_' + model + fix_batch + '_time.json'), 'r-', label='After') | |
plt.xlabel('New tokens generated') | |
plt.ylabel('Generation time [s]') | |
plt.grid() | |
plt.legend() | |
plt.title(model + '\nInput size 300, Batch size 1') | |
plt.savefig(fig_folder + model + '_time_fix_batch.pdf', bbox_inches='tight') | |
plt.show() | |
# Fix new tokens figures | |
for model in model_names: | |
# memory figure | |
plt.figure() | |
plt.plot(*load_json(folder + 'legacy_' + model + fix_length + '_memory.json'), 'b-', label='Before') | |
plt.plot(*load_json(folder + 'patched_' + model + fix_length + '_memory.json'), 'r-', label='After') | |
plt.xlabel('Batch size') | |
plt.ylabel('Peak memory usage [GiB]') | |
plt.grid() | |
plt.legend() | |
plt.title(model + '\nInput size 300, New tokens 2000') | |
if model == model_names[1]: | |
plt.text(4.1, 9, 'OOM after this point', color='b') | |
plt.savefig(fig_folder + model + '_memory_fix_length.pdf', bbox_inches='tight') | |
plt.show() | |
# time figure | |
plt.figure() | |
plt.plot(*load_json(folder + 'legacy_' + model + fix_length + '_time.json'), 'b-', label='Before') | |
plt.plot(*load_json(folder + 'patched_' + model + fix_length + '_time.json'), 'r-', label='After') | |
plt.xlabel('Batch size') | |
plt.ylabel('Generation time [s]') | |
plt.grid() | |
plt.legend() | |
plt.title(model + '\nInput size 300, New tokens 2000') | |
if model == model_names[1]: | |
plt.text(4.1, 57, 'OOM after this point', color='b') | |
plt.savefig(fig_folder + model + '_time_fix_length.pdf', bbox_inches='tight') | |
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# FIGURES FOR BENCHMARK 2 | |
import json | |
import matplotlib.pyplot as plt | |
import numpy as np | |
model_names = 'Mistral-7B-v0.1' | |
folder = 'res/' | |
filename = '_version_benchmark_all_gen_methods_' | |
assistant = 'gpt2-medium' | |
matcher = { | |
'contrastive search': { | |
'top_k': 10, | |
'penalty_alpha': 0.5, | |
}, | |
'greedy': { | |
}, | |
'sample': { | |
'top_k': 50, | |
}, | |
'beam sample': { | |
'num_beams': 2, | |
}, | |
'beam search': { | |
'num_beams': 2, | |
}, | |
'group beam search': { | |
'num_beams': 2, | |
'num_beam_groups': 2, | |
'diversity_penalty': 0.5, | |
}, | |
'constrained beam search': { | |
'num_beams': 2, | |
}, | |
'assisted': { | |
'assistant_model': assistant, | |
}, | |
} | |
def load_json(filename: str) -> dict: | |
with open(filename, 'r') as fp: | |
data = json.load(fp) | |
for k,v in data.items(): | |
int_keys = [int(x) for x in v.keys()] | |
a, b = np.array(int_keys), np.array(list(v.values())) | |
sorting = np.argsort(a) | |
data[k] = (a[sorting], b[sorting]) | |
return data | |
patched_memory, patched_time = load_json(folder + 'patched' + filename + 'memory.json'), load_json(folder + 'patched' + filename + 'time.json') | |
legacy_memory, legacy_time = load_json(folder + 'legacy' + filename + 'memory.json'), load_json(folder + 'legacy' + filename + 'time.json') | |
# Fix batch figures | |
save = False | |
fig_folder = 'results/all_gen_methods/' | |
for strategy in patched_memory.keys(): | |
fig, axes = plt.subplots(1, 2, figsize=(6.4*2.1, 5.5), sharex=True) | |
axes[0].plot(*legacy_memory[strategy], 'b-', label='Before') | |
axes[0].plot(*patched_memory[strategy], 'r-', label='After') | |
axes[0].set_xlabel('New tokens generated') | |
axes[0].set_ylabel('Peak memory usage [GiB]') | |
axes[0].grid() | |
axes[0].legend() | |
if strategy == 'contrastive search': | |
axes[0].text(1020, 3.75, 'OOM after this point', color='b') | |
axes[1].plot(*legacy_time[strategy], 'b-', label='Before') | |
axes[1].plot(*patched_time[strategy], 'r-', label='After') | |
axes[1].set_xlabel('New tokens generated') | |
axes[1].set_ylabel('Generation time [s]') | |
axes[1].grid() | |
axes[1].legend() | |
if strategy == 'contrastive search': | |
axes[1].text(1020, 35, 'OOM after this point', color='b') | |
title = f'{strategy}\nInput size 300, Batch size 1' | |
if len(matcher[strategy]) > 0: | |
gen_args = ', '.join(f"'{k}': {v}" for k,v in matcher[strategy].items()) | |
title += f'\n{gen_args}' | |
plt.suptitle(title) | |
if save: | |
plt.savefig(fig_folder + '_'.join(strategy.split(' ')) + '.pdf', bbox_inches='tight') | |
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# FOR FIGURES OF BENCHMARK 3 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import json | |
def load_json(filename: str) -> dict: | |
with open(filename, 'r') as fp: | |
data = json.load(fp) | |
out = {} | |
for k1, v1 in data.items(): | |
out[k1] = {} | |
for k2, v2 in v1.items(): | |
out[k1][int(k2)] = {int(k3): v3 for k3, v3 in v2.items()} | |
return out | |
mistral = 'Mistral-7B-v0.1' | |
llama = 'Llama-2-7b-hf' | |
results_before_mistral = load_json('legacy_' + mistral + '_ratios_memory.json') | |
results_after_mistral = load_json('patched_' + mistral + '_ratios_memory.json') | |
results_before_llama2 = load_json('legacy_' + llama + '_ratios_memory.json') | |
results_after_llama2 = load_json('patched_' + llama + '_ratios_memory.json') | |
results_before_mistral_top_k_4 = load_json('legacy_' + mistral + '_ratios' + '_contrastive_k_4_memory.json') | |
results_after_mistral_top_k_4 = load_json('patched_' + mistral + '_ratios' + '_contrastive_k_4_memory.json') | |
results_before_mistral_top_k_8 = load_json('legacy_' + mistral + '_ratios' + '_contrastive_k_8_memory.json') | |
results_after_mistral_top_k_8 = load_json('patched_' + mistral + '_ratios' + '_contrastive_k_8_memory.json') | |
# First figures | |
import numpy as np | |
import matplotlib.pyplot as plt | |
methods = ['sample', 'beam sample', 'contrastive search'] | |
input_sizes = [300, 1000] | |
N_toks = [2, 100, 250, 500, 750, 1000, 1250, 1500, 2000, 3000, 4000] | |
for strategy in methods: | |
fig, axes = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(6.4*2, 5.2)) | |
for INPUT in input_sizes: | |
memory_before_mistral = np.array(list(results_before_mistral[strategy][INPUT].values())) | |
toks_before_mistral = np.array(list(results_before_mistral[strategy][INPUT].keys())) | |
sorting = np.argsort(toks_before_mistral) | |
memory_before_mistral = memory_before_mistral[sorting] | |
toks_before_mistral = toks_before_mistral[sorting] | |
memory_after_mistral = np.array(list(results_after_mistral[strategy][INPUT].values())) | |
toks_after_mistral = np.array(list(results_after_mistral[strategy][INPUT].keys())) | |
sorting = np.argsort(toks_after_mistral) | |
memory_after_mistral = memory_after_mistral[sorting] | |
memory_before_llama2 = np.array(list(results_before_llama2[strategy][INPUT].values())) | |
toks_before_llama2 = np.array(list(results_before_llama2[strategy][INPUT].keys())) | |
sorting = np.argsort(toks_before_llama2) | |
memory_before_llama2 = memory_before_llama2[sorting] | |
toks_before_llama2 = toks_before_llama2[sorting] | |
memory_after_llama2 = np.array(list(results_after_llama2[strategy][INPUT].values())) | |
toks_after_llama2 = np.array(list(results_after_llama2[strategy][INPUT].keys())) | |
sorting = np.argsort(toks_after_llama2) | |
memory_after_llama2 = memory_after_llama2[sorting] | |
min_ = min(len(memory_after_llama2), len(memory_before_llama2)) | |
toks_before_llama2 = toks_before_llama2[:min_] | |
memory_before_llama2 = memory_before_llama2[:min_] | |
memory_after_llama2 = memory_after_llama2[:min_] | |
axes[0].plot(toks_before_llama2, memory_before_llama2 / memory_after_llama2, label=f'Input size = {INPUT}') | |
axes[1].plot(toks_before_mistral, memory_before_mistral / memory_after_mistral, label=f'Input size = {INPUT}') | |
if strategy == 'beam sample': | |
axes[0].text(2020, 1.485, 'OOM for legacy version\nafter those points') | |
axes[0].set_xlabel('New tokens generated') | |
axes[0].set_ylabel('Memory before / Memory after') | |
axes[1].set_xlabel('New tokens generated') | |
axes[1].set_ylabel('Memory before / Memory after') | |
axes[0].set_title('Usual cache size (most models, e.g. Llama2)\nHere: Llama2') | |
axes[1].set_title('Small efficient cache size (e.g. Mistral)\nHere: Mistral') | |
axes[0].grid() | |
axes[1].grid() | |
axes[0].legend() | |
axes[1].legend() | |
# plt.suptitle(strategy) | |
fig.savefig(strategy + '.pdf', bbox_inches='tight') | |
plt.show() | |
# Second figure | |
import numpy as np | |
import matplotlib.pyplot as plt | |
methods = ['contrastive search'] | |
input_sizes = [300, 1000] | |
N_toks = [2, 100, 250, 500, 750, 1000, 1250, 1500, 2000, 3000, 4000] | |
for strategy in methods: | |
fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(6.4*3, 5.2)) | |
for INPUT in input_sizes: | |
memory_before_k2 = np.array(list(results_before_mistral[strategy][INPUT].values())) | |
toks_before_k2 = np.array(list(results_before_mistral[strategy][INPUT].keys())) | |
sorting = np.argsort(toks_before_k2) | |
memory_before_k2 = memory_before_k2[sorting] | |
toks_before_k2 = toks_before_k2[sorting] | |
memory_after_k2 = np.array(list(results_after_mistral[strategy][INPUT].values())) | |
toks_after_k2 = np.array(list(results_after_mistral[strategy][INPUT].keys())) | |
sorting = np.argsort(toks_after_k2) | |
memory_after_k2 = memory_after_k2[sorting] | |
memory_before_k4 = np.array(list(results_before_mistral_top_k_4[strategy][INPUT].values())) | |
toks_before_k4 = np.array(list(results_before_mistral_top_k_4[strategy][INPUT].keys())) | |
sorting = np.argsort(toks_before_k4) | |
memory_before_k4 = memory_before_k4[sorting] | |
toks_before_k4 = toks_before_k4[sorting] | |
memory_after_k4 = np.array(list(results_after_mistral_top_k_4[strategy][INPUT].values())) | |
toks_after_k4 = np.array(list(results_after_mistral_top_k_4[strategy][INPUT].keys())) | |
sorting = np.argsort(toks_after_k4) | |
memory_after_k4 = memory_after_k4[sorting] | |
min_ = min(len(memory_before_k4), len(memory_before_k4)) | |
toks_before_k4 = toks_before_k4[:min_] | |
memory_before_k4 = memory_before_k4[:min_] | |
memory_after_k4 = memory_after_k4[:min_] | |
memory_before_k8 = np.array(list(results_before_mistral_top_k_8[strategy][INPUT].values())) | |
toks_before_k8 = np.array(list(results_before_mistral_top_k_8[strategy][INPUT].keys())) | |
sorting = np.argsort(toks_before_k8) | |
memory_before_k8 = memory_before_k8[sorting] | |
toks_before_k8 = toks_before_k8[sorting] | |
memory_after_k8 = np.array(list(results_after_mistral_top_k_8[strategy][INPUT].values())) | |
toks_after_k8 = np.array(list(results_after_mistral_top_k_8[strategy][INPUT].keys())) | |
sorting = np.argsort(toks_after_k8) | |
memory_after_k8 = memory_after_k8[sorting] | |
min_ = min(len(memory_before_k8), len(memory_before_k8)) | |
toks_before_k8 = toks_before_k8[:min_] | |
memory_before_k8 = memory_before_k8[:min_] | |
memory_after_k8 = memory_after_k8[:min_] | |
axes[0].plot(toks_before_k2, memory_before_k2 / memory_after_k2, label=f'Input size = {INPUT}') | |
axes[1].plot(toks_before_k4, memory_before_k4 / memory_after_k4, label=f'Input size = {INPUT}') | |
axes[2].plot(toks_before_k8, memory_before_k8 / memory_after_k8, label=f'Input size = {INPUT}') | |
axes[0].set_xlabel('New tokens generated') | |
axes[0].set_ylabel('Memory before / Memory after') | |
axes[1].set_xlabel('New tokens generated') | |
axes[1].set_ylabel('Memory before / Memory after') | |
axes[2].set_xlabel('New tokens generated') | |
axes[2].set_ylabel('Memory before / Memory after') | |
axes[0].set_title('Top k = 2') | |
axes[1].set_title('Top k = 4') | |
axes[2].set_title('Top k = 8') | |
axes[0].grid() | |
axes[1].grid() | |
axes[2].grid() | |
axes[0].legend() | |
axes[1].legend() | |
axes[2].legend() | |
plt.suptitle('Mistral') | |
fig.savefig('contrastive_top_k.pdf', bbox_inches='tight') | |
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# BENCHMARK 1 | |
# INITIAL BENCHMARK WHEN OPENING THE PR | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import time | |
from tqdm import tqdm | |
import numpy as np | |
import json | |
import gc | |
# Random generated text | |
LARGE_TEXT = """Title: Monkeys: Nature's Pranksters, Social Geniuses, and Ecological Wonders | |
Introduction | |
Monkeys, the charismatic and diverse members of the primate order, have long held a special place in the annals of our fascination with the animal kingdom. With their playful antics, astonishing intelligence, and complex social structures, they serve as a source of both joy and profound scientific inquiry. In this comprehensive exploration, we embark on a deep dive into the world of monkeys, spanning their evolutionary history, classifications, ecological roles, social dynamics, communication methods, and the pressing need for conservation. These captivating creatures offer insights into the intricacies of the natural world, our own evolutionary heritage, and the urgent importance of preserving biodiversity. | |
I. Evolutionary Origins | |
To understand the world of monkeys, we must embark on a journey through their evolutionary past, a tapestry that stretches back millions of years. Monkeys are part of the grand order of Primates, and their lineage is interwoven with the broader history of these remarkable mammals. | |
A. Primate Origins | |
The story of primates, including monkeys, begins around 60 million years ago. At that time, the world was a vastly different place, dominated by the reign of dinosaurs. It was during this period of Earth's history that the first primates, known as prosimians, emerged. These small, tree-dwelling mammals exhibited several characteristics that would become hallmarks of all primates: grasping hands and feet, forward-facing eyes for stereoscopic vision, and an enlarged brain relative to body size. These adaptations suited them for life in the trees, where they foraged for insects and fruits. | |
B. The Emergence of Monkeys | |
Around 35 million years ago, a significant split occurred within the primate family tree, leading to the emergence of two major groups: New World monkeys (Platyrrhini) and Old World monkeys (Catarrhini). This evolutionary divergence set in motion a cascade of adaptations that would result in the striking diversity of monkeys we see today. | |
The division between New World and Old World monkeys was not merely a matter of geographical separation but also marked significant differences in physical traits and behaviors. New World monkeys, found in Central and South America, are characterized by their prehensile tails and a variety of adaptations that allow them to thrive in the lush forests of the Americas. Old World monkeys, on the other hand, are residents of Africa, Asia, and parts of Gibraltar, and they have developed their own unique set of features to suit the diverse environments they inhabit. | |
II. Classification and Diversity""" | |
N_repeat = 5 | |
torch.manual_seed(1) | |
def main(model_name, dtype, restrict_context_size, use_fast): | |
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='flash_attention_2', | |
torch_dtype=dtype, low_cpu_mem_usage=True).cuda(1) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=use_fast) | |
inputs = tokenizer.encode(LARGE_TEXT*10, return_tensors='pt').to(device=1) | |
# Useless in our case but silences warning | |
eos_token_id = model.generation_config.eos_token_id | |
if hasattr(eos_token_id, '__len__') and hasattr(eos_token_id, '__getitem__'): | |
eos_token_id = eos_token_id[0] | |
if restrict_context_size: | |
N_toks = [500, 1000, 2000, 3000, 3500] | |
else: | |
N_toks = [500, 1000, 2000, 3000, 4000, 5000, 6000, 7000] | |
results_time = {} | |
results_memory = {} | |
filename = 'patched_' + model_name.split('/', 1)[1] + '_batch_size_1_input_size_300' | |
for N_tok in tqdm(N_toks, desc='New tokens'): | |
times = [] | |
memory = [] | |
for i in tqdm(range(N_repeat), leave=False): | |
torch.cuda.reset_peak_memory_stats(1) | |
actual_peak = torch.cuda.max_memory_allocated(1) / 1024**3 | |
t0 = time.time() | |
out = model.generate(inputs[:, :300], max_new_tokens=N_tok, min_new_tokens=N_tok, do_sample=True, | |
num_return_sequences=1, temperature=0.8, top_k=50, top_p=0.9, | |
pad_token_id=eos_token_id) | |
dt = time.time() - t0 | |
memory_used = (torch.cuda.max_memory_allocated(1) / 1024**3) - actual_peak | |
# Verify that we did generate the correct number of tokens | |
assert out.shape[-1] == 300 + N_tok | |
times.append(dt) | |
memory.append(memory_used) | |
results_time[N_tok] = np.mean(times) | |
results_memory[N_tok] = np.mean(memory) | |
# print(f'New tokens: {N_tok} --- {results_time[N_tok]:.3e} s --- {results_memory[N_tok]:.2f} GiB') | |
# Rewrite results at each iteration (so that we keep them if OOM) | |
with open(filename + '_memory.json', 'w') as fp: | |
json.dump(results_memory, fp, indent='\t') | |
with open(filename + '_time.json', 'w') as fp: | |
json.dump(results_time, fp, indent='\t') | |
batch_sizes = [2, 4, 6, 8, 10] | |
results_time = {} | |
results_memory = {} | |
filename = 'patched_' + model_name.split('/', 1)[1] + '_input_size_300_new_tokens_2000' | |
for batch_size in tqdm(batch_sizes, desc='Batch sizes'): | |
times = [] | |
memory = [] | |
for i in tqdm(range(N_repeat), leave=False): | |
torch.cuda.reset_peak_memory_stats(1) | |
actual_peak = torch.cuda.max_memory_allocated(1) / 1024**3 | |
t0 = time.time() | |
out = model.generate(inputs[:, :300], max_new_tokens=2000, min_new_tokens=2000, do_sample=True, | |
num_return_sequences=batch_size, temperature=0.8, top_k=50, top_p=0.9, | |
pad_token_id=eos_token_id) | |
dt = time.time() - t0 | |
memory_used = (torch.cuda.max_memory_allocated(1) / 1024**3) - actual_peak | |
# Verify that we did generate the correct number of tokens | |
assert out.shape[-1] == 2300 | |
assert out.shape[0] == batch_size | |
times.append(dt) | |
memory.append(memory_used) | |
results_time[batch_size] = np.mean(times) | |
results_memory[batch_size] = np.mean(memory) | |
# print(f'Batch size: {batch_size} --- {results_time[N_tok]:.3e} s --- {results_memory[N_tok]:.2f} GiB') | |
# Rewrite results at each iteration (so that we keep them if OOM) | |
with open(filename + '_memory.json', 'w') as fp: | |
json.dump(results_memory, fp, indent='\t') | |
with open(filename + '_time.json', 'w') as fp: | |
json.dump(results_time, fp, indent='\t') | |
del model | |
gc.collect() | |
if __name__ == '__main__': | |
model_names = ['mistralai/Mistral-7B-v0.1', 'meta-llama/Llama-2-7b-hf', 'meta-llama/Meta-Llama-3-8B'] | |
dtypes = [torch.bfloat16, torch.float16, torch.bfloat16] | |
restrict_contexts = [False, True, False] | |
fasts = [False]*3 | |
for name, dtype, restrict_context_size, use_fast in zip(model_names, dtypes, restrict_contexts, fasts): | |
try: | |
main(name, dtype, restrict_context_size, use_fast) | |
except RuntimeError as e: | |
if isinstance(e, torch.cuda.OutOfMemoryError): | |
pass | |
else: | |
raise e | |
gc.collect() | |
torch.cuda.empty_cache() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# BENCHMARK 3 | |
# LAST BENCHMARK (RATIOS) SHOWN IN THE PR | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import warnings | |
import json | |
import gc | |
warnings.filterwarnings("ignore") | |
LARGE_TEXT = """Title: Monkeys: Nature's Pranksters, Social Geniuses, and Ecological Wonders | |
Introduction | |
Monkeys, the charismatic and diverse members of the primate order, have long held a special place in the annals of our fascination with the animal kingdom. With their playful antics, astonishing intelligence, and complex social structures, they serve as a source of both joy and profound scientific inquiry. In this comprehensive exploration, we embark on a deep dive into the world of monkeys, spanning their evolutionary history, classifications, ecological roles, social dynamics, communication methods, and the pressing need for conservation. These captivating creatures offer insights into the intricacies of the natural world, our own evolutionary heritage, and the urgent importance of preserving biodiversity. | |
I. Evolutionary Origins | |
To understand the world of monkeys, we must embark on a journey through their evolutionary past, a tapestry that stretches back millions of years. Monkeys are part of the grand order of Primates, and their lineage is interwoven with the broader history of these remarkable mammals. | |
A. Primate Origins | |
The story of primates, including monkeys, begins around 60 million years ago. At that time, the world was a vastly different place, dominated by the reign of dinosaurs. It was during this period of Earth's history that the first primates, known as prosimians, emerged. These small, tree-dwelling mammals exhibited several characteristics that would become hallmarks of all primates: grasping hands and feet, forward-facing eyes for stereoscopic vision, and an enlarged brain relative to body size. These adaptations suited them for life in the trees, where they foraged for insects and fruits. | |
B. The Emergence of Monkeys | |
Around 35 million years ago, a significant split occurred within the primate family tree, leading to the emergence of two major groups: New World monkeys (Platyrrhini) and Old World monkeys (Catarrhini). This evolutionary divergence set in motion a cascade of adaptations that would result in the striking diversity of monkeys we see today. | |
The division between New World and Old World monkeys was not merely a matter of geographical separation but also marked significant differences in physical traits and behaviors. New World monkeys, found in Central and South America, are characterized by their prehensile tails and a variety of adaptations that allow them to thrive in the lush forests of the Americas. Old World monkeys, on the other hand, are residents of Africa, Asia, and parts of Gibraltar, and they have developed their own unique set of features to suit the diverse environments they inhabit. | |
II. Classification and Diversity""" | |
matcher = { | |
'contrastive search': { | |
'do_sample': False, | |
'top_k': 2, | |
'penalty_alpha': 0.5, | |
}, | |
'greedy': { | |
'do_sample': False, | |
}, | |
'sample': { | |
'do_sample': True, | |
'top_k': 50, | |
}, | |
'group beam search': { | |
'num_beams': 2, | |
'num_beam_groups': 2, | |
'do_sample': False, | |
'diversity_penalty': 0.5, | |
}, | |
'beam sample': { | |
'num_beams': 2, | |
'do_sample': True, | |
}, | |
'beam search': { | |
'num_beams': 2, | |
'do_sample': False, | |
}, | |
} | |
def memory_usage(past_key_values): | |
"""Recursively compute the memory footprint of past key values (in bytes). | |
""" | |
if isinstance(past_key_values, torch.Tensor): | |
return past_key_values.nelement() * past_key_values.element_size() | |
elif isinstance(past_key_values[0], torch.Tensor): | |
return sum([x.nelement() * x.element_size() for x in past_key_values]) | |
else: | |
return sum([memory_usage(x) for x in past_key_values]) | |
model_names = ['mistralai/Mistral-7B-v0.1', 'meta-llama/Llama-2-7b-hf'] | |
dtypes = [torch.bfloat16, torch.float16] | |
torch.manual_seed(1) | |
methods = ['sample', 'beam sample', 'contrastive search'] | |
input_sizes = [300, 1000] | |
N_toks = [2, 100, 250, 500, 750, 1000, 1250, 1500, 2000, 3000, 4000] | |
for model_name, dtype in zip(model_names, dtypes): | |
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation='flash_attention_2', | |
torch_dtype=dtype, low_cpu_mem_usage=True).cuda(0) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
inputs = tokenizer.encode(LARGE_TEXT*300, return_tensors='pt').to(device=0) | |
eos_token_id = model.generation_config.eos_token_id | |
if hasattr(eos_token_id, '__len__') and hasattr(eos_token_id, '__getitem__'): | |
eos_token_id = eos_token_id[0] | |
results_memory = {} | |
filename = 'patched_' + model_name.split('/', 1)[1] + '_ratios' | |
for strategy in methods: | |
input_params = matcher[strategy] | |
results_memory[strategy] = {} | |
for INPUT in input_sizes: | |
results_memory[strategy][INPUT] = {} | |
for TOKS in N_toks: | |
try: | |
torch.cuda.reset_peak_memory_stats(0) | |
actual_peak = torch.cuda.max_memory_allocated(0) / 1024**3 | |
out = model.generate(inputs[:, :INPUT], max_new_tokens=TOKS, min_new_tokens=TOKS, pad_token_id=eos_token_id, | |
**input_params) | |
memory_used = (torch.cuda.max_memory_allocated(0) / 1024**3) - actual_peak | |
results_memory[strategy][INPUT][TOKS] = memory_used | |
except RuntimeError as e: | |
if isinstance(e, torch.cuda.OutOfMemoryError): | |
break | |
else: | |
raise e | |
# Rewrite results at each iteration (so that we keep them if OOM) | |
with open(filename + '_memory.json', 'w') as fp: | |
json.dump(results_memory, fp, indent='\t') | |
# Do more contrastive tries with bigger top_ks | |
if model_name == 'mistralai/Mistral-7B-v0.1': | |
for k in [4, 8]: | |
input_params = {'do_sample': False, 'top_k': k, 'penalty_alpha': 0.5} | |
results_memory = {'contrastive search': {}} | |
for INPUT in input_sizes: | |
results_memory['contrastive search'][INPUT] = {} | |
for TOKS in N_toks: | |
try: | |
torch.cuda.reset_peak_memory_stats(0) | |
actual_peak = torch.cuda.max_memory_allocated(0) / 1024**3 | |
out = model.generate(inputs[:, :INPUT], max_new_tokens=TOKS, min_new_tokens=TOKS, pad_token_id=eos_token_id, | |
**input_params) | |
memory_used = (torch.cuda.max_memory_allocated(0) / 1024**3) - actual_peak | |
results_memory['contrastive search'][INPUT][TOKS] = memory_used | |
except RuntimeError as e: | |
if isinstance(e, torch.cuda.OutOfMemoryError): | |
break | |
else: | |
raise e | |
# Rewrite results at each iteration (so that we keep them if OOM) | |
with open(filename + f'_contrastive_k_{k}_memory.json', 'w') as fp: | |
json.dump(results_memory, fp, indent='\t') | |
del model | |
gc.collect() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment