Skip to content

Instantly share code, notes, and snippets.

@xXWarMachineRoXx
Created February 29, 2024 12:40
Show Gist options
  • Save xXWarMachineRoXx/2e6210a29c50aada4a5614dd38f0839a to your computer and use it in GitHub Desktop.
Save xXWarMachineRoXx/2e6210a29c50aada4a5614dd38f0839a to your computer and use it in GitHub Desktop.
Google Gemma 7b inference
import time
from transformers import AutoTokenizer, AutoModelForCausalLM, logging
from huggingface_hub import login
# Access token setup
ACCESS_TOKEN_READ = "token"
login(token=ACCESS_TOKEN_READ)
# Model setup
tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b", device_map="auto")
# Input text
input_text = "Write me a poem about Quantum computing."
# Tokenization
input_ids = tokenizer(input_text, return_tensors="pt", max_length=50, truncation=True).to("cuda") # Setting max_length and truncation
# Silence Transformers warnings
logging.set_verbosity_error() # Silence warnings
# Timing execution
start_time = time.time()
# Model inference
outputs = model.generate(**input_ids, max_length=100) # Setting max_length for generation
# Calculating execution time
execution_time = time.time() - start_time
# Decoding and printing output
print(tokenizer.decode(outputs[0]))
# Printing execution time
print(f"Execution time: {execution_time} seconds")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment