Last active
April 25, 2024 15:20
-
-
Save laksjdjf/762bb2efa201261b67f8740a14127028 to your computer and use it in GitHub Desktop.
デフォルト設定はcommand -r 用
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
import json | |
import requests | |
import argparse | |
from dataclasses import dataclass | |
############### utils ############### | |
BAN_TOKENS = ["<|END_OF_TURN_TOKEN|>"] # command -r 用の回避トークン | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--share", action="store_true") | |
parser.add_argument("--user_avatar", "-ua", help="path to user image") | |
parser.add_argument("--chatbot_avatar", "-ca", help="path to chatbot image") | |
args = parser.parse_args() | |
@dataclass | |
class ChatConfig: | |
url: str = "http://localhost:8080" | |
system: str = "あなたは優秀なアシスタントです。" | |
temperature: float = 0.8 | |
top_p: float = 0.9 | |
max_tokens: int = 256 | |
repeat_penalty: float = 1.0 | |
ignore_eos: bool = False | |
system_template: str = "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{system}<|END_OF_TURN_TOKEN|>" | |
user_template: str = "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{message}<|END_OF_TURN_TOKEN|>" | |
chatbot_template: str = "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{response}<|END_OF_TURN_TOKEN|>" | |
def __call__( | |
self, | |
url: str, | |
system: str, | |
temperature: float, | |
top_p: float, | |
max_tokens: int, | |
repeat_penalty: float, | |
ignore_eos: bool, | |
system_template: str, | |
user_template: str, | |
chatbot_template: str, | |
): | |
self.url = url | |
self.system = system | |
self.temperature = temperature | |
self.top_p = top_p | |
self.max_tokens = max_tokens | |
self.repeat_penalty = repeat_penalty | |
self.ignore_eos = ignore_eos | |
self.system_template = system_template | |
self.user_template = user_template | |
self.chatbot_template = chatbot_template | |
return self.__repr__() | |
config = ChatConfig() | |
def get_post(prompt): | |
global config | |
url = config.url + "/completion" | |
post = requests.post( | |
url = url, | |
json = { | |
"prompt": prompt, | |
"temperature": config.temperature, | |
"top_p": config.top_p, | |
"n_predict": config.max_tokens, | |
"repeat_penalty": config.repeat_penalty, | |
"ignore_eos": config.ignore_eos, | |
"stream": True | |
}, | |
stream = True, | |
) | |
return post | |
############### chat interface ############### | |
def chat(message, history): | |
global config | |
prompt = config.system_template.format(system=config.system) | |
for mes, res in history: | |
prompt += config.user_template.format(message=mes) | |
prompt += config.chatbot_template.format(response=res) | |
prompt += config.user_template.format(message=message) | |
prompt += config.chatbot_template.split("{response}")[0] | |
post = get_post(prompt) | |
response = "" | |
for chunk in post.iter_lines(): | |
try: | |
content = json.loads(chunk.decode()[6:]).get('content') | |
if content: | |
response += content if content not in BAN_TOKENS else "" # ちょっと無理やり感ある | |
yield "", history + [(message, response)] | |
except: | |
pass | |
def continue_chat(history): | |
global config | |
prompt = config.system_template.format(system=config.system) | |
last = len(history) - 1 | |
for i, (mes, res) in enumerate(history): | |
prompt += config.user_template.format(message=mes) | |
if i < last: | |
prompt += config.chatbot_template.format(response=res) | |
else: | |
prompt += config.chatbot_template.split("{response}")[0] + res | |
post = get_post(prompt) | |
response = res | |
for chunk in post.iter_lines(): | |
try: | |
content = json.loads(chunk.decode()[6:]).get('content') | |
if content: | |
response += content if content not in BAN_TOKENS else "" | |
yield history[:-1] + [(mes, response)] | |
except: | |
pass | |
else: | |
yield history[:-1] + [(mes, response)] | |
def undo_history(history): | |
pre_message = history[-1][0] if history else "" | |
return pre_message, history[:-1] | |
with gr.Blocks() as chat_interface: | |
gr.Markdown("Chat用タブです。Shift+Enterで送信できます。") | |
chatbot = gr.Chatbot(avatar_images=[args.user_avatar, args.chatbot_avatar]) | |
message = gr.Textbox(label="Message", placeholder="Enter your message here...", lines=3) | |
with gr.Row(): | |
continue_button = gr.Button("Continue") | |
undo_button_chat = gr.Button("Undo") | |
clear = gr.ClearButton([message, chatbot]) | |
message.submit(chat, inputs=[message, chatbot], outputs=[message, chatbot]) | |
continue_button.click(continue_chat, inputs=[chatbot], outputs=[chatbot]) | |
undo_button_chat.click(undo_history, inputs=[chatbot], outputs=[message, chatbot]) | |
############### completion interface ############### | |
pre_prompt = "" | |
def text_generator(prompt): | |
global pre_prompt | |
pre_prompt = prompt | |
post = get_post(prompt) | |
for chunk in post.iter_lines(): | |
try: | |
content = json.loads(chunk.decode()[6:]).get('content') | |
if content: | |
prompt += content | |
yield gr.update(value=prompt, autoscroll=True), gr.update(visible=False) | |
except: | |
pass | |
else: | |
yield gr.update(value=prompt, autoscroll=True), gr.update(visible=True) | |
def undo_prompt(): | |
return gr.update(value=pre_prompt, autoscroll=True) | |
def set_default(message): | |
prompt = config.system_template.format(system=config.system) | |
prompt += config.user_template.format(message=message) | |
prompt += config.chatbot_template.split("{response}")[0] | |
return gr.update(value=prompt, autoscroll=True) | |
def add_input(prompt, message): | |
prompt += config.user_template.format(message=message) | |
prompt += config.chatbot_template.split("{response}")[0] | |
return gr.update(value=prompt, autoscroll=True) | |
with gr.Blocks() as completion_interface: | |
gr.Markdown("Completion用タブです。設定したテンプレートは無視されます。Defaultボタンを押すか、自分で書いてください。") | |
io_textbox = gr.Textbox( | |
label="Input/Output", | |
placeholder="Enter your prompt here...", | |
interactive=True, | |
elem_classes=["prompt"], | |
lines=5, | |
) | |
generate_button = gr.Button("Generate", variant="primary") | |
with gr.Row(): | |
undo_button = gr.Button("Undo") | |
default_button = gr.Button("Default") | |
add_button = gr.Button("Add Input") | |
message_textbox = gr.Textbox(label="message", value="INPUT", lines=3) | |
generate_button.click( | |
text_generator, | |
inputs=[io_textbox], | |
outputs=[io_textbox, generate_button], | |
show_progress=True, | |
) | |
undo_button.click( | |
undo_prompt, | |
inputs=None, | |
outputs=[io_textbox], | |
) | |
default_button.click( | |
set_default, | |
inputs=[message_textbox], | |
outputs=[io_textbox], | |
) | |
add_button.click( | |
add_input, | |
inputs=[io_textbox, message_textbox], | |
outputs=[io_textbox], | |
) | |
############### setting interface ############### | |
with gr.Blocks() as setting_interface: | |
gr.Markdown("設定用タブです。") | |
url = gr.Textbox(label="url", value="http://localhost:8080") | |
system = gr.Textbox(label="system", value="あなたは優秀なアシスタントです。", lines=2) | |
temperature = gr.Slider(minimum=0, maximum=2, step=0.01, value=0.8, label="temperature") | |
top_p = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.9, label="top_p") | |
max_tokens = gr.Slider(minimum=1, maximum=65536, step=1, value=256, label="max_tokens") | |
repeat_penalty = gr.Slider(minimum=1.0, maximum=2.0, step=0.01, value=1.0, label="repeat_penalty") | |
ignore_eos = gr.Checkbox(label="ignore_eos", value=False) | |
system_template = gr.Textbox(label="system_template", value="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{system}<|END_OF_TURN_TOKEN|>", lines=3) | |
user_template = gr.Textbox(label="user_template", value="<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{message}<|END_OF_TURN_TOKEN|>", lines=3) | |
chatbot_template = gr.Textbox(label="chatbot_template", value="<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{response}<|END_OF_TURN_TOKEN|>", lines=3) | |
output = gr.Textbox(label="output", interactive=False) | |
setting_list = [url, system, temperature, top_p, max_tokens, repeat_penalty, ignore_eos, system_template, user_template, chatbot_template] | |
for setting in setting_list: | |
setting.change(config, inputs=setting_list, outputs=output) | |
demo = gr.TabbedInterface([chat_interface, completion_interface, setting_interface], ["Chat", "Completion", "Setting"], theme=gr.themes.Base()) | |
demo.launch(share = args.share) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment