Created
December 20, 2024 20:12
-
-
Save komodoooo/5e1c21eaf308518458dc017edcbfed30 to your computer and use it in GitHub Desktop.
Chat interface with the Qwen 2.5 LLM built with Gradio
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" # Qwen/Qwen2.5-7B-Instruct, Qwen/Qwen2.5-32B-Instruct | |
class QwenInterface: | |
def __init__(self): | |
self.messages = [{"role": "system", "content": "You're a usefull assistant who answers to everything."}] | |
self.model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto", device_map="auto") | |
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
def get_response(self, prompt): | |
self.messages.append({"role": "user", "content": prompt}) | |
text = self.tokenizer.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) | |
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) | |
generated_ids = self.model.generate(**model_inputs, max_new_tokens=512) | |
generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)] | |
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
self.messages.append({"role": "assistant", "content": response}) | |
return response | |
chat = QwenInterface() | |
with gr.Blocks(title="Chat with Qwen", fill_height=True) as interface: | |
gr.Markdown(f"# {MODEL_ID.split('/')[1].replace('-', ' ')}") | |
gr.ChatInterface(lambda message, _: chat.get_response(message), autofocus=True, type="messages") | |
interface.launch() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment