Created
May 22, 2023 05:15
-
-
Save mutaguchi/d36ea7156dfff5aad4af31ff035c424d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# based on StableLM chat | |
# https://huggingface.co/spaces/stabilityai/stablelm-tuned-alpha-chat | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
import time | |
import numpy as np | |
from torch.nn import functional as F | |
import os | |
from threading import Thread | |
print(f"Starting to load the model to memory") | |
model_name = "rinna/japanese-gpt-neox-3.6b-instruction-sft" | |
m = AutoModelForCausalLM.from_pretrained( | |
model_name,device_map='auto', torch_dtype=torch.float16) | |
tok = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
print(f"Sucessfully loaded the model to the memory") | |
start_message = "" | |
def user(message, history): | |
# Append the user's message to the conversation history | |
return "", history + [[message, ""]] | |
def chat(curr_system_message, history): | |
# Construct the input message string for the model by concatenating the current system message and conversation history | |
curr_system_message = "" | |
messages = curr_system_message + \ | |
"<NL>".join(["<NL>".join(["ユーザー: "+item[0], "システム: "+item[1]]) | |
for item in history]) | |
# Tokenize the messages string | |
token_ids = tok.encode(messages, add_special_tokens=False, return_tensors="pt") | |
streamer = TextIteratorStreamer(tok, skip_prompt=True) | |
generation_args = [token_ids.to(m.device)] | |
generation_kwargs = dict( | |
streamer=streamer, | |
do_sample=True, | |
max_new_tokens=256, | |
temperature=0.7, | |
pad_token_id=tok.pad_token_id, | |
bos_token_id=tok.bos_token_id, | |
eos_token_id=tok.eos_token_id | |
) | |
t = Thread(target=m.generate, args=generation_args, kwargs=generation_kwargs) | |
t.start() | |
# print(history) | |
# Initialize an empty string to store the generated text | |
partial_text = "" | |
for new_text in streamer: | |
# print(new_text) | |
partial_text += new_text | |
history[-1][1] = partial_text | |
# Yield an empty string to cleanup the message textbox and the updated conversation history | |
yield history | |
return partial_text | |
with gr.Blocks() as demo: | |
# history = gr.State([]) | |
gr.Markdown("## Rinna japanese-gpt-neox-3.6b Chat") | |
chatbot = gr.Chatbot().style(height=500) | |
with gr.Row(): | |
with gr.Column(): | |
msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box", | |
show_label=False).style(container=False) | |
with gr.Column(): | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
stop = gr.Button("Stop") | |
clear = gr.Button("Clear") | |
system_msg = gr.Textbox( | |
start_message, label="System Message", interactive=False, visible=False) | |
submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( | |
fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True) | |
submit_click_event = submit.click(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( | |
fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True) | |
stop.click(fn=None, inputs=None, outputs=None, cancels=[ | |
submit_event, submit_click_event], queue=False) | |
clear.click(lambda: None, None, [chatbot], queue=False) | |
demo.queue(max_size=32, concurrency_count=2) | |
demo.launch() |
Author
mutaguchi
commented
May 22, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment