Skip to content

Instantly share code, notes, and snippets.

@epicfilemcnulty
Created June 9, 2023 17:21
Show Gist options
  • Save epicfilemcnulty/0ebac8cd7d9bc2ef5e704c991c99a92c to your computer and use it in GitHub Desktop.
Save epicfilemcnulty/0ebac8cd7d9bc2ef5e704c991c99a92c to your computer and use it in GitHub Desktop.
local docs embeddings
import argparse
import time
from langchain.chains import RetrievalQA
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.llms import HuggingFacePipeline
from transformers import LlamaTokenizer, LlamaForCausalLM, pipeline
from chromadb.config import Settings
from langchain.llms import LlamaCpp
from bottle import Bottle, run, route, request
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--embed_model', required=False, type=str, default='/storage/models/instructor/instructor-large', help="Embedding model")
parser.add_argument('-m', '--model', required=False, type=str, default='/storage/models/LLaMA/Wizard-Vicuna-13B-Uncensored-HF', help="Grasping Model")
parser.add_argument('-a', '--model_name', required=False, type=str, default="WizVicUncen13.8bit", help="Grasping Model's Alias")
parser.add_argument('-l', '--length', required=False, type=int, default=2048, help="Max sequence length for embedding model")
parser.add_argument('-t', '--threads', required=False, type=int, default=16, help="Number of CPU threads")
parser.add_argument('-n', '--ngl', required=False, type=int, default=40, help="Number of layers to upload to GPU")
parser.add_argument('--port', default=8014, required=False, type=int, help="Port to listen on")
parser.add_argument('--ip', default='127.0.0.1', required=False, type=str, help="IP to listen on")
args = parser.parse_args()
app = Bottle()
def load_ggml():
llm = LlamaCpp(
model_path=args.model,
n_gpu_layers=args.ngl, n_batch=512, temperature=0, n_ctx=2048, n_threads=args.threads,
verbose=False
)
return llm
def load_model():
model_id = args.model
tokenizer = LlamaTokenizer.from_pretrained(model_id)
model = LlamaForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map='auto')
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_length=2048,
temperature=0,
top_p=0.95,
repetition_penalty=1.15
)
local_llm = HuggingFacePipeline(pipeline=pipe)
return local_llm
embeddings = HuggingFaceInstructEmbeddings(model_name=args.embed_model, model_kwargs={"device": 'cuda'})
embeddings.client.max_seq_length = args.length
if 'ggml' in args.model:
llm = load_ggml()
else:
llm = load_model()
@app.route('/docs', method='POST')
def chat():
data = request.json
embed_db_dir = data['embeddings']
query = data['query']
CHROMA_SETTINGS = Settings(
chroma_db_impl='duckdb+parquet',
persist_directory=embed_db_dir,
anonymized_telemetry=False
)
db = Chroma(persist_directory=embed_db_dir, embedding_function=embeddings, client_settings=CHROMA_SETTINGS)
retriever = db.as_retriever()
# there is also ConversationalRetrievalChain
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
start_time = time.time_ns()
res = qa(query)
answer, docs = res['result'], res['source_documents']
end_time = time.time_ns()
secs = (end_time - start_time) / 1e9
s_docs = []
tokens = len(answer.split())*1.2*2
for document in docs:
if document.metadata["source"] not in s_docs:
s_docs.append(document.metadata["source"])
#print(document.page_content)
return {
"text": answer,
"sources": s_docs,
"tokens": tokens,
"rate": tokens / secs,
"model": args.model_name,
}
run(app, host=args.ip, port=args.port)
import argparse
import os
from typing import List
from langchain.document_loaders import TextLoader, PyMuPDFLoader, CSVLoader, UnstructuredEPubLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceInstructEmbeddings
from chromadb.config import Settings
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--embed_model', required=False, type=str, default='/storage/models/instructor/instructor-large', help="Embedding model")
parser.add_argument('-s', '--source_dir', required=True, type=str, help="Path to the directory with documents to embed")
parser.add_argument('-d', '--db_dir', required=True, type=str, help="Path to the directory where chromadb files will be stored")
parser.add_argument('-l', '--length', required=False, type=int, default=2048, help="Max sequence length for embedding model")
args = parser.parse_args()
CHROMA_SETTINGS = Settings(
chroma_db_impl='duckdb+parquet',
persist_directory=args.db_dir,
anonymized_telemetry=False
)
def load_single_document(file_path: str) -> Document:
# Loads a single document from a file path
if file_path.endswith(".txt"):
loader = TextLoader(file_path, encoding="utf8")
elif file_path.endswith(".pdf"):
loader = PyMuPDFLoader(file_path)
elif file_path.endswith(".csv"):
loader = CSVLoader(file_path)
elif file_path.endswith(".epub"):
loader = UnstructuredEPubLoader(file_path)
return loader.load()[0]
def load_documents(source_dir: str) -> List[Document]:
# Loads all documents from source documents directory
docs = []
for dirpath, dirnames, filenames in os.walk(source_dir):
for file_name in filenames:
if file_name.endswith(('.txt', '.pdf', '.csv', '.epub')):
full_file_path = os.path.join(dirpath, file_name)
docs.append(load_single_document(full_file_path))
return docs
def main():
device='cuda'
print(f"Loading documents from {args.source_dir}")
documents = load_documents(args.source_dir)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
print(f"Loaded {len(documents)} documents from {args.source_dir}")
print(f"Split into {len(texts)} chunks of text")
embeddings = HuggingFaceInstructEmbeddings(model_name=args.embed_model, model_kwargs={"device": device})
embeddings.client.max_seq_length = args.length
db = Chroma.from_documents(texts, embeddings, persist_directory=args.db_dir, client_settings=CHROMA_SETTINGS)
db.persist()
db = None
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment