Created
August 3, 2023 15:15
-
-
Save ShadowPower/527656baba17fc822c76277953c70a6b to your computer and use it in GitHub Desktop.
通义千问 7B Gradio
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
import mdtex2html | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from transformers.generation.utils import GenerationConfig | |
# 模型路径 | |
MODEL_PATH = 'Qwen/Qwen-7B-Chat' | |
CONTEXT_SIZE = 8192 | |
MAX_LENGTH = 2048 | |
TOP_P = 0.8 | |
# quantization configuration for NF4 (4 bits) | |
nf4_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type='nf4', | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
# quantization configuration for Int8 (8 bits) | |
int8_config = BitsAndBytesConfig(load_in_8bit=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_PATH, | |
trust_remote_code=True, | |
# quantization_config=int8_config, | |
device_map="auto", | |
) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False, trust_remote_code=True) | |
model.generation_config = GenerationConfig.from_pretrained(MODEL_PATH) | |
model.generation_config.max_context_size = CONTEXT_SIZE | |
model.generation_config.top_p = TOP_P | |
model.generation_config.max_new_tokens = MAX_LENGTH | |
def postprocess(self, y): | |
if y is None: | |
return [] | |
for i, (message, response) in enumerate(y): | |
y[i] = ( | |
None if message is None else mdtex2html.convert((message)), | |
None if response is None else mdtex2html.convert(response), | |
) | |
return y | |
gr.Chatbot.postprocess = postprocess | |
def parse_text(text): | |
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" | |
lines = text.split("\n") | |
lines = [line for line in lines if line != ""] | |
count = 0 | |
for i, line in enumerate(lines): | |
if "```" in line: | |
count += 1 | |
items = line.split('`') | |
if count % 2 == 1: | |
lines[i] = f'<pre><code class="language-{items[-1]}">' | |
else: | |
lines[i] = f'<br></code></pre>' | |
else: | |
if i > 0: | |
if count % 2 == 1: | |
line = line.replace("`", "\`") | |
line = line.replace("<", "<") | |
line = line.replace(">", ">") | |
line = line.replace(" ", " ") | |
line = line.replace("*", "*") | |
line = line.replace("_", "_") | |
line = line.replace("-", "-") | |
line = line.replace(".", ".") | |
line = line.replace("!", "!") | |
line = line.replace("(", "(") | |
line = line.replace(")", ")") | |
line = line.replace("$", "$") | |
lines[i] = "<br>"+line | |
text = "".join(lines) | |
return text | |
def predict(input, chatbot, history): | |
chatbot.append((parse_text(input), "")) | |
response, history = model.chat(tokenizer, parse_text(input), history) | |
chatbot[-1] = (parse_text(input), parse_text(response)) | |
yield chatbot, history | |
def reset_user_input(): | |
return gr.update(value='') | |
def reset_state(): | |
return [], [] | |
with gr.Blocks() as demo: | |
gr.HTML("""通义千问 7B""") | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
with gr.Column(scale=4): | |
user_input = gr.Textbox(show_label=False, placeholder="在此输入消息", lines=4, container=False) | |
with gr.Column(scale=1): | |
submitBtn = gr.Button("Submit", variant="primary") | |
emptyBtn = gr.Button("重置会话") | |
history = gr.State([]) | |
submitBtn.click(predict, [user_input, chatbot, history], [chatbot, history], show_progress=True) | |
submitBtn.click(reset_user_input, [], [user_input]) | |
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) | |
demo.queue().launch(share=True, inbrowser=True, server_name="0.0.0.0", server_port=9999) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment