Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Last active April 25, 2024 15:20
Show Gist options
  • Save laksjdjf/762bb2efa201261b67f8740a14127028 to your computer and use it in GitHub Desktop.
Save laksjdjf/762bb2efa201261b67f8740a14127028 to your computer and use it in GitHub Desktop.
デフォルト設定はcommand -r 用
import gradio as gr
import json
import requests
import argparse
from dataclasses import dataclass
############### utils ###############
BAN_TOKENS = ["<|END_OF_TURN_TOKEN|>"] # command -r 用の回避トークン
parser = argparse.ArgumentParser()
parser.add_argument("--share", action="store_true")
parser.add_argument("--user_avatar", "-ua", help="path to user image")
parser.add_argument("--chatbot_avatar", "-ca", help="path to chatbot image")
args = parser.parse_args()
@dataclass
class ChatConfig:
url: str = "http://localhost:8080"
system: str = "あなたは優秀なアシスタントです。"
temperature: float = 0.8
top_p: float = 0.9
max_tokens: int = 256
repeat_penalty: float = 1.0
ignore_eos: bool = False
system_template: str = "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{system}<|END_OF_TURN_TOKEN|>"
user_template: str = "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{message}<|END_OF_TURN_TOKEN|>"
chatbot_template: str = "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{response}<|END_OF_TURN_TOKEN|>"
def __call__(
self,
url: str,
system: str,
temperature: float,
top_p: float,
max_tokens: int,
repeat_penalty: float,
ignore_eos: bool,
system_template: str,
user_template: str,
chatbot_template: str,
):
self.url = url
self.system = system
self.temperature = temperature
self.top_p = top_p
self.max_tokens = max_tokens
self.repeat_penalty = repeat_penalty
self.ignore_eos = ignore_eos
self.system_template = system_template
self.user_template = user_template
self.chatbot_template = chatbot_template
return self.__repr__()
config = ChatConfig()
def get_post(prompt):
global config
url = config.url + "/completion"
post = requests.post(
url = url,
json = {
"prompt": prompt,
"temperature": config.temperature,
"top_p": config.top_p,
"n_predict": config.max_tokens,
"repeat_penalty": config.repeat_penalty,
"ignore_eos": config.ignore_eos,
"stream": True
},
stream = True,
)
return post
############### chat interface ###############
def chat(message, history):
global config
prompt = config.system_template.format(system=config.system)
for mes, res in history:
prompt += config.user_template.format(message=mes)
prompt += config.chatbot_template.format(response=res)
prompt += config.user_template.format(message=message)
prompt += config.chatbot_template.split("{response}")[0]
post = get_post(prompt)
response = ""
for chunk in post.iter_lines():
try:
content = json.loads(chunk.decode()[6:]).get('content')
if content:
response += content if content not in BAN_TOKENS else "" # ちょっと無理やり感ある
yield "", history + [(message, response)]
except:
pass
def continue_chat(history):
global config
prompt = config.system_template.format(system=config.system)
last = len(history) - 1
for i, (mes, res) in enumerate(history):
prompt += config.user_template.format(message=mes)
if i < last:
prompt += config.chatbot_template.format(response=res)
else:
prompt += config.chatbot_template.split("{response}")[0] + res
post = get_post(prompt)
response = res
for chunk in post.iter_lines():
try:
content = json.loads(chunk.decode()[6:]).get('content')
if content:
response += content if content not in BAN_TOKENS else ""
yield history[:-1] + [(mes, response)]
except:
pass
else:
yield history[:-1] + [(mes, response)]
def undo_history(history):
pre_message = history[-1][0] if history else ""
return pre_message, history[:-1]
with gr.Blocks() as chat_interface:
gr.Markdown("Chat用タブです。Shift+Enterで送信できます。")
chatbot = gr.Chatbot(avatar_images=[args.user_avatar, args.chatbot_avatar])
message = gr.Textbox(label="Message", placeholder="Enter your message here...", lines=3)
with gr.Row():
continue_button = gr.Button("Continue")
undo_button_chat = gr.Button("Undo")
clear = gr.ClearButton([message, chatbot])
message.submit(chat, inputs=[message, chatbot], outputs=[message, chatbot])
continue_button.click(continue_chat, inputs=[chatbot], outputs=[chatbot])
undo_button_chat.click(undo_history, inputs=[chatbot], outputs=[message, chatbot])
############### completion interface ###############
pre_prompt = ""
def text_generator(prompt):
global pre_prompt
pre_prompt = prompt
post = get_post(prompt)
for chunk in post.iter_lines():
try:
content = json.loads(chunk.decode()[6:]).get('content')
if content:
prompt += content
yield gr.update(value=prompt, autoscroll=True), gr.update(visible=False)
except:
pass
else:
yield gr.update(value=prompt, autoscroll=True), gr.update(visible=True)
def undo_prompt():
return gr.update(value=pre_prompt, autoscroll=True)
def set_default(message):
prompt = config.system_template.format(system=config.system)
prompt += config.user_template.format(message=message)
prompt += config.chatbot_template.split("{response}")[0]
return gr.update(value=prompt, autoscroll=True)
def add_input(prompt, message):
prompt += config.user_template.format(message=message)
prompt += config.chatbot_template.split("{response}")[0]
return gr.update(value=prompt, autoscroll=True)
with gr.Blocks() as completion_interface:
gr.Markdown("Completion用タブです。設定したテンプレートは無視されます。Defaultボタンを押すか、自分で書いてください。")
io_textbox = gr.Textbox(
label="Input/Output",
placeholder="Enter your prompt here...",
interactive=True,
elem_classes=["prompt"],
lines=5,
)
generate_button = gr.Button("Generate", variant="primary")
with gr.Row():
undo_button = gr.Button("Undo")
default_button = gr.Button("Default")
add_button = gr.Button("Add Input")
message_textbox = gr.Textbox(label="message", value="INPUT", lines=3)
generate_button.click(
text_generator,
inputs=[io_textbox],
outputs=[io_textbox, generate_button],
show_progress=True,
)
undo_button.click(
undo_prompt,
inputs=None,
outputs=[io_textbox],
)
default_button.click(
set_default,
inputs=[message_textbox],
outputs=[io_textbox],
)
add_button.click(
add_input,
inputs=[io_textbox, message_textbox],
outputs=[io_textbox],
)
############### setting interface ###############
with gr.Blocks() as setting_interface:
gr.Markdown("設定用タブです。")
url = gr.Textbox(label="url", value="http://localhost:8080")
system = gr.Textbox(label="system", value="あなたは優秀なアシスタントです。", lines=2)
temperature = gr.Slider(minimum=0, maximum=2, step=0.01, value=0.8, label="temperature")
top_p = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.9, label="top_p")
max_tokens = gr.Slider(minimum=1, maximum=65536, step=1, value=256, label="max_tokens")
repeat_penalty = gr.Slider(minimum=1.0, maximum=2.0, step=0.01, value=1.0, label="repeat_penalty")
ignore_eos = gr.Checkbox(label="ignore_eos", value=False)
system_template = gr.Textbox(label="system_template", value="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{system}<|END_OF_TURN_TOKEN|>", lines=3)
user_template = gr.Textbox(label="user_template", value="<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{message}<|END_OF_TURN_TOKEN|>", lines=3)
chatbot_template = gr.Textbox(label="chatbot_template", value="<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{response}<|END_OF_TURN_TOKEN|>", lines=3)
output = gr.Textbox(label="output", interactive=False)
setting_list = [url, system, temperature, top_p, max_tokens, repeat_penalty, ignore_eos, system_template, user_template, chatbot_template]
for setting in setting_list:
setting.change(config, inputs=setting_list, outputs=output)
demo = gr.TabbedInterface([chat_interface, completion_interface, setting_interface], ["Chat", "Completion", "Setting"], theme=gr.themes.Base())
demo.launch(share = args.share)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment