Last active
October 31, 2023 14:57
-
-
Save kishida/66b4e456e9e14362eb3d339f19815477 to your computer and use it in GitHub Desktop.
Youri 7B Instructionを使った翻訳システム
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoTokenizer | |
# model_name = "rinna/youri-7b-instruction-gptq" | |
model_name = "rinna/youri-7b-instruction" | |
load_in_8bit = True | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
if "gptq" in model_name: | |
from auto_gptq import AutoGPTQForCausalLM | |
model = AutoGPTQForCausalLM.from_quantized(model_name, use_safetensors=True) | |
else: | |
from transformers import AutoModelForCausalLM | |
model = AutoModelForCausalLM.from_pretrained(model_name, load_in_8bit=load_in_8bit) | |
if torch.cuda.is_available() and not load_in_8bit: | |
model = model.to("cuda") | |
instruction_je = "次の日本語を英語に翻訳してください。" | |
instruction_ej = "次の英語を日本語に翻訳してください。" | |
template = """ | |
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。 | |
### 指示: | |
{instruction} | |
### 入力: | |
{input} | |
### 応答: | |
""" | |
def generate(input, ej): | |
prompt = template.format(input=input, instruction=instruction_ej if ej else instruction_je) | |
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") | |
start_index = token_ids.shape[-1] | |
with torch.no_grad(): | |
output_ids = model.generate( | |
input_ids=token_ids.to(model.device), | |
max_new_tokens=200, | |
do_sample=True, | |
temperature=0.5, | |
pad_token_id=tokenizer.pad_token_id, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
output = tokenizer.decode(output_ids.tolist()[0][start_index:], skip_special_tokens=True) | |
return output | |
# UI | |
import gradio as gr | |
label_je = "日->英" | |
label_ej = "英->日" | |
def proc(text, type): | |
ej = type==label_ej | |
return generate(text, ej) | |
with gr.Blocks() as demo: | |
gr.Markdown(f"## 翻訳 by {model_name}") | |
input = gr.Textbox(lines=5, placeholder="原文") | |
type = gr.Radio([label_je, label_ej], value=label_ej) | |
submit = gr.Button("翻訳", variant="primary") | |
answer = gr.Textbox(lines=5) | |
submit.click(proc, inputs=[input, type], outputs=answer) | |
demo.launch() |
Author
kishida
commented
Oct 31, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment