Created
July 14, 2023 07:04
-
-
Save ShadowPower/e01539dee602d9601dca53509fef8c19 to your computer and use it in GitHub Desktop.
Baichuan 13B for Google Colab
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"gpuType": "T4" | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"安装依赖:" | |
], | |
"metadata": { | |
"id": "MCYKztKnii28" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "h3OTlwGJimIc", | |
"collapsed": true, | |
"cellView": "form" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install accelerate\n", | |
"!pip install bitsandbytes\n", | |
"!pip install colorama\n", | |
"!pip install cpm_kernels\n", | |
"!pip install sentencepiece\n", | |
"!pip install streamlit\n", | |
"!pip install transformers_stream_generator\n", | |
"!pip install gradio\n", | |
"!pip install mdtex2html" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"聊天:" | |
], | |
"metadata": { | |
"id": "cSN-eQ8MilaU" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import gradio as gr\n", | |
"import mdtex2html\n", | |
"import os\n", | |
"import torch\n", | |
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n", | |
"from transformers.generation.utils import GenerationConfig\n", | |
"\n", | |
"# Configuration\n", | |
"MODEL_PATH = 'sharpbai/Baichuan-13B-Chat'\n", | |
"\n", | |
"MAX_LENGTH = 2048\n", | |
"TOP_P = 0.85\n", | |
"TEMPERATURE = 0.05\n", | |
"STREAM = True\n", | |
"\n", | |
"\n", | |
"nf4_config = BitsAndBytesConfig(\n", | |
" load_in_4bit=True,\n", | |
" bnb_4bit_quant_type=\"nf4\",\n", | |
" bnb_4bit_use_double_quant=True,\n", | |
" bnb_4bit_compute_dtype=torch.bfloat16,\n", | |
")\n", | |
"\n", | |
"model = AutoModelForCausalLM.from_pretrained(\n", | |
" MODEL_PATH,\n", | |
" trust_remote_code=True,\n", | |
" quantization_config=nf4_config,\n", | |
" device_map=\"auto\",\n", | |
")\n", | |
"\n", | |
"tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False, trust_remote_code=True)\n", | |
"\n", | |
"model.generation_config = GenerationConfig.from_pretrained(MODEL_PATH)\n", | |
"model.generation_config.temperature = TEMPERATURE\n", | |
"model.generation_config.top_p = TOP_P\n", | |
"model.generation_config.max_new_tokens = MAX_LENGTH\n", | |
"\n", | |
"\n", | |
"def postprocess(self, y):\n", | |
" if y is None:\n", | |
" return []\n", | |
" for i, (message, response) in enumerate(y):\n", | |
" y[i] = (\n", | |
" None if message is None else mdtex2html.convert((message)),\n", | |
" None if response is None else mdtex2html.convert(response),\n", | |
" )\n", | |
" return y\n", | |
"\n", | |
"\n", | |
"gr.Chatbot.postprocess = postprocess\n", | |
"\n", | |
"\n", | |
"def parse_text(text):\n", | |
" lines = text.split(\"\\n\")\n", | |
" lines = [line for line in lines if line != \"\"]\n", | |
" count = 0\n", | |
" for i, line in enumerate(lines):\n", | |
" if \"```\" in line:\n", | |
" count += 1\n", | |
" items = line.split('`')\n", | |
" if count % 2 == 1:\n", | |
" lines[i] = f'<pre><code class=\"language-{items[-1]}\">'\n", | |
" else:\n", | |
" lines[i] = f'<br></code></pre>'\n", | |
" else:\n", | |
" if i > 0:\n", | |
" if count % 2 == 1:\n", | |
" line = line.replace(\"`\", \"\\`\")\n", | |
" line = line.replace(\"<\", \"<\")\n", | |
" line = line.replace(\">\", \">\")\n", | |
" line = line.replace(\" \", \" \")\n", | |
" line = line.replace(\"*\", \"*\")\n", | |
" line = line.replace(\"_\", \"_\")\n", | |
" line = line.replace(\"-\", \"-\")\n", | |
" line = line.replace(\".\", \".\")\n", | |
" line = line.replace(\"!\", \"!\")\n", | |
" line = line.replace(\"(\", \"(\")\n", | |
" line = line.replace(\")\", \")\")\n", | |
" line = line.replace(\"$\", \"$\")\n", | |
" lines[i] = \"<br>\" + line\n", | |
" text = \"\".join(lines)\n", | |
" return text\n", | |
"\n", | |
"\n", | |
"def predict(input, chatbot, history):\n", | |
" chatbot.append((parse_text(input), \"\"))\n", | |
" # 只保留最后 6 条对话记录\n", | |
" history = history[-6:]\n", | |
" history.append({\"role\": \"user\", \"content\": parse_text(input)})\n", | |
" if STREAM:\n", | |
" for response in model.chat(tokenizer, history, stream=True):\n", | |
" chatbot[-1] = (parse_text(input), parse_text(response))\n", | |
" yield chatbot, history\n", | |
" history.append({\"role\": \"assistant\", \"content\": response})\n", | |
" else:\n", | |
" response = model.chat(tokenizer, history)\n", | |
" chatbot[-1] = (parse_text(input), parse_text(response))\n", | |
" yield chatbot, history\n", | |
"\n", | |
"\n", | |
"def reset_user_input():\n", | |
" return gr.update(value='')\n", | |
"\n", | |
"\n", | |
"def reset_state():\n", | |
" return [], []\n", | |
"\n", | |
"\n", | |
"with gr.Blocks() as demo:\n", | |
" gr.HTML(\"\"\"<h1 align=\"center\">Baichuan 13B Chat</h1>\"\"\")\n", | |
" chatbot = gr.Chatbot()\n", | |
" with gr.Row():\n", | |
" with gr.Column(scale=4):\n", | |
" user_input = gr.Textbox(show_label=False, placeholder=\"在此输入消息\", lines=4).style(container=False)\n", | |
" with gr.Column(scale=1):\n", | |
" submitBtn = gr.Button(\"Submit\", variant=\"primary\")\n", | |
" emptyBtn = gr.Button(\"重置会话\")\n", | |
" history = gr.State([])\n", | |
" submitBtn.click(predict, [user_input, chatbot, history], [chatbot, history], show_progress=True)\n", | |
" submitBtn.click(reset_user_input, [], [user_input])\n", | |
" emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)\n", | |
"\n", | |
"demo.queue().launch(share=True, inbrowser=True, server_name=\"0.0.0.0\", server_port=9876)" | |
], | |
"metadata": { | |
"id": "jJlSeiHxixF_" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment