Created
March 15, 2024 11:22
-
-
Save Norod/6d0ea958b38168e0dd4e39c720894e63 to your computer and use it in GitHub Desktop.
A simple inference script for CohereForAI/aya-101 with Gradio based UI, RTL support and Streaming text
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import gradio as gr | |
from threading import Thread | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TextIteratorStreamer | |
checkpoint = "CohereForAI/aya-101" | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map='auto', torch_dtype=torch.bfloat16) | |
text_title = checkpoint.replace("/", " - ") + ' - Gradio Demo' | |
######################################################################## | |
# Settings | |
######################################################################## | |
#Set the maximum number of tokens to generate | |
max_new_tokens = 250 | |
#Set a the value of the repetition penalty | |
#The higher the value, the less repetitive the generated text will be | |
#Note that `repetition_penalty` has to be a strictly positive float | |
repetition_penalty = 1.8 | |
#Set the text direction | |
#For languages that are written from right to left (RTL), set rtl to True | |
rtl = False | |
######################################################################## | |
print(f"Settings: max_new_tokens = {max_new_tokens}, repetition_penalty = {repetition_penalty}, rtl = {rtl}") | |
if rtl: | |
text_title += " - RTL" | |
text_align = 'right' | |
css = "#output_text{direction: rtl} #input_text{direction: rtl}" | |
else: | |
text_align = 'left' | |
css = "" | |
def generate(text = ""): | |
print("Create streamer") | |
yield "[Please wait for an answer]" | |
decode_kwargs = dict(skip_special_tokens = True, clean_up_tokenization_spaces = True) | |
streamer = TextIteratorStreamer(tokenizer, timeout = 5., decode_kwargs = decode_kwargs) | |
inputs = tokenizer([text], add_special_tokens = False, return_tensors = "pt").to('cuda') | |
print(tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)) | |
generation_kwargs = dict(inputs, streamer = streamer, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty) | |
print("Create thread") | |
thread = Thread(target = model.generate, kwargs = generation_kwargs) | |
thread.start() | |
generated_text = "" | |
for new_text in streamer: | |
if tokenizer.eos_token not in new_text: | |
new_text = new_text.replace(tokenizer.pad_token, "") | |
yield generated_text + new_text | |
print(new_text, end ="") | |
generated_text += new_text | |
else: | |
new_text = new_text.replace(tokenizer.eos_token, "\n") | |
print(new_text, end ="") | |
generated_text += new_text | |
return generated_text | |
return generated_text | |
demo = gr.Interface( | |
title = text_title, | |
fn = generate, | |
inputs = gr.Textbox(label = "Enter your prompt here", elem_id = "input_text", text_align = text_align, rtl = rtl), | |
outputs = gr.Textbox(type = "text", label = "Generated text will appear here", elem_id = "output_text", text_align = text_align, rtl = rtl), | |
css = css, | |
allow_flagging = 'never' | |
) | |
demo.queue() | |
demo.launch(debug = True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment