Skip to content

Instantly share code, notes, and snippets.

@MarshalW
Created September 5, 2024 07:24
Show Gist options
  • Save MarshalW/15ded88349781c4e18be855a5388d0da to your computer and use it in GitHub Desktop.
Save MarshalW/15ded88349781c4e18be855a5388d0da to your computer and use it in GitHub Desktop.
My-KBS, 个人知识库问答系统
import streamlit as st
from streamlit.logger import get_logger
import os
from llama_index.core import Settings
from llama_index.llms.ollama import Ollama
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.core import SimpleDirectoryReader
from llama_index.core import VectorStoreIndex
from llama_index.core import StorageContext, load_index_from_storage
logger = get_logger(__name__)
VERSION = "v0.2.0"
kb_dir = "/data/知识库"
index_dir = "/index"
st.set_page_config(
"My KBS",
)
llm_base_url = os.getenv('LLM_BASE_URL', 'http://localhost:11434')
llm_model_name = os.getenv('LLM_MODEL_NAME', 'qwen2')
embedding_base_url = os.getenv(
'EMBEDDING_BASE_URL', 'http://localhost:11434')
embedding_model_name = os.getenv(
'EMBEDDING_MODEL_NAME', 'quentinz/bge-large-zh-v1.5')
Settings.llm = Ollama(
base_url=llm_base_url,
model=llm_model_name,
is_chat_model=True,
temperature=0.1,
request_timeout=60.0
)
Settings.embed_model = OllamaEmbedding(
model_name=embedding_model_name,
base_url=embedding_base_url,
# -mirostat N 使用 Mirostat 采样。
ollama_additional_kwargs={"mirostat": 0},
)
documents = SimpleDirectoryReader(
input_dir=kb_dir,
recursive=True,
filename_as_id=True,
required_exts=[".md"],
).load_data()
if not os.path.exists(index_dir) or not os.path.exists(f"{index_dir}/docstore.json"):
index = VectorStoreIndex.from_documents(documents)
index.storage_context.persist(persist_dir=index_dir)
else:
storage_context = StorageContext.from_defaults(persist_dir=index_dir)
index = load_index_from_storage(storage_context)
index.refresh_ref_docs(documents)
index.storage_context.persist(persist_dir=index_dir)
def _show_sources(source_nodes):
with st.expander("搜索结果"):
filtered_nodes = [node for node in source_nodes if node.score > 0]
if len(filtered_nodes) == 0:
st.write('未检索到有效结果')
else:
for i, node in enumerate(filtered_nodes):
st.write(node.score)
st.write(node.text)
if i < len(filtered_nodes) - 1:
st.divider()
def find_md_files(dir_path):
md_files = []
for root, dirs, files in os.walk(dir_path):
for file in files:
if file.endswith('.md'):
md_files.append(os.path.join(root, file))
return md_files
md_file_paths = find_md_files(kb_dir)
with st.sidebar:
st.title('My KBS')
st.caption(f'版本: {VERSION}')
st.caption(f"知识库文件数: {len(md_file_paths)} 个")
st.number_input("相似度前k条:", 0, 20, 2,
key='similarity_top_k',
help="检索语义最接近的前k条")
# 初始化消息
if "messages" not in st.session_state.keys():
st.session_state.messages = [
{"role": "assistant",
"content": "请问我问题吧。"}
]
# 显示输入框
if prompt := st.chat_input(placeholder="这里输入问题"):
st.session_state.messages.append({"role": "user", "content": prompt})
# 显示之前的消息
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
if "source_nodes" in message:
_show_sources(message['source_nodes'])
# 输出助手回答
if st.session_state.messages[-1]["role"] != "assistant":
with st.chat_message("assistant"):
query_engine = index.as_query_engine(
streaming=True,
similarity_top_k=st.session_state.similarity_top_k,
)
# 流式输出
stream = query_engine.query(prompt)
response = st.write_stream(stream.response_gen)
_show_sources(stream.source_nodes)
message = {"role": "assistant", "content": response,
"source_nodes": stream.source_nodes}
st.session_state.messages.append(message)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment