Last active
June 30, 2023 23:55
-
-
Save lucataco/013560e99b89f64e346ff9ed803a9699 to your computer and use it in GitHub Desktop.
Falcon7B HF speed test
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import transformers | |
import torch | |
import time | |
model = "tiiuae/falcon-7b" | |
tokenizer = AutoTokenizer.from_pretrained(model) | |
pipeline = transformers.pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map="auto", | |
) | |
text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:" | |
max_length = 200 | |
do_sample = True | |
top_k = 10 | |
num_return_sequences = 1 | |
eos_token_id = tokenizer.eos_token_id | |
start_time = time.time() | |
sequences = pipeline( | |
text, | |
max_length=max_length, | |
do_sample=do_sample, | |
top_k=top_k, | |
num_return_sequences=num_return_sequences, | |
eos_token_id=eos_token_id, | |
) | |
end_time = time.time() | |
for seq in sequences: | |
print(f"Result: {seq['generated_text']}") | |
num_tokens = sum(len(seq['generated_text'].split()) for seq in sequences) | |
duration = end_time - start_time | |
tokens_per_second = num_tokens / duration | |
print(f"Number of tokens generated per second: {tokens_per_second:.2f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment