Created
February 29, 2024 12:40
-
-
Save xXWarMachineRoXx/2e6210a29c50aada4a5614dd38f0839a to your computer and use it in GitHub Desktop.
Google Gemma 7b inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import time | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, logging | |
| from huggingface_hub import login | |
| # Access token setup | |
| ACCESS_TOKEN_READ = "token" | |
| login(token=ACCESS_TOKEN_READ) | |
| # Model setup | |
| tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") | |
| model = AutoModelForCausalLM.from_pretrained("google/gemma-7b", device_map="auto") | |
| # Input text | |
| input_text = "Write me a poem about Quantum computing." | |
| # Tokenization | |
| input_ids = tokenizer(input_text, return_tensors="pt", max_length=50, truncation=True).to("cuda") # Setting max_length and truncation | |
| # Silence Transformers warnings | |
| logging.set_verbosity_error() # Silence warnings | |
| # Timing execution | |
| start_time = time.time() | |
| # Model inference | |
| outputs = model.generate(**input_ids, max_length=100) # Setting max_length for generation | |
| # Calculating execution time | |
| execution_time = time.time() - start_time | |
| # Decoding and printing output | |
| print(tokenizer.decode(outputs[0])) | |
| # Printing execution time | |
| print(f"Execution time: {execution_time} seconds") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment