|
import os |
|
import gc |
|
from string import Template |
|
from threading import Thread |
|
|
|
import torch |
|
import gradio as gr |
|
from transformers import AutoTokenizer |
|
from text_generation import Client |
|
|
|
hostname = os.environ.get("HOSTNAME") |
|
|
|
auth_token = os.environ.get("HF_API_TOKEN") |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"CarperAI/stable-vicuna-13b-fp16", |
|
use_auth_token=auth_token if auth_token else True, |
|
) |
|
|
|
max_context_length = 2048 |
|
max_new_tokens = 768 |
|
|
|
|
|
prompt_template = Template("""\ |
|
### Human: $human |
|
### Assistant: $bot\ |
|
""") |
|
|
|
|
|
system_prompt = "### Assistant: I am StableVicuna, a large language model created by CarperAI. I am here to chat!" |
|
system_prompt_tokens = tokenizer([f"{system_prompt}\n\n"], return_tensors="pt") |
|
max_sys_tokens = system_prompt_tokens['input_ids'].size(-1) |
|
|
|
|
|
client = Client(f"http://{hostname}:8080") |
|
|
|
def bot(history): |
|
history = history or [] |
|
|
|
# Inject prompt formatting into the history |
|
prompt_history = [] |
|
for human, bot in history: |
|
if bot is not None: |
|
bot = bot.replace("<br>", "\n") |
|
bot = bot.rstrip() |
|
prompt_history.append( |
|
prompt_template.substitute( |
|
human=human, bot=bot if bot is not None else "") |
|
) |
|
|
|
msg_tokens = tokenizer( |
|
"\n\n".join(prompt_history).strip(), |
|
return_tensors="pt", |
|
add_special_tokens=False # Use <BOS> from the system prompt |
|
) |
|
|
|
# Take only the most recent context up to the max context length and prepend the |
|
# system prompt with the messages |
|
max_tokens = -max_context_length + max_new_tokens + max_sys_tokens |
|
|
|
input_tokens = torch.concat([system_prompt_tokens['input_ids'], msg_tokens['input_ids'][:, max_tokens:]], dim=-1) |
|
input_text = tokenizer.decode(input_tokens[0].cpu().tolist()) |
|
|
|
generate_kwargs = dict( |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
top_p=0.999, |
|
temperature=1.0, |
|
) |
|
|
|
partial_text = "" |
|
for resp in client.generate_stream(input_text, **generate_kwargs): |
|
new_text = resp.token.text |
|
# Process out the prompt separator |
|
new_text = new_text.replace("<br>", "\n") |
|
if (partial_text + new_text).endswith('###'): |
|
history[-1][1] = (partial_text + new_text)[:-3] |
|
break |
|
# new_text = new_text.split("###")[0] |
|
# partial_text += new_text.strip() |
|
# history[-1][1] = partial_text |
|
break |
|
elif new_text == '#' or new_text == '##': |
|
partial_text += new_text |
|
else: |
|
# Filter empty trailing new lines |
|
# if new_text == "\n": |
|
# new_text = '' |
|
partial_text += new_text |
|
history[-1][1] = partial_text |
|
yield history |
|
return partial_text |
|
|
|
|
|
def user(user_message, history): |
|
return "", history + [[user_message, None]] |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("#StableVicuna by CarperAI") |
|
gr.HTML("<a href='https://huggingface.co/CarperAI/stable-vicuna-13b-delta'><code>CarperAI/stable-vicuna-13b-delta</a>") |
|
gr.HTML('''<center><a href="https://huggingface.co/spaces/CarperAI/StableVicuna?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''') |
|
|
|
chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500) |
|
state = gr.State([]) |
|
with gr.Row(): |
|
with gr.Column(): |
|
msg = gr.Textbox( |
|
label="Send a message", |
|
placeholder="Send a message", |
|
show_label=False |
|
).style(container=False) |
|
with gr.Column(): |
|
with gr.Row(): |
|
submit = gr.Button("Send") |
|
stop = gr.Button("Stop") |
|
clear = gr.Button("Clear History") |
|
|
|
submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( |
|
fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True) |
|
submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( |
|
fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True) |
|
|
|
stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False) |
|
clear.click(lambda: None, None, [chatbot], queue=True) |
|
|
|
demo.queue(max_size=32, concurrency_count=2) |
|
demo.launch(share=True) |