Created
May 19, 2023 16:19
-
-
Save advanceboy/b9143aa9de23a6f9a60a07a862e0b4a8 to your computer and use it in GitHub Desktop.
rinna/japanese-gpt-neox-3.6b-instruction-sft を使ったチャット UI のサンプル実装です。 transformers.TextIteratorStreamer API を利用して、 ChatGPT のように生成したテキストを少しずつ表示し、ユーザー体験を向上させています。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
# License: CC0 | |
""" | |
rinna/japanese-gpt-neox-3.6b-instruction-sft を使ったチャット UI のサンプル実装です。 | |
-> https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft | |
transformers.TextIteratorStreamer API を利用して、 ChatGPT のように生成したテキストを少しずつ表示し、ユーザー体験を向上させています。 | |
-> https://huggingface.co/docs/transformers/v4.29.1/en/internal/generation_utils#transformers.TextIteratorStreamer | |
ユーザー入力には、以下のコマンドが使えます。 | |
clear : すべての入力履歴をクリアし、初期プロンプト状態にリセットします。 | |
exit : プログラムを終了します。 | |
retry : 前回と同じプロンプトでテキストを再生成します。 | |
streamer クラスの API は開発中のため、近い将来互換性がなくなる可能性があります。 | |
transformers==4.29.2 での動作を確認しています。 | |
環境作成手順 | |
1. CUDA Toolkit のインストール https://developer.nvidia.com/cuda-toolkit-archive | |
2. CUDA の環境にあわせた PyTorch のパッケージを pip で追加 | |
* `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118` | |
3. transformers 関連のパッケージを pip で追加 | |
* `pip3 install ipython sentencepiece transformers accelerate` | |
4. Python でスクリプトを実行 | |
* `python rinna_chat_streaming.py` | |
* 初回実行時、 huggingface.co からモデルを DL にするのに時間がかかったり、失敗したりする場合があります。 | |
pip パッケージを入れる際は、 venv などで仮想環境を作成しておくことを強くおすすめします。 | |
""" | |
import re | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from threading import Thread | |
# 定数宣言 | |
model_name = "rinna/japanese-gpt-neox-3.6b-instruction-sft" | |
torch_dtype = torch.bfloat16 | |
max_new_tokens = 256 | |
temperature = 0.7 | |
max_history = 16 | |
speaker_name_user = 'ユーザー' | |
speaker_name_system = 'システム' | |
initial_messages = [ | |
{ | |
'speaker': speaker_name_user, | |
'text': 'こんにちは。' | |
}, | |
{ | |
'speaker': speaker_name_system, | |
'text': f'こんにちは、私は{speaker_name_system}です。あなたの質問に適切な回答をします。どのようなご用件ですか?' | |
}, | |
] | |
messages = [] | |
messages.extend(initial_messages) | |
# モデルの初期化 | |
if not torch.cuda.is_available(): | |
raise 'CUDA is not available' | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch_dtype) | |
output = None | |
while True: | |
input_text = input(f'{speaker_name_user}: ') | |
# コマンドの処理 | |
if input_text == 'clear': | |
messages.clear() | |
messages.extend(initial_messages) | |
print('履歴が初期化されました') | |
continue | |
elif input_text == 'exit': | |
break | |
elif input_text == 'retry': | |
pass | |
else: | |
if output: | |
messages.append({ | |
'speaker': speaker_name_system, | |
'text': output | |
}) | |
messages.append({ | |
'speaker': speaker_name_user, | |
'text': input_text | |
}) | |
if len(messages) > max_history: | |
del messages[0:(len(messages)-max_history)] | |
# プロンプトの作成 | |
prompt = [ | |
f"{uttr['speaker']}: {uttr['text']}" | |
for uttr in messages | |
] | |
prompt = "<NL>".join(prompt) | |
prompt = ( | |
prompt | |
+ "<NL>" | |
+ f'{speaker_name_system}: ' | |
) | |
# テキスト生成の開始 | |
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
generation_args = [token_ids.to(model.device)] | |
generation_kwargs = dict( | |
streamer=streamer, | |
do_sample=True, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
pad_token_id=tokenizer.pad_token_id, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
thread = Thread(target=model.generate, args=generation_args, kwargs=generation_kwargs) | |
thread.start() | |
# TextIteratorStreamer を使った生成結果の受け取り | |
print(f'{speaker_name_system}: ', end='') | |
generated_text = '' | |
for next_text in streamer: | |
if not next_text: | |
continue | |
print(next_text.replace('<NL>', "\n"), end='', flush=True) | |
generated_text += next_text | |
print('') | |
# 生成結果 | |
output = re.sub('</s>$', '', generated_text) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment