Created
August 20, 2023 00:26
-
-
Save mutaguchi/2a4fd9a9d90b0103be6cdd4e9b628cd5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import GPTJForCausalLM, AlbertTokenizer | |
import torch | |
model = 'AIBunCho/japanese-novel-gpt-j-6b' | |
tokenizer = AlbertTokenizer.from_pretrained(model, keep_accents=True, remove_space=False) | |
model = GPTJForCausalLM.from_pretrained( | |
model, | |
load_in_4bit = True, | |
torch_dtype = torch.bfloat16, | |
device_map = 'auto') | |
model.eval() | |
def completion(prompt): | |
input_ids = tokenizer.encode( | |
prompt, | |
add_special_tokens=False, | |
return_tensors="pt" | |
).cuda() | |
tokens = model.generate( | |
input_ids.to(device=model.device), | |
max_new_tokens=256, | |
temperature=0.6, | |
top_p=1.0, | |
repetition_penalty=1.2, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
output = tokenizer.decode(tokens[0], skip_special_tokens=True) | |
output_without_input = tokenizer.decode(tokens[0][input_ids.shape[1]:], skip_special_tokens=True) | |
return (output, output_without_input) | |
def get_line(prompt): | |
result = "" | |
for i in range(20): | |
prompt, current_output = completion(prompt) | |
if current_output: | |
markers = ["」", f"{speaker2}「", f"{speaker1}「"] | |
found = False | |
for marker in markers: | |
if marker in current_output: | |
current_output = current_output.split(marker)[0] | |
found = True | |
break | |
print(current_output, end="", flush=True) | |
result += current_output | |
if found: | |
break | |
return result.strip() | |
speaker1 = "私" | |
speaker2 = "メイド" | |
max_conversation_length = 20 | |
system=f''' | |
[{speaker1}と{speaker2}の会話] | |
[{speaker2}の職業は、メイド。] | |
[{speaker2}は、丁寧な話し方をする。] | |
[{speaker2}は、毒舌家である。] | |
[{speaker2}の好きな食べ物は、チョコレート。] | |
[{speaker2}は、{speaker1}に仕えている。] | |
''' | |
initial_conversation=[ | |
{ | |
"speaker":speaker1, | |
"content":"おはよう、メイドちゃん。" | |
}, | |
{ | |
"speaker":speaker2, | |
"content":"おはようございます、ご主人様。" | |
}, | |
{ | |
"speaker":speaker1, | |
"content":"今日もよろしくね。" | |
}, | |
{ | |
"speaker":speaker2, | |
"content":"はい、精一杯、お仕事を務めさせていただきます。" | |
} | |
] | |
conversation = [] | |
turn = 0 | |
while True: | |
speaker1_line = input(f"{speaker1}: ") | |
prompt = f"{system}\n" | |
current_conversation = conversation[-max_conversation_length:] if len(conversation) > max_conversation_length else conversation.copy() | |
if turn < 5: | |
current_conversation[:0] = initial_conversation | |
else: | |
speaker_lines = [line["content"] for line in initial_conversation if line["speaker"] == speaker2] | |
prompt += f'[{speaker2}の台詞例:「{"」「".join(speaker_lines)}」]\n' | |
for line in current_conversation: | |
prompt += f'{line["speaker"]}「{line["content"]}」\n' | |
prompt += f'{speaker1}「{speaker1_line}」\n{speaker2}「' | |
#print(prompt) | |
print(f'{speaker2}: ',end="",flush=True) | |
speaker2_line = get_line(prompt) | |
print("") | |
if speaker2_line: | |
turn += 1 | |
conversation.append( | |
{ | |
"speaker":speaker1, | |
"content":speaker1_line | |
} | |
) | |
conversation.append( | |
{ | |
"speaker":speaker2, | |
"content":speaker2_line | |
} | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment