Skip to content

Instantly share code, notes, and snippets.

@alvarobartt
Last active August 20, 2024 07:09
Show Gist options
  • Save alvarobartt/7e77a9d05ac1a8c65f03fd0b567f54d0 to your computer and use it in GitHub Desktop.
Save alvarobartt/7e77a9d05ac1a8c65f03fd0b567f54d0 to your computer and use it in GitHub Desktop.
Simple script on using `torch` for text-generation with a `transformers` model one token at a time on MPS.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# Define the model name
model_name = "HuggingFaceTB/SmolLM-1.7B-Instruct"
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
# Set the model to evaluation mode
model.eval()
# Check if MPS is available and if not, use CPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model.to(device)
# Input conversation for text generation
conversation = [
{
"role": "system",
"content": "You are a concise and accurate assistant, don't explain, just answer.",
},
{"role": "user", "content": "What's 2 + 2?"},
]
# Encode the input conversation
input_ids = tokenizer.apply_chat_template(
conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to(device)
# Streaming token generation
max_new_tokens = 256 # Maximum number of new tokens to generate
generated_tokens = input_ids
# Generate tokens one by one
for _ in range(max_new_tokens):
# Generate the next token
with torch.no_grad():
outputs = model(generated_tokens)
next_token_logits = outputs.logits[:, -1, :]
# Apply sampling method
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)
# Append the new token to the sequence
generated_tokens = torch.cat((generated_tokens, next_token), dim=-1)
# Decode the latest token and print it
generated_text = tokenizer.decode(
next_token.squeeze().tolist(), skip_special_tokens=True
)
print(generated_text, end="", flush=True)
# Stop if the model predicts the end-of-sequence token
if next_token.squeeze().item() == tokenizer.eos_token_id:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment