Last active
August 2, 2023 03:11
-
-
Save ShadowPower/776737d66c987e9eac8d08726fde6f66 to your computer and use it in GitHub Desktop.
baichuan-13b quantized
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import torch | |
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from transformers.generation.utils import GenerationConfig | |
st.set_page_config(page_title="Baichuan-13B-Chat") | |
st.title("Baichuan-13B-Chat") | |
model_path = "baichuan-inc/Baichuan-13B-Chat" | |
@st.cache_resource | |
def init_model(): | |
nf4_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
quantization_config=nf4_config, | |
device_map="auto", | |
trust_remote_code=True | |
).eval() | |
model.generation_config = GenerationConfig.from_pretrained( | |
model_path | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_path, | |
use_fast=False, | |
trust_remote_code=True | |
) | |
return model, tokenizer | |
def clear_chat_history(): | |
del st.session_state.messages | |
def init_chat_history(): | |
with st.chat_message("assistant", avatar='🤖'): | |
st.markdown("您好,我是百川大模型,很高兴为您服务🥰") | |
if "messages" in st.session_state: | |
for message in st.session_state.messages: | |
avatar = '🧑💻' if message["role"] == "user" else '🤖' | |
with st.chat_message(message["role"], avatar=avatar): | |
st.markdown(message["content"]) | |
else: | |
st.session_state.messages = [] | |
return st.session_state.messages | |
def main(): | |
model, tokenizer = init_model() | |
messages = init_chat_history() | |
if prompt := st.chat_input("你好,百川"): | |
with st.chat_message("user", avatar='🧑💻'): | |
st.markdown(prompt) | |
messages.append({"role": "user", "content": prompt}) | |
print(f"[user] {prompt}", flush=True) | |
with st.chat_message("assistant", avatar='🤖'): | |
placeholder = st.empty() | |
for response in model.chat(tokenizer, messages, stream=True): | |
placeholder.markdown(response) | |
messages.append({"role": "assistant", "content": response}) | |
print(json.dumps(messages, ensure_ascii=False), flush=True) | |
st.button("清空对话", on_click=clear_chat_history) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment