Created
May 9, 2024 02:04
-
-
Save inspirit941/35a27ad79538cb61c8026df685f54f6f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "140580ef-5db0-43cc-a524-9c39e04d4df0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"! pip install langchain unstructured[all-docs] pydantic lxml openai chromadb tiktoken" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "74b56bde-1ba0-4525-a11d-cab02c5659e4", | |
"metadata": {}, | |
"source": [ | |
"## Data Loading\n", | |
"\n", | |
"### Partition PDF tables, text, and images\n", | |
"\n", | |
"* Use [Unstructured](https://unstructured-io.github.io/unstructured/) to partition elements" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e98bdeb7-eb77-42e6-a3a5-c3f27a1838d5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from typing import Any\n", | |
"import os\n", | |
"from unstructured.partition.pdf import partition_pdf\n", | |
"import pytesseract\n", | |
"import os\n", | |
"\n", | |
"pytesseract.pytesseract.tesseract_cmd = r'C:\\Program Files\\Tesseract-OCR\\tesseract.exe'\n", | |
"\n", | |
"input_path = os.getcwd()\n", | |
"output_path = os.path.join(os.getcwd(), \"output\")\n", | |
"\n", | |
"# Get elements\n", | |
"raw_pdf_elements = partition_pdf(\n", | |
" filename=os.path.join(input_path, \"test.pdf\"),\n", | |
" extract_images_in_pdf=True,\n", | |
" infer_table_structure=True,\n", | |
" chunking_strategy=\"by_title\",\n", | |
" max_characters=4000,\n", | |
" new_after_n_chars=3800,\n", | |
" combine_text_under_n_chars=2000,\n", | |
" image_output_dir_path=output_path,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5f660305-e165-4b6c-ada3-a67a422defb5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import base64\n", | |
"\n", | |
"text_elements = []\n", | |
"table_elements = []\n", | |
"image_elements = []\n", | |
"\n", | |
"# Function to encode images\n", | |
"def encode_image(image_path):\n", | |
" with open(image_path, \"rb\") as image_file:\n", | |
" return base64.b64encode(image_file.read()).decode('utf-8')\n", | |
"\n", | |
"for element in raw_pdf_elements:\n", | |
" if 'CompositeElement' in str(type(element)):\n", | |
" text_elements.append(element)\n", | |
" elif 'Table' in str(type(element)):\n", | |
" table_elements.append(element)\n", | |
"\n", | |
"table_elements = [i.text for i in table_elements]\n", | |
"text_elements = [i.text for i in text_elements]\n", | |
"\n", | |
"# Tables\n", | |
"print(len(table_elements))\n", | |
"\n", | |
"# Text\n", | |
"print(len(text_elements))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for image_file in os.listdir(output_path):\n", | |
" if image_file.endswith(('.png', '.jpg', '.jpeg')):\n", | |
" image_path = os.path.join(output_path, image_file)\n", | |
" encoded_image = encode_image(image_path)\n", | |
" image_elements.append(encoded_image)\n", | |
"print(len(image_elements))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from langchain.chat_models import ChatOpenAI\n", | |
"from langchain.schema.messages import HumanMessage, AIMessage\n", | |
"\n", | |
"chain_gpt_35 = ChatOpenAI(model=\"gpt-3.5-turbo\", max_tokens=1024)\n", | |
"chain_gpt_4_vision = ChatOpenAI(model=\"gpt-4-vision-preview\", max_tokens=1024)\n", | |
"\n", | |
"# Function for text summaries\n", | |
"def summarize_text(text_element):\n", | |
" prompt = f\"Summarize the following text:\\n\\n{text_element}\\n\\nSummary:\"\n", | |
" response = chain_gpt_35.invoke([HumanMessage(content=prompt)])\n", | |
" return response.content\n", | |
"\n", | |
"# Function for table summaries\n", | |
"def summarize_table(table_element):\n", | |
" prompt = f\"Summarize the following table:\\n\\n{table_element}\\n\\nSummary:\"\n", | |
" response = chain_gpt_35.invoke([HumanMessage(content=prompt)])\n", | |
" return response.content\n", | |
"\n", | |
"# Function for image summaries\n", | |
"def summarize_image(encoded_image):\n", | |
" prompt = [\n", | |
" AIMessage(content=\"You are a bot that is good at analyzing images.\"),\n", | |
" HumanMessage(content=[\n", | |
" {\"type\": \"text\", \"text\": \"Describe the contents of this image.\"},\n", | |
" {\n", | |
" \"type\": \"image_url\",\n", | |
" \"image_url\": {\n", | |
" \"url\": f\"data:image/jpeg;base64,{encoded_image}\"\n", | |
" },\n", | |
" },\n", | |
" ])\n", | |
" ]\n", | |
" response = chain_gpt_4_vision.invoke(prompt)\n", | |
" return response.content" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Processing table elements with feedback and sleep\n", | |
"table_summaries = []\n", | |
"for i, te in enumerate(table_elements[0:2]):\n", | |
" summary = summarize_table(te)\n", | |
" table_summaries.append(summary)\n", | |
" print(f\"{i + 1}th element of tables processed.\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Processing text elements with feedback and sleep\n", | |
"text_summaries = []\n", | |
"for i, te in enumerate(text_elements[0:2]):\n", | |
" summary = summarize_text(te)\n", | |
" text_summaries.append(summary)\n", | |
" print(f\"{i + 1}th element of texts processed.\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Processing image elements with feedback and sleep\n", | |
"image_summaries = []\n", | |
"for i, ie in enumerate(image_elements[0:2]):\n", | |
" summary = summarize_image(ie)\n", | |
" image_summaries.append(summary)\n", | |
" print(f\"{i + 1}th element of images processed.\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0aa7f52f-bf5c-4ba4-af72-b2ccba59a4cf", | |
"metadata": {}, | |
"source": [ | |
"## Multi-vector retriever\n", | |
"\n", | |
"Use [multi-vector-retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector#summary).\n", | |
"\n", | |
"Summaries are used to retrieve raw tables and / or raw chunks of text." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "67b030d4-2ac5-41b6-9245-fc3ba5771d87", | |
"metadata": {}, | |
"source": [ | |
"### Add to vectorstore\n", | |
"\n", | |
"Use [Multi Vector Retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector#summary) with summaries." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d643cc61-827d-4f3c-8242-7a7c8291ed8a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import uuid\n", | |
"\n", | |
"from langchain.embeddings import OpenAIEmbeddings\n", | |
"from langchain.retrievers.multi_vector import MultiVectorRetriever\n", | |
"from langchain.schema.document import Document\n", | |
"from langchain.storage import InMemoryStore\n", | |
"from langchain.vectorstores import Chroma\n", | |
"\n", | |
"# Initialize the vector store and storage layer\n", | |
"vectorstore = Chroma(collection_name=\"summaries\", embedding_function=OpenAIEmbeddings())\n", | |
"store = InMemoryStore()\n", | |
"id_key = \"doc_id\"\n", | |
"\n", | |
"# Initialize the retriever\n", | |
"retriever = MultiVectorRetriever(vectorstore=vectorstore, docstore=store, id_key=id_key)\n", | |
"\n", | |
"# Function to add documents to the retriever\n", | |
"def add_documents_to_retriever(summaries, original_contents):\n", | |
" doc_ids = [str(uuid.uuid4()) for _ in summaries]\n", | |
" summary_docs = [\n", | |
" Document(page_content=s, metadata={id_key: doc_ids[i]})\n", | |
" for i, s in enumerate(summaries)\n", | |
" ]\n", | |
" retriever.vectorstore.add_documents(summary_docs)\n", | |
" retriever.docstore.mset(list(zip(doc_ids, original_contents)))\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Add text summaries\n", | |
"add_documents_to_retriever(text_summaries, text_elements)\n", | |
"\n", | |
"# Add table summaries\n", | |
"add_documents_to_retriever(table_summaries, table_elements)\n", | |
"\n", | |
"# Add image summaries\n", | |
"add_documents_to_retriever(image_summaries, image_summaries) # hopefully real images soon" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4b45fb81-46b1-426e-aa2c-01aed4eac700", | |
"metadata": {}, | |
"source": [ | |
"# Table retrieval\n", | |
"\n", | |
"The most complex table in the paper:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "1bea75fe-85af-4955-a80c-6e0b44a8e215", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# We can retrieve this table\n", | |
"retriever.get_relevant_documents(\n", | |
" \"What do you see on the images in the database?\"\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6fde6f17-d244-4270-b759-68e1858d399f", | |
"metadata": {}, | |
"source": [ | |
"We can retrieve this image summary:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "771a47fa-1267-4db8-a6ae-5fde48bbc069", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from langchain.schema.runnable import RunnablePassthrough\n", | |
"from langchain.prompts import ChatPromptTemplate\n", | |
"from langchain.schema.output_parser import StrOutputParser\n", | |
"\n", | |
"template = \"\"\"Answer the question based only on the following context, which can include text, images and tables:\n", | |
"{context}\n", | |
"Question: {question}\n", | |
"\"\"\"\n", | |
"prompt = ChatPromptTemplate.from_template(template)\n", | |
"\n", | |
"model = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\")\n", | |
"\n", | |
"chain = (\n", | |
" {\"context\": retriever, \"question\": RunnablePassthrough()}\n", | |
" | prompt\n", | |
" | model\n", | |
" | StrOutputParser()\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ea8414a8-65ee-4e11-8154-029b454f46af", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"chain.invoke(\n", | |
" \"What do you see on the images in the database?\"\n", | |
")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment