Created
April 28, 2024 15:30
-
-
Save f0ster/26fd9f2c0e28fbfca6c3f61e86567c3e to your computer and use it in GitHub Desktop.
Running mistralai mixtral locally
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
def load_model_and_tokenizer(model_id): | |
""" | |
Load the tokenizer and model based on the specified model ID. | |
Model is set to use float16 for computation to reduce memory usage and improve performance. | |
""" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") | |
return tokenizer, model | |
def prepare_input(tokenizer, messages): | |
""" | |
Convert the list of message dictionaries into model-ready input IDs. | |
""" | |
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to("cuda") | |
return input_ids | |
def generate_response(model, input_ids): | |
""" | |
Generate a response using the model and the provided input IDs. | |
""" | |
outputs = model.generate(input_ids, max_new_tokens=20) | |
return outputs | |
def main(): | |
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1" | |
messages = [ | |
{"role": "user", "content": "What is your favourite condiment?"}, | |
{"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"}, | |
{"role": "user", "content": "Do you have mayonnaise recipes?"} | |
] | |
# Load model and tokenizer | |
tokenizer, model = load_model_and_tokenizer(model_id) | |
# Measure the start time | |
start_time = time.time() | |
# Prepare input for the model | |
input_ids = prepare_input(tokenizer, messages) | |
# Generate output from the model | |
outputs = generate_response(model, input_ids) | |
# Measure the end time | |
end_time = time.time() | |
# Print the elapsed time and the decoded output | |
print("Elapsed time: {:.2f} seconds".format(end_time - start_time)) | |
print("Generated text:", tokenizer.decode(outputs[0], skip_special_tokens=True)) | |
if __name__ == "__main__": | |
main() |
Author
f0ster
commented
Apr 28, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment