Last active
August 20, 2024 07:09
-
-
Save alvarobartt/7e77a9d05ac1a8c65f03fd0b567f54d0 to your computer and use it in GitHub Desktop.
Simple script on using `torch` for text-generation with a `transformers` model one token at a time on MPS.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Define the model name | |
model_name = "HuggingFaceTB/SmolLM-1.7B-Instruct" | |
# Load the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) | |
# Set the model to evaluation mode | |
model.eval() | |
# Check if MPS is available and if not, use CPU | |
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") | |
model.to(device) | |
# Input conversation for text generation | |
conversation = [ | |
{ | |
"role": "system", | |
"content": "You are a concise and accurate assistant, don't explain, just answer.", | |
}, | |
{"role": "user", "content": "What's 2 + 2?"}, | |
] | |
# Encode the input conversation | |
input_ids = tokenizer.apply_chat_template( | |
conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt" | |
).to(device) | |
# Streaming token generation | |
max_new_tokens = 256 # Maximum number of new tokens to generate | |
generated_tokens = input_ids | |
# Generate tokens one by one | |
for _ in range(max_new_tokens): | |
# Generate the next token | |
with torch.no_grad(): | |
outputs = model(generated_tokens) | |
next_token_logits = outputs.logits[:, -1, :] | |
# Apply sampling method | |
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0) | |
# Append the new token to the sequence | |
generated_tokens = torch.cat((generated_tokens, next_token), dim=-1) | |
# Decode the latest token and print it | |
generated_text = tokenizer.decode( | |
next_token.squeeze().tolist(), skip_special_tokens=True | |
) | |
print(generated_text, end="", flush=True) | |
# Stop if the model predicts the end-of-sequence token | |
if next_token.squeeze().item() == tokenizer.eos_token_id: | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment