|
# japanese-gpt-neox-3.6b-instruction-sftを使ったチャットサンプル |
|
|
|
# 以下のコマンドが使えます。 |
|
# retry: 回答を再生成します。 |
|
# clear, cls: 会話ログを消去します。 |
|
# exit, end: スクリプトを終了します。 |
|
|
|
# 以下のコマンドライン引数が使えます。 |
|
# --path, -p: 会話設定用JSONファイルのパス(省略可) |
|
# --input, -i: 入力テキストを指定。実行したらすぐに終了する。(省略可) |
|
|
|
# ------------設定------------ |
|
use_cuda = True # cudaを使用するかどうか |
|
fp16 = True # モデルを半精度化する |
|
max_history = 5 # 履歷の最大長 |
|
max_completion_count = 3 # 補完が途切れた場合、自動補完継続する回数 |
|
max_new_tokens= 256 # 補完時の最大トークン長 |
|
temperature = 0.7 # 補完の温度 |
|
|
|
# ------------会話設定(JSONファイルでも指定可)------------ |
|
speaker1 = "ユーザー" # デフォルトユーザー名 |
|
speaker2 = "アシスタント" # デフォルトAI名 |
|
|
|
# デフォルトシステムメッセージ |
|
system_message = f"{speaker2}は有能なAIアシスタントです。{speaker1}の質問に対し、適切な回答を行います。" |
|
|
|
# デフォルト初期プロンプト |
|
initial_messages = [ |
|
{ |
|
"speaker": speaker1, |
|
"text": "こんにちは。" |
|
}, |
|
{ |
|
"speaker": speaker2, |
|
"text": f"こんにちは、{speaker1}さん。何かお手伝いできることはありますか?" |
|
}, |
|
] |
|
|
|
# ------------コードここから------------ |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import json |
|
import argparse |
|
import os |
|
import re |
|
|
|
json_path = "" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--path','-p') |
|
parser.add_argument('--input', '-i', dest='input_text') |
|
args = parser.parse_args() |
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
if args.path: |
|
json_path = args.path if os.path.isabs(args.path) else os.path.join( |
|
current_dir, args.path) |
|
|
|
if json_path: |
|
with open(json_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
system_message = data['system_message'] |
|
speaker1 = data['speaker1'] |
|
speaker2 = data['speaker2'] |
|
initial_messages = data['initial_messages'] |
|
|
|
messages = [] |
|
system_message = { |
|
"speaker": "システム", |
|
"text": system_message |
|
} |
|
messages.extend(initial_messages) |
|
|
|
model_name="rinna/japanese-gpt-neox-3.6b-instruction-sft" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) |
|
|
|
if use_cuda and torch.cuda.is_available(): |
|
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.float16) if fp16 else AutoModelForCausalLM.from_pretrained(model_name).to("cuda") |
|
else: |
|
model = AutoModelForCausalLM.from_pretrained(model_name).to("cpu") |
|
|
|
def complete(prompt): |
|
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") |
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate( |
|
token_ids.to(model.device), |
|
do_sample=True, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
pad_token_id=tokenizer.pad_token_id, |
|
bos_token_id=tokenizer.bos_token_id, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):]) |
|
return output |
|
|
|
while True: |
|
|
|
input_text = args.input_text if args.input_text else input(f"{speaker1}: ") |
|
input_text = input_text.replace("\n", "<NL>") |
|
if input_text == "exit" or input_text == "end": |
|
break |
|
elif input_text == "retry": |
|
messages.pop() |
|
elif input_text == "clear" or input_text == "cls": |
|
messages.clear() |
|
messages.extend(initial_messages) |
|
continue |
|
else: |
|
messages.append( |
|
{ |
|
"speaker":speaker1, |
|
"text":input_text |
|
} |
|
) |
|
|
|
while len(messages) > max_history: |
|
messages.pop(0) |
|
|
|
prompt = [ |
|
f"{uttr['speaker']}: {uttr['text']}" |
|
for uttr in [system_message] + messages |
|
] |
|
prompt = "<NL>".join(prompt) |
|
prompt = prompt + "<NL>" + f"{speaker2}: " |
|
|
|
completion_count = 0 |
|
output = "" |
|
if not args.input_text: |
|
print(f"{speaker2}: ",end="") |
|
while not output.endswith("</s>") and completion_count < max_completion_count: |
|
current = complete(prompt + output) |
|
output += current |
|
completion_count += 1 |
|
|
|
output = output.replace("</s>", "") |
|
|
|
pattern = re.compile(r"^(.+?)<NL>" + speaker1 + ": ", re.DOTALL) |
|
match = re.search(pattern, output) |
|
if match: |
|
output = match.group(1) |
|
|
|
output = output.replace(speaker2 + ": ", "") |
|
|
|
messages.append( |
|
{ |
|
"speaker": speaker2, |
|
"text": output |
|
} |
|
) |
|
output = output.replace("<NL>", "\n") |
|
print(output) |
|
if args.input_text: |
|
break |