Created
July 8, 2024 12:08
-
-
Save ehzawad/fa15c4b514f743aa49902f7778db937f to your computer and use it in GitHub Desktop.
RAG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Cell 1: Install dependencies | |
!pip install -q -U transformers llama-index accelerate pypdf einops bitsandbytes | |
!pip install -q llama-index-llms-huggingface | |
!pip install -q llama-index-embeddings-huggingface | |
# Cell 2: Import libraries and set up warnings | |
import warnings | |
warnings.filterwarnings('ignore') | |
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, ServiceContext | |
from llama_index.llms.huggingface import HuggingFaceLLM | |
from llama_index.core.prompts import PromptTemplate | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.core.memory import ChatMemoryBuffer | |
from llama_index.core.response.notebook_utils import display_source_node | |
import torch | |
# Cell 3: Set up prompts and login to Hugging Face | |
system_prompt = """<|SYSTEM|># | |
You are a helpful, respectful and honest assistant. Always consider the chat history when answering questions. | |
""" | |
query_wrapper_prompt = PromptTemplate("<|USER|>{query_str}<|ASSISTANT|>") | |
from huggingface_hub import login | |
login(token='your_huggingface_token_here') | |
# Cell 4: Load documents | |
documents = SimpleDirectoryReader("/content/data").load_data() | |
# Cell 5: Set up embedding model | |
embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-mpnet-base-v2") | |
# Cell 6: Set up LLM | |
llm = HuggingFaceLLM( | |
context_window=4096, | |
max_new_tokens=256, | |
generate_kwargs={"temperature": 0, "do_sample": False}, | |
system_prompt=system_prompt, | |
query_wrapper_prompt=query_wrapper_prompt, | |
tokenizer_name="meta-llama/Llama-2-7b-chat-hf", | |
model_name="meta-llama/Llama-2-7b-chat-hf", | |
device_map="auto", | |
tokenizer_kwargs={"max_length": 4096}, | |
model_kwargs={ | |
"torch_dtype": torch.float16, | |
"llm_int8_enable_fp32_cpu_offload": True, | |
"bnb_4bit_quant_type": 'nf4', | |
"bnb_4bit_use_double_quant": True, | |
"bnb_4bit_compute_dtype": torch.bfloat16, | |
"load_in_4bit": True | |
} | |
) | |
# Cell 7: Set up service context and index | |
service_context = ServiceContext.from_defaults( | |
chunk_size=2048, | |
chunk_overlap=50, | |
llm=llm, | |
embed_model=embed_model | |
) | |
index = VectorStoreIndex.from_documents( | |
documents, service_context=service_context | |
) | |
# Cell 8: Set up chat memory and query engine | |
memory = ChatMemoryBuffer.from_defaults(token_limit=1500) | |
query_engine = index.as_query_engine( | |
streaming=True, | |
similarity_top_k=2, | |
chat_memory=memory | |
) | |
# Cell 9: Function to handle conversation | |
def chat_with_rag(query): | |
response_stream = query_engine.query(query) | |
response_stream.print_response_stream() | |
print("\nSources:") | |
for node in response_stream.source_nodes: | |
print(f"- {node.node.get_content()[:100]}...") | |
# Cell 10: Main conversation loop | |
print("Welcome to the Conversational RAG system. Type 'exit' to end the conversation.") | |
while True: | |
user_input = input("You: ") | |
if user_input.lower() == 'exit': | |
print("Thank you for using the Conversational RAG system. Goodbye!") | |
break | |
chat_with_rag(user_input) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment