Last active
February 21, 2024 08:40
-
-
Save yuchenlin/e3c62d9a8ebe1a2323a1d31e6a015c90 to your computer and use it in GitHub Desktop.
chat_app.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers import StoppingCriteria, StoppingCriteriaList, LogitsProcessor, LogitsProcessorList | |
model_path = "./qlora-out-hkg_300B/merged/" | |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", trust_remote_code=True) | |
def format_prompt(message, history): | |
prompt = "<s>A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" | |
for user_prompt, bot_response in history: | |
prompt += f"USER: {user_prompt} \n" | |
prompt += f"ASSISTANT: {bot_response} " | |
prompt += f"USER: {message} \nASSISTANT: " | |
return prompt | |
class EndOfFunctionCriteria(StoppingCriteria): | |
"""Custom `StoppingCriteria` which checks if all generated functions in the batch are completed.""" | |
def __init__(self, start_length, eof_strings, tokenizer): | |
self.start_length = start_length | |
self.eof_strings = eof_strings | |
self.tokenizer = tokenizer | |
def __call__(self, input_ids, scores, **kwargs): | |
"""Returns true if all generated sequences contain any of the end-of-function strings.""" | |
decoded_generations = self.tokenizer.batch_decode( | |
input_ids[:, self.start_length :] | |
) | |
done = [] | |
for decoded_generation in decoded_generations: | |
done.append( | |
any( | |
[ | |
stop_string in decoded_generation | |
for stop_string in self.eof_strings | |
] | |
) | |
) | |
return all(done) | |
def generate( | |
prompt, history, temperature=0.3, max_new_tokens=256, top_p=0.9, repetition_penalty=1.0, | |
): | |
global tokenizer, model | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
) | |
formatted_prompt = format_prompt(prompt, history) | |
print("formatted_prompt:", [formatted_prompt]) | |
inputs = tokenizer([formatted_prompt], return_tensors="pt") | |
input_ids = inputs["input_ids"] | |
_, prefix_length = input_ids.shape | |
eof_strings = ["USER:", "ASSISTANT:"] | |
stopping_criteria = StoppingCriteriaList([EndOfFunctionCriteria(start_length=prefix_length, eof_strings=eof_strings, tokenizer=tokenizer)]) | |
generate_kwargs["stopping_criteria"] = stopping_criteria | |
output_ids = model.generate(input_ids.to('cuda'), **generate_kwargs) | |
response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) | |
for eof in eof_strings: | |
if response.strip().endswith(eof): | |
response = response.strip()[:-len(eof)] | |
output = response | |
return output | |
mychatbot = gr.Chatbot( | |
# avatar_images=["./user.png", "./botm.png"], | |
bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,) | |
demo = gr.ChatInterface(fn=generate, | |
chatbot=mychatbot, | |
title="Let's Chat", | |
retry_btn=None, | |
undo_btn=None, | |
) | |
demo.queue().launch(show_api=False, share=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment