Skip to content

Instantly share code, notes, and snippets.

@ashhadulislam
Last active September 23, 2023 16:04
Show Gist options
  • Save ashhadulislam/c5793834cbeaaab8c05817dd65975660 to your computer and use it in GitHub Desktop.
Save ashhadulislam/c5793834cbeaaab8c05817dd65975660 to your computer and use it in GitHub Desktop.
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyPDFLoader, DirectoryLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
DATA_PATH="data/"
DB_FAISS_PATH = "vectorstores/db_faiss/"
def create_vector_db():
loader=DirectoryLoader(DATA_PATH,glob="*.pdf",loader_cls=PyPDFLoader)
documents=loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
texts=text_splitter.split_documents(documents)
embeddings=HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2',
model_kwargs={'device':'cpu'})
db=FAISS.from_documents(texts,embeddings)
db.save_local(DB_FAISS_PATH)
if __name__=="__main__":
create_vector_db()
## Below for chainlit
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import CTransformers
from langchain.chains import RetrievalQA
import chainlit as cl
DB_FAISS_PATH = "vectorstores/db_faiss/"
custom_prompt_template='''Use the following pieces of information to answer the users question.
If you don't know the answer, please just say that you don't know the answer. Don't make up an answer.
Context:{context}
question:{question}
Only returns the helpful anser below and nothing else.
Helpful answer
'''
def set_custom_prompt():
'''
Prompt template for QA retrieval for each vector store
'''
prompt =PromptTemplate(template=custom_prompt_template, input_variables=['context','question'])
return prompt
def load_llm():
llm = CTransformers(
model='llama-2-7b-chat.ggmlv3.q8_0.bin',
model_type='llama',
max_new_tokens=512,
temperature=0.5
)
return llm
def retrieval_qa_chain(llm,prompt,db):
qa_chain=RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={'k':2}),
return_source_documents=True,
chain_type_kwargs={'prompt':prompt }
)
return qa_chain
def qa_bot(vector):
embeddings=HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2',
model_kwargs={'device':'cpu'})
db = FAISS.load_local(vector,embeddings)
llm=load_llm()
qa_prompt=set_custom_prompt()
qa = retrieval_qa_chain(llm,qa_prompt,db)
return qa
def final_result(query):
qa_result=qa_bot()
response=qa_result({'query':query})
return response
## chainlit here
@cl.on_chat_start
async def start():
chain=qa_bot()
msg=cl.Message(content="Firing up the shariah bot...")
await msg.send()
msg.content= "Hi, welcome to sharia bot. What is your query?"
await msg.update()
cl.user_session.set("chain",chain)
@cl.on_message
async def main(message):
chain=cl.user_session.get("chain")
cb = cl.AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL","ANSWER"]
)
cb.ansert_reached=True
res=await chain.acall(message, callbacks=[cb])
answer=res["result"]
sources=res["source_documents"]
if sources:
answer+=f"\nSources: "+str(str(sources))
else:
answer+=f"\nNo Sources found"
await cl.Message(content=answer).send()
# what are the characteristics of charge card?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment