Last active
August 9, 2024 07:23
-
-
Save advanceboy/717fde162a6f9ccb592f04898f0aacc1 to your computer and use it in GitHub Desktop.
rinna/japanese-gpt-neox-3.6b-instruction-sft もとい rinna/japanese-gpt-neox-3.6b-instruction-ppo と gradio を使ったチャット UI のサンプル実装です。 transformers.TextIteratorStreamer API を利用して、 ChatGPT のように生成したテキストを少しずつ出力し、gradio の UI でユーザーに表示させています。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
# License: CC0 | |
""" | |
rinna/japanese-gpt-neox-3.6b-instruction-ppo と gradio を使ったチャット UI のサンプル実装です。 | |
-> https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-ppo | |
transformers.TextIteratorStreamer API を利用して、 ChatGPT のように生成したテキストを少しずつ表示し、ユーザー体験を向上させています。 | |
-> https://huggingface.co/docs/transformers/v4.29.1/en/internal/generation_utils#transformers.TextIteratorStreamer | |
streamer クラスの API は開発中のため、近い将来互換性がなくなる可能性があります。 | |
transformers==4.37.2 gradio==4.16.0 での動作を確認しています。 | |
環境作成手順 | |
1. CUDA Toolkit のインストール https://developer.nvidia.com/cuda-toolkit-archive | |
2. CUDA の環境にあわせた PyTorch のパッケージを pip で追加 | |
* `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118` | |
3. transformers 関連のパッケージを pip で追加 | |
* `pip3 install ipython sentencepiece transformers accelerate gradio` | |
4. Python でスクリプトを実行 | |
* `python rinna_gradio_chat.py` | |
* 初回実行時、 huggingface.co からモデルを DL にするのに時間がかかったり、失敗したりする場合があります。 | |
5. コンソールに表示された URL <http://127.0.0.1:7860> にブラウザでアクセスする | |
pip パッケージを入れる際は、 venv などで仮想環境を作成しておくことを強くおすすめします。 | |
""" | |
import re | |
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from threading import Thread | |
# 定数宣言 | |
model_name = "rinna/japanese-gpt-neox-3.6b-instruction-ppo" | |
torch_dtype = torch.bfloat16 | |
max_new_tokens = 512 | |
temperature = 0.7 | |
speaker_name_user = 'ユーザー' | |
speaker_name_system = 'システム' | |
start_message = '<NL>'.join([ | |
f'{speaker_name_user}: こんにちは。', | |
f'{speaker_name_system}: こんにちは、私は{speaker_name_system}です。あなたの質問に適切な回答をします。どのようなご用件ですか?' | |
]) | |
# モデルの初期化 | |
if not torch.cuda.is_available(): | |
raise 'CUDA is not available' | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch_dtype) | |
def user(message, history): | |
# history にユーザーメッセージを追加 | |
return "", history + [[message, ""]] | |
def chat(curr_system_message, history): | |
# プロンプトの作成 | |
prompt = [ | |
f"{speaker_name_user}: {item[0]}" | |
'<NL>' | |
f"{speaker_name_system}: {item[1]}" | |
for item in history | |
] | |
prompt = '<NL>'.join(prompt) | |
prompt = (curr_system_message | |
+ '<NL>' | |
+ prompt | |
+ '<NL>' | |
+ f'{speaker_name_system}: ' | |
).replace("\n", '<NL>') | |
# テキスト生成の開始 | |
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
generation_args = [token_ids.to(model.device)] | |
generation_kwargs = dict( | |
streamer=streamer, | |
do_sample=True, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
pad_token_id=tokenizer.pad_token_id, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
thread = Thread(target=model.generate, args=generation_args, kwargs=generation_kwargs) | |
thread.start() | |
# TextIteratorStreamer を使った生成結果の受け取り | |
print(f'{speaker_name_system}: ', end='') | |
generated_text = '' | |
pending_buffer = '' | |
for next_text in streamer: | |
if not next_text: | |
continue | |
print(next_text.replace('<NL>', "\n"), end='', flush=True) | |
last_pending_buffer = pending_buffer | |
generated_token = re.sub('</s>$', '', next_text) | |
pending_buffer = '' if next_text == generated_token else '</s>' | |
generated_text += last_pending_buffer + generated_token.replace('<NL>', "\n") | |
history[-1][1] = generated_text | |
yield history | |
print('') | |
# 生成結果 | |
return generated_text | |
with gr.Blocks() as app: | |
gr.Markdown(f"## {model_name} Chat") | |
chatbot = gr.Chatbot(height=500) | |
with gr.Row(): | |
with gr.Column(): | |
msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box", show_label=False, container=False) | |
with gr.Column(): | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
stop = gr.Button("Stop") | |
clear = gr.Button("Clear") | |
system_msg = gr.Textbox(start_message, label="System Message", interactive=False, visible=False) | |
submit_kwargs = dict(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False) | |
submit_then_kwargs = dict(fn=chat, inputs=[system_msg, chatbot], outputs=[chatbot], queue=True) | |
submit_event = msg.submit(**submit_kwargs).then(**submit_then_kwargs) | |
submit_click_event = submit.click(**submit_kwargs).then(**submit_then_kwargs) | |
stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False) | |
clear.click(lambda: None, None, [chatbot], queue=False) | |
app.queue(max_size=32) | |
app.launch(max_threads=2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment