Skip to content

Instantly share code, notes, and snippets.

@ShadowPower
Created July 14, 2023 07:04
Show Gist options
  • Save ShadowPower/e01539dee602d9601dca53509fef8c19 to your computer and use it in GitHub Desktop.
Save ShadowPower/e01539dee602d9601dca53509fef8c19 to your computer and use it in GitHub Desktop.
Baichuan 13B for Google Colab
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"安装依赖:"
],
"metadata": {
"id": "MCYKztKnii28"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "h3OTlwGJimIc",
"collapsed": true,
"cellView": "form"
},
"outputs": [],
"source": [
"!pip install accelerate\n",
"!pip install bitsandbytes\n",
"!pip install colorama\n",
"!pip install cpm_kernels\n",
"!pip install sentencepiece\n",
"!pip install streamlit\n",
"!pip install transformers_stream_generator\n",
"!pip install gradio\n",
"!pip install mdtex2html"
]
},
{
"cell_type": "markdown",
"source": [
"聊天:"
],
"metadata": {
"id": "cSN-eQ8MilaU"
}
},
{
"cell_type": "code",
"source": [
"import gradio as gr\n",
"import mdtex2html\n",
"import os\n",
"import torch\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
"from transformers.generation.utils import GenerationConfig\n",
"\n",
"# Configuration\n",
"MODEL_PATH = 'sharpbai/Baichuan-13B-Chat'\n",
"\n",
"MAX_LENGTH = 2048\n",
"TOP_P = 0.85\n",
"TEMPERATURE = 0.05\n",
"STREAM = True\n",
"\n",
"\n",
"nf4_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
")\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" MODEL_PATH,\n",
" trust_remote_code=True,\n",
" quantization_config=nf4_config,\n",
" device_map=\"auto\",\n",
")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False, trust_remote_code=True)\n",
"\n",
"model.generation_config = GenerationConfig.from_pretrained(MODEL_PATH)\n",
"model.generation_config.temperature = TEMPERATURE\n",
"model.generation_config.top_p = TOP_P\n",
"model.generation_config.max_new_tokens = MAX_LENGTH\n",
"\n",
"\n",
"def postprocess(self, y):\n",
" if y is None:\n",
" return []\n",
" for i, (message, response) in enumerate(y):\n",
" y[i] = (\n",
" None if message is None else mdtex2html.convert((message)),\n",
" None if response is None else mdtex2html.convert(response),\n",
" )\n",
" return y\n",
"\n",
"\n",
"gr.Chatbot.postprocess = postprocess\n",
"\n",
"\n",
"def parse_text(text):\n",
" lines = text.split(\"\\n\")\n",
" lines = [line for line in lines if line != \"\"]\n",
" count = 0\n",
" for i, line in enumerate(lines):\n",
" if \"```\" in line:\n",
" count += 1\n",
" items = line.split('`')\n",
" if count % 2 == 1:\n",
" lines[i] = f'<pre><code class=\"language-{items[-1]}\">'\n",
" else:\n",
" lines[i] = f'<br></code></pre>'\n",
" else:\n",
" if i > 0:\n",
" if count % 2 == 1:\n",
" line = line.replace(\"`\", \"\\`\")\n",
" line = line.replace(\"<\", \"&lt;\")\n",
" line = line.replace(\">\", \"&gt;\")\n",
" line = line.replace(\" \", \"&nbsp;\")\n",
" line = line.replace(\"*\", \"&ast;\")\n",
" line = line.replace(\"_\", \"&lowbar;\")\n",
" line = line.replace(\"-\", \"&#45;\")\n",
" line = line.replace(\".\", \"&#46;\")\n",
" line = line.replace(\"!\", \"&#33;\")\n",
" line = line.replace(\"(\", \"&#40;\")\n",
" line = line.replace(\")\", \"&#41;\")\n",
" line = line.replace(\"$\", \"&#36;\")\n",
" lines[i] = \"<br>\" + line\n",
" text = \"\".join(lines)\n",
" return text\n",
"\n",
"\n",
"def predict(input, chatbot, history):\n",
" chatbot.append((parse_text(input), \"\"))\n",
" # 只保留最后 6 条对话记录\n",
" history = history[-6:]\n",
" history.append({\"role\": \"user\", \"content\": parse_text(input)})\n",
" if STREAM:\n",
" for response in model.chat(tokenizer, history, stream=True):\n",
" chatbot[-1] = (parse_text(input), parse_text(response))\n",
" yield chatbot, history\n",
" history.append({\"role\": \"assistant\", \"content\": response})\n",
" else:\n",
" response = model.chat(tokenizer, history)\n",
" chatbot[-1] = (parse_text(input), parse_text(response))\n",
" yield chatbot, history\n",
"\n",
"\n",
"def reset_user_input():\n",
" return gr.update(value='')\n",
"\n",
"\n",
"def reset_state():\n",
" return [], []\n",
"\n",
"\n",
"with gr.Blocks() as demo:\n",
" gr.HTML(\"\"\"<h1 align=\"center\">Baichuan 13B Chat</h1>\"\"\")\n",
" chatbot = gr.Chatbot()\n",
" with gr.Row():\n",
" with gr.Column(scale=4):\n",
" user_input = gr.Textbox(show_label=False, placeholder=\"在此输入消息\", lines=4).style(container=False)\n",
" with gr.Column(scale=1):\n",
" submitBtn = gr.Button(\"Submit\", variant=\"primary\")\n",
" emptyBtn = gr.Button(\"重置会话\")\n",
" history = gr.State([])\n",
" submitBtn.click(predict, [user_input, chatbot, history], [chatbot, history], show_progress=True)\n",
" submitBtn.click(reset_user_input, [], [user_input])\n",
" emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)\n",
"\n",
"demo.queue().launch(share=True, inbrowser=True, server_name=\"0.0.0.0\", server_port=9876)"
],
"metadata": {
"id": "jJlSeiHxixF_"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment