Skip to content

Instantly share code, notes, and snippets.

@justdoit0823
Created July 15, 2023 16:00
Show Gist options
  • Save justdoit0823/892820b8399de9133b3df647158728e8 to your computer and use it in GitHub Desktop.
Save justdoit0823/892820b8399de9133b3df647158728e8 to your computer and use it in GitHub Desktop.
A simple translation demo for chinese to english with gradio.
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Setup the gradio Demo.
import gradio as gr
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-zh-en")
def translation(text):
print(f"input: {text}")
parts = []
for part in text.split("。"):
if part:
parts.append(do_trans(part))
return "".join(parts)
def do_trans(text):
input_ids = tokenizer.encode(text, return_tensors="pt", truncation=False)
# 使用翻译模型将中文句子翻译为英文
outputs = model.generate(input_ids=input_ids, num_beams=4, early_stopping=False)
content = tokenizer.decode(outputs[0], skip_special_tokens=True)
if content.endswith('.'):
return content
return content + '.'
with gr.Blocks(
theme=gr.themes.Soft(),
css=".disclaimer {font-variant-caps: all-small-caps;}",
) as demo:
gr.Markdown(
"""Opus-mt-zh-e Demo
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
srcMsg = gr.Textbox(
label="Source Text Box",
placeholder="请输入翻译内容",
show_label=False,
lines=20,
max_lines=100,
).style(container=True)
dstMsg = gr.Textbox(
label="Dest Text Box",
placeholder="翻译结果",
show_label=False,
lines=20,
max_lines=100,
).style(container=True)
with gr.Row():
submit = gr.Button("Translate")
clear = gr.Button("Clear")
submit_event = srcMsg.submit(
fn=lambda x: x,
inputs=[srcMsg],
outputs=[srcMsg],
queue=False,
).then(
fn=translation,
inputs=[
srcMsg,
],
outputs=dstMsg,
queue=True,
)
submit_click_event = submit.click(
fn=lambda x: x,
inputs=[srcMsg],
outputs=[srcMsg],
queue=False,
).then(
fn=translation,
inputs=[
srcMsg,
],
outputs=dstMsg,
queue=True,
)
clear.click(lambda: [None,None], None, [srcMsg, dstMsg], queue=False)
demo.queue(max_size=128, concurrency_count=2)
# Launch your Guanaco Demo!
demo.launch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment