Skip to content

Instantly share code, notes, and snippets.

@kishida
Last active October 31, 2023 14:57
Show Gist options
  • Save kishida/66b4e456e9e14362eb3d339f19815477 to your computer and use it in GitHub Desktop.
Save kishida/66b4e456e9e14362eb3d339f19815477 to your computer and use it in GitHub Desktop.
Youri 7B Instructionを使った翻訳システム
import torch
from transformers import AutoTokenizer
# model_name = "rinna/youri-7b-instruction-gptq"
model_name = "rinna/youri-7b-instruction"
load_in_8bit = True
tokenizer = AutoTokenizer.from_pretrained(model_name)
if "gptq" in model_name:
from auto_gptq import AutoGPTQForCausalLM
model = AutoGPTQForCausalLM.from_quantized(model_name, use_safetensors=True)
else:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_name, load_in_8bit=load_in_8bit)
if torch.cuda.is_available() and not load_in_8bit:
model = model.to("cuda")
instruction_je = "次の日本語を英語に翻訳してください。"
instruction_ej = "次の英語を日本語に翻訳してください。"
template = """
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
### 指示:
{instruction}
### 入力:
{input}
### 応答:
"""
def generate(input, ej):
prompt = template.format(input=input, instruction=instruction_ej if ej else instruction_je)
token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
start_index = token_ids.shape[-1]
with torch.no_grad():
output_ids = model.generate(
input_ids=token_ids.to(model.device),
max_new_tokens=200,
do_sample=True,
temperature=0.5,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
output = tokenizer.decode(output_ids.tolist()[0][start_index:], skip_special_tokens=True)
return output
# UI
import gradio as gr
label_je = "日->英"
label_ej = "英->日"
def proc(text, type):
ej = type==label_ej
return generate(text, ej)
with gr.Blocks() as demo:
gr.Markdown(f"## 翻訳 by {model_name}")
input = gr.Textbox(lines=5, placeholder="原文")
type = gr.Radio([label_je, label_ej], value=label_ej)
submit = gr.Button("翻訳", variant="primary")
answer = gr.Textbox(lines=5)
submit.click(proc, inputs=[input, type], outputs=answer)
demo.launch()
@kishida
Copy link
Author

kishida commented Oct 31, 2023

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment