|  | import os | 
        
          |  | import gc | 
        
          |  | from string import Template | 
        
          |  | from threading import Thread | 
        
          |  |  | 
        
          |  | import torch | 
        
          |  | import gradio as gr | 
        
          |  | from transformers import AutoTokenizer | 
        
          |  | from text_generation import Client | 
        
          |  |  | 
        
          |  | hostname = os.environ.get("HOSTNAME") | 
        
          |  |  | 
        
          |  | auth_token = os.environ.get("HF_API_TOKEN") | 
        
          |  | tokenizer = AutoTokenizer.from_pretrained( | 
        
          |  | "CarperAI/stable-vicuna-13b-fp16", | 
        
          |  | use_auth_token=auth_token if auth_token else True, | 
        
          |  | ) | 
        
          |  |  | 
        
          |  | max_context_length = 2048 | 
        
          |  | max_new_tokens = 768 | 
        
          |  |  | 
        
          |  |  | 
        
          |  | prompt_template = Template("""\ | 
        
          |  | ### Human: $human | 
        
          |  | ### Assistant: $bot\ | 
        
          |  | """) | 
        
          |  |  | 
        
          |  |  | 
        
          |  | system_prompt = "### Assistant: I am StableVicuna, a large language model created by CarperAI. I am here to chat!" | 
        
          |  | system_prompt_tokens = tokenizer([f"{system_prompt}\n\n"], return_tensors="pt") | 
        
          |  | max_sys_tokens = system_prompt_tokens['input_ids'].size(-1) | 
        
          |  |  | 
        
          |  |  | 
        
          |  | client = Client(f"http://{hostname}:8080") | 
        
          |  |  | 
        
          |  | def bot(history): | 
        
          |  | history = history or [] | 
        
          |  |  | 
        
          |  | # Inject prompt formatting into the history | 
        
          |  | prompt_history = [] | 
        
          |  | for human, bot in history: | 
        
          |  | if bot is not None: | 
        
          |  | bot = bot.replace("<br>", "\n") | 
        
          |  | bot = bot.rstrip() | 
        
          |  | prompt_history.append( | 
        
          |  | prompt_template.substitute( | 
        
          |  | human=human, bot=bot if bot is not None else "") | 
        
          |  | ) | 
        
          |  |  | 
        
          |  | msg_tokens = tokenizer( | 
        
          |  | "\n\n".join(prompt_history).strip(), | 
        
          |  | return_tensors="pt", | 
        
          |  | add_special_tokens=False  # Use <BOS> from the system prompt | 
        
          |  | ) | 
        
          |  |  | 
        
          |  | # Take only the most recent context up to the max context length and prepend the | 
        
          |  | # system prompt with the messages | 
        
          |  | max_tokens = -max_context_length + max_new_tokens + max_sys_tokens | 
        
          |  |  | 
        
          |  | input_tokens = torch.concat([system_prompt_tokens['input_ids'], msg_tokens['input_ids'][:, max_tokens:]], dim=-1) | 
        
          |  | input_text = tokenizer.decode(input_tokens[0].cpu().tolist()) | 
        
          |  |  | 
        
          |  | generate_kwargs = dict( | 
        
          |  | max_new_tokens=max_new_tokens, | 
        
          |  | do_sample=True, | 
        
          |  | top_p=0.999, | 
        
          |  | temperature=1.0, | 
        
          |  | ) | 
        
          |  |  | 
        
          |  | partial_text = "" | 
        
          |  | for resp in client.generate_stream(input_text, **generate_kwargs): | 
        
          |  | new_text = resp.token.text | 
        
          |  | # Process out the prompt separator | 
        
          |  | new_text = new_text.replace("<br>", "\n") | 
        
          |  | if (partial_text + new_text).endswith('###'): | 
        
          |  | history[-1][1] = (partial_text + new_text)[:-3] | 
        
          |  | break | 
        
          |  | # new_text = new_text.split("###")[0] | 
        
          |  | # partial_text += new_text.strip() | 
        
          |  | # history[-1][1] = partial_text | 
        
          |  | break | 
        
          |  | elif new_text == '#' or new_text == '##': | 
        
          |  | partial_text += new_text | 
        
          |  | else: | 
        
          |  | # Filter empty trailing new lines | 
        
          |  | # if new_text == "\n": | 
        
          |  | #    new_text = '' | 
        
          |  | partial_text += new_text | 
        
          |  | history[-1][1] = partial_text | 
        
          |  | yield history | 
        
          |  | return partial_text | 
        
          |  |  | 
        
          |  |  | 
        
          |  | def user(user_message, history): | 
        
          |  | return "", history + [[user_message, None]] | 
        
          |  |  | 
        
          |  |  | 
        
          |  | with gr.Blocks() as demo: | 
        
          |  | gr.Markdown("#StableVicuna by CarperAI") | 
        
          |  | gr.HTML("<a href='https://huggingface.co/CarperAI/stable-vicuna-13b-delta'><code>CarperAI/stable-vicuna-13b-delta</a>") | 
        
          |  | gr.HTML('''<center><a href="https://huggingface.co/spaces/CarperAI/StableVicuna?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''') | 
        
          |  |  | 
        
          |  | chatbot = gr.Chatbot([], elem_id="chatbot").style(height=500) | 
        
          |  | state = gr.State([]) | 
        
          |  | with gr.Row(): | 
        
          |  | with gr.Column(): | 
        
          |  | msg = gr.Textbox( | 
        
          |  | label="Send a message", | 
        
          |  | placeholder="Send a message", | 
        
          |  | show_label=False | 
        
          |  | ).style(container=False) | 
        
          |  | with gr.Column(): | 
        
          |  | with gr.Row(): | 
        
          |  | submit = gr.Button("Send") | 
        
          |  | stop = gr.Button("Stop") | 
        
          |  | clear = gr.Button("Clear History") | 
        
          |  |  | 
        
          |  | submit_event = msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( | 
        
          |  | fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True) | 
        
          |  | submit_click_event = submit.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( | 
        
          |  | fn=bot, inputs=[chatbot], outputs=[chatbot], queue=True) | 
        
          |  |  | 
        
          |  | stop.click(fn=None, inputs=None, outputs=None, cancels=[submit_event, submit_click_event], queue=False) | 
        
          |  | clear.click(lambda: None, None, [chatbot], queue=True) | 
        
          |  |  | 
        
          |  | demo.queue(max_size=32, concurrency_count=2) | 
        
          |  | demo.launch(share=True) |