Last active
February 9, 2024 07:42
-
-
Save virattt/98fb3aa85603211bad87cc19b79fcfc5 to your computer and use it in GitHub Desktop.
langchain-crag-financial-assistant.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyOk31xq9//kMYrJkaoArNWL", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/virattt/98fb3aa85603211bad87cc19b79fcfc5/langchain-crag-financial-assistant.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "aSdU1s56Pouh" | |
}, | |
"outputs": [], | |
"source": [ | |
"! pip install langchain_community tiktoken langchain-openai langchainhub chromadb langchain langgraph tavily-python pypdf" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import getpass\n", | |
"import os\n", | |
"\n", | |
"# Set your OpenAI API key\n", | |
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "zSJ6772gP0sJ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Set your Tavily API key\n", | |
"os.environ[\"TAVILY_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "ppNjc7LXP3F9" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Download and prepare SEC filing" | |
], | |
"metadata": { | |
"id": "sz639zFf6JoK" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langchain.document_loaders import PyPDFLoader\n", | |
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n", | |
"\n", | |
"# Load $ABNB's financial report. This may take 1-2 minutes since the PDF is large\n", | |
"sec_filing_pdf = \"https://d18rn0p25nwr6d.cloudfront.net/CIK-0001559720/8a9ebed0-815a-469a-87eb-1767d21d8cec.pdf\"\n", | |
"\n", | |
"# Create your PDF loader\n", | |
"loader = PyPDFLoader(sec_filing_pdf)\n", | |
"\n", | |
"# Load the PDF document\n", | |
"documents = loader.load()\n", | |
"\n", | |
"# Chunk the financial report\n", | |
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=0)\n", | |
"docs = text_splitter.split_documents(documents)" | |
], | |
"metadata": { | |
"id": "rIO5t-j7611h" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Load the SEC filing into vector store" | |
], | |
"metadata": { | |
"id": "iaYSqxiMLUGb" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langchain_community.vectorstores import Chroma\n", | |
"from langchain.embeddings.openai import OpenAIEmbeddings\n", | |
"\n", | |
"# Load the document into Chroma\n", | |
"embedding_function = OpenAIEmbeddings()\n", | |
"vectorstore = Chroma.from_documents(docs, embedding_function)\n", | |
"\n", | |
"retriever = vectorstore.as_retriever()" | |
], | |
"metadata": { | |
"id": "QVZevdc-Md4N" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Define graph State" | |
], | |
"metadata": { | |
"id": "4m0Vu0YhRgAc" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from typing import Dict, TypedDict\n", | |
"\n", | |
"from langchain_core.messages import BaseMessage\n", | |
"\n", | |
"\n", | |
"class GraphState(TypedDict):\n", | |
" \"\"\"\n", | |
" Represents the state of our graph.\n", | |
"\n", | |
" Attributes:\n", | |
" keys: A dictionary where each key is a string.\n", | |
" \"\"\"\n", | |
"\n", | |
" keys: Dict[str, any]" | |
], | |
"metadata": { | |
"id": "ztIym-lyRhgc" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Define the graph's Nodes and Edges" | |
], | |
"metadata": { | |
"id": "opHvs3lrRmF7" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import json\n", | |
"import operator\n", | |
"from typing import Annotated, Sequence, TypedDict\n", | |
"\n", | |
"from langchain import hub\n", | |
"from langchain.output_parsers.openai_tools import PydanticToolsParser\n", | |
"from langchain.prompts import PromptTemplate\n", | |
"from langchain.schema import Document\n", | |
"from langchain_community.tools.tavily_search import TavilySearchResults\n", | |
"from langchain_community.vectorstores import Chroma\n", | |
"from langchain_core.messages import BaseMessage, FunctionMessage\n", | |
"from langchain_core.output_parsers import StrOutputParser\n", | |
"from langchain_core.pydantic_v1 import BaseModel, Field\n", | |
"from langchain_core.runnables import RunnablePassthrough\n", | |
"from langchain_core.utils.function_calling import convert_to_openai_tool\n", | |
"from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n", | |
"\n", | |
"### Nodes ###\n", | |
"\n", | |
"\n", | |
"def retrieve(state):\n", | |
" \"\"\"\n", | |
" Retrieve documents\n", | |
"\n", | |
" Args:\n", | |
" state (dict): The current graph state\n", | |
"\n", | |
" Returns:\n", | |
" state (dict): New key added to state, documents, that contains retrieved documents\n", | |
" \"\"\"\n", | |
" print(\"---RETRIEVE---\")\n", | |
" state_dict = state[\"keys\"]\n", | |
" question = state_dict[\"question\"]\n", | |
" documents = retriever.get_relevant_documents(question)\n", | |
" return {\"keys\": {\"documents\": documents, \"question\": question}}\n", | |
"\n", | |
"\n", | |
"def generate(state):\n", | |
" \"\"\"\n", | |
" Generate answer\n", | |
"\n", | |
" Args:\n", | |
" state (dict): The current graph state\n", | |
"\n", | |
" Returns:\n", | |
" state (dict): New key added to state, generation, that contains LLM generation\n", | |
" \"\"\"\n", | |
" print(\"---GENERATE---\")\n", | |
" state_dict = state[\"keys\"]\n", | |
" question = state_dict[\"question\"]\n", | |
" documents = state_dict[\"documents\"]\n", | |
"\n", | |
" # Prompt\n", | |
" prompt = hub.pull(\"rlm/rag-prompt\")\n", | |
"\n", | |
" # LLM\n", | |
" llm = ChatOpenAI(model_name=\"gpt-3.5-turbo\", temperature=0, streaming=True)\n", | |
"\n", | |
" # Post-processing\n", | |
" def format_docs(docs):\n", | |
" return \"\\n\\n\".join(doc.page_content for doc in docs)\n", | |
"\n", | |
" # Chain\n", | |
" rag_chain = prompt | llm | StrOutputParser()\n", | |
"\n", | |
" # Run\n", | |
" generation = rag_chain.invoke({\"context\": documents, \"question\": question})\n", | |
" return {\n", | |
" \"keys\": {\"documents\": documents, \"question\": question, \"generation\": generation}\n", | |
" }\n", | |
"\n", | |
"\n", | |
"def grade_documents(state):\n", | |
" \"\"\"\n", | |
" Determines whether the retrieved documents are relevant to the question.\n", | |
"\n", | |
" Args:\n", | |
" state (dict): The current graph state\n", | |
"\n", | |
" Returns:\n", | |
" state (dict): Updates documents key with relevant documents\n", | |
" \"\"\"\n", | |
"\n", | |
" print(\"---CHECK RELEVANCE---\")\n", | |
" state_dict = state[\"keys\"]\n", | |
" question = state_dict[\"question\"]\n", | |
" documents = state_dict[\"documents\"]\n", | |
"\n", | |
" # Data model\n", | |
" class grade(BaseModel):\n", | |
" \"\"\"Binary score for relevance check.\"\"\"\n", | |
"\n", | |
" binary_score: str = Field(description=\"Relevance score 'yes' or 'no'\")\n", | |
"\n", | |
" # LLM\n", | |
" model = ChatOpenAI(temperature=0, model=\"gpt-4-0125-preview\", streaming=True)\n", | |
"\n", | |
" # Tool\n", | |
" grade_tool_oai = convert_to_openai_tool(grade)\n", | |
"\n", | |
" # LLM with tool and enforce invocation\n", | |
" llm_with_tool = model.bind(\n", | |
" tools=[convert_to_openai_tool(grade_tool_oai)],\n", | |
" tool_choice={\"type\": \"function\", \"function\": {\"name\": \"grade\"}},\n", | |
" )\n", | |
"\n", | |
" # Parser\n", | |
" parser_tool = PydanticToolsParser(tools=[grade])\n", | |
"\n", | |
" # Prompt\n", | |
" prompt = PromptTemplate(\n", | |
" template=\"\"\"You are a grader assessing relevance of a retrieved document to a user question. \\n\n", | |
" Here is the retrieved document: \\n\\n {context} \\n\\n\n", | |
" Here is the user question: {question} \\n\n", | |
" If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant. \\n\n", | |
" Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.\"\"\",\n", | |
" input_variables=[\"context\", \"question\"],\n", | |
" )\n", | |
"\n", | |
" # Chain\n", | |
" chain = prompt | llm_with_tool | parser_tool\n", | |
"\n", | |
" # Score\n", | |
" filtered_docs = []\n", | |
" search = \"No\" # Default do not opt for web search to supplement retrieval\n", | |
" for d in documents:\n", | |
" score = chain.invoke({\"question\": question, \"context\": d.page_content})\n", | |
" grade = score[0].binary_score\n", | |
" if grade == \"yes\":\n", | |
" print(\"---GRADE: DOCUMENT RELEVANT---\")\n", | |
" filtered_docs.append(d)\n", | |
" else:\n", | |
" print(\"---GRADE: DOCUMENT NOT RELEVANT---\")\n", | |
" search = \"Yes\" # Perform web search\n", | |
" continue\n", | |
"\n", | |
" return {\n", | |
" \"keys\": {\n", | |
" \"documents\": filtered_docs,\n", | |
" \"question\": question,\n", | |
" \"run_web_search\": search,\n", | |
" }\n", | |
" }\n", | |
"\n", | |
"\n", | |
"def transform_query(state):\n", | |
" \"\"\"\n", | |
" Transform the query to produce a better question.\n", | |
"\n", | |
" Args:\n", | |
" state (dict): The current graph state\n", | |
"\n", | |
" Returns:\n", | |
" state (dict): Updates question key with a re-phrased question\n", | |
" \"\"\"\n", | |
"\n", | |
" print(\"---TRANSFORM QUERY---\")\n", | |
" state_dict = state[\"keys\"]\n", | |
" question = state_dict[\"question\"]\n", | |
" documents = state_dict[\"documents\"]\n", | |
"\n", | |
" # Create a prompt template with format instructions and the query\n", | |
" prompt = PromptTemplate(\n", | |
" template=\"\"\"You are generating questions that is well optimized for retrieval. \\n\n", | |
" Look at the input and try to reason about the underlying sematic intent / meaning. \\n\n", | |
" Here is the initial question:\n", | |
" \\n ------- \\n\n", | |
" {question}\n", | |
" \\n ------- \\n\n", | |
" Formulate an improved question: \"\"\",\n", | |
" input_variables=[\"question\"],\n", | |
" )\n", | |
"\n", | |
" # Grader\n", | |
" model = ChatOpenAI(temperature=0, model=\"gpt-4-0125-preview\", streaming=True)\n", | |
"\n", | |
" # Prompt\n", | |
" chain = prompt | model | StrOutputParser()\n", | |
" better_question = chain.invoke({\"question\": question})\n", | |
"\n", | |
" return {\"keys\": {\"documents\": documents, \"question\": better_question}}\n", | |
"\n", | |
"\n", | |
"def web_search(state):\n", | |
" \"\"\"\n", | |
" Web search based on the re-phrased question using Tavily API.\n", | |
"\n", | |
" Args:\n", | |
" state (dict): The current graph state\n", | |
"\n", | |
" Returns:\n", | |
" state (dict): Updates documents key with appended web results\n", | |
" \"\"\"\n", | |
"\n", | |
" print(\"---WEB SEARCH---\")\n", | |
" state_dict = state[\"keys\"]\n", | |
" question = state_dict[\"question\"]\n", | |
" documents = state_dict[\"documents\"]\n", | |
"\n", | |
" tool = TavilySearchResults()\n", | |
" docs = tool.invoke({\"query\": question})\n", | |
" web_results = \"\\n\".join([d[\"content\"] for d in docs])\n", | |
" web_results = Document(page_content=web_results)\n", | |
" documents.append(web_results)\n", | |
"\n", | |
" return {\"keys\": {\"documents\": documents, \"question\": question}}\n", | |
"\n", | |
"\n", | |
"### Edges\n", | |
"\n", | |
"\n", | |
"def decide_to_generate(state):\n", | |
" \"\"\"\n", | |
" Determines whether to generate an answer or re-generate a question for web search.\n", | |
"\n", | |
" Args:\n", | |
" state (dict): The current state of the agent, including all keys.\n", | |
"\n", | |
" Returns:\n", | |
" str: Next node to call\n", | |
" \"\"\"\n", | |
"\n", | |
" print(\"---DECIDE TO GENERATE---\")\n", | |
" state_dict = state[\"keys\"]\n", | |
" question = state_dict[\"question\"]\n", | |
" filtered_documents = state_dict[\"documents\"]\n", | |
" search = state_dict[\"run_web_search\"]\n", | |
"\n", | |
" if search == \"Yes\":\n", | |
" # All documents have been filtered check_relevance\n", | |
" # We will re-generate a new query\n", | |
" print(\"---DECISION: TRANSFORM QUERY and RUN WEB SEARCH---\")\n", | |
" return \"transform_query\"\n", | |
" else:\n", | |
" # We have relevant documents, so generate answer\n", | |
" print(\"---DECISION: GENERATE---\")\n", | |
" return \"generate\"" | |
], | |
"metadata": { | |
"id": "mdVk88QkRoxl" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Build the graph" | |
], | |
"metadata": { | |
"id": "3vNh_m4bRq4r" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import pprint\n", | |
"\n", | |
"from langgraph.graph import END, StateGraph\n", | |
"\n", | |
"workflow = StateGraph(GraphState)\n", | |
"\n", | |
"# Define the nodes\n", | |
"workflow.add_node(\"retrieve\", retrieve) # retrieve\n", | |
"workflow.add_node(\"grade_documents\", grade_documents) # grade documents\n", | |
"workflow.add_node(\"generate\", generate) # generatae\n", | |
"workflow.add_node(\"transform_query\", transform_query) # transform_query\n", | |
"workflow.add_node(\"web_search\", web_search) # web search\n", | |
"\n", | |
"# Build graph\n", | |
"workflow.set_entry_point(\"retrieve\")\n", | |
"workflow.add_edge(\"retrieve\", \"grade_documents\")\n", | |
"workflow.add_conditional_edges(\n", | |
" \"grade_documents\",\n", | |
" decide_to_generate,\n", | |
" {\n", | |
" \"transform_query\": \"transform_query\",\n", | |
" \"generate\": \"generate\",\n", | |
" },\n", | |
")\n", | |
"workflow.add_edge(\"transform_query\", \"web_search\")\n", | |
"workflow.add_edge(\"web_search\", \"generate\")\n", | |
"workflow.add_edge(\"generate\", END)\n", | |
"\n", | |
"# Compile\n", | |
"app = workflow.compile()" | |
], | |
"metadata": { | |
"id": "du5RScfyRr48" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Run the graph" | |
], | |
"metadata": { | |
"id": "KJC48i6sRvsV" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Run\n", | |
"question = \"What was Airbnb's revenue in Q3 2023?\"\n", | |
"inputs = {\"keys\": {\"question\": question}}\n", | |
"print(f\"Question: {question}\\n\")\n", | |
"for output in app.stream(inputs):\n", | |
" for key, value in output.items():\n", | |
" # Print Node\n", | |
" print()\n", | |
"\n", | |
"# Final generation\n", | |
"answer = value['keys']['generation']\n", | |
"print(f\"Answer: {answer}\")" | |
], | |
"metadata": { | |
"id": "t0dCdGbURwcT" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment