Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save inspirit941/35a27ad79538cb61c8026df685f54f6f to your computer and use it in GitHub Desktop.
Save inspirit941/35a27ad79538cb61c8026df685f54f6f to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "140580ef-5db0-43cc-a524-9c39e04d4df0",
"metadata": {},
"outputs": [],
"source": [
"! pip install langchain unstructured[all-docs] pydantic lxml openai chromadb tiktoken"
]
},
{
"cell_type": "markdown",
"id": "74b56bde-1ba0-4525-a11d-cab02c5659e4",
"metadata": {},
"source": [
"## Data Loading\n",
"\n",
"### Partition PDF tables, text, and images\n",
"\n",
"* Use [Unstructured](https://unstructured-io.github.io/unstructured/) to partition elements"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e98bdeb7-eb77-42e6-a3a5-c3f27a1838d5",
"metadata": {},
"outputs": [],
"source": [
"from typing import Any\n",
"import os\n",
"from unstructured.partition.pdf import partition_pdf\n",
"import pytesseract\n",
"import os\n",
"\n",
"pytesseract.pytesseract.tesseract_cmd = r'C:\\Program Files\\Tesseract-OCR\\tesseract.exe'\n",
"\n",
"input_path = os.getcwd()\n",
"output_path = os.path.join(os.getcwd(), \"output\")\n",
"\n",
"# Get elements\n",
"raw_pdf_elements = partition_pdf(\n",
" filename=os.path.join(input_path, \"test.pdf\"),\n",
" extract_images_in_pdf=True,\n",
" infer_table_structure=True,\n",
" chunking_strategy=\"by_title\",\n",
" max_characters=4000,\n",
" new_after_n_chars=3800,\n",
" combine_text_under_n_chars=2000,\n",
" image_output_dir_path=output_path,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f660305-e165-4b6c-ada3-a67a422defb5",
"metadata": {},
"outputs": [],
"source": [
"import base64\n",
"\n",
"text_elements = []\n",
"table_elements = []\n",
"image_elements = []\n",
"\n",
"# Function to encode images\n",
"def encode_image(image_path):\n",
" with open(image_path, \"rb\") as image_file:\n",
" return base64.b64encode(image_file.read()).decode('utf-8')\n",
"\n",
"for element in raw_pdf_elements:\n",
" if 'CompositeElement' in str(type(element)):\n",
" text_elements.append(element)\n",
" elif 'Table' in str(type(element)):\n",
" table_elements.append(element)\n",
"\n",
"table_elements = [i.text for i in table_elements]\n",
"text_elements = [i.text for i in text_elements]\n",
"\n",
"# Tables\n",
"print(len(table_elements))\n",
"\n",
"# Text\n",
"print(len(text_elements))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for image_file in os.listdir(output_path):\n",
" if image_file.endswith(('.png', '.jpg', '.jpeg')):\n",
" image_path = os.path.join(output_path, image_file)\n",
" encoded_image = encode_image(image_path)\n",
" image_elements.append(encoded_image)\n",
"print(len(image_elements))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema.messages import HumanMessage, AIMessage\n",
"\n",
"chain_gpt_35 = ChatOpenAI(model=\"gpt-3.5-turbo\", max_tokens=1024)\n",
"chain_gpt_4_vision = ChatOpenAI(model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
"\n",
"# Function for text summaries\n",
"def summarize_text(text_element):\n",
" prompt = f\"Summarize the following text:\\n\\n{text_element}\\n\\nSummary:\"\n",
" response = chain_gpt_35.invoke([HumanMessage(content=prompt)])\n",
" return response.content\n",
"\n",
"# Function for table summaries\n",
"def summarize_table(table_element):\n",
" prompt = f\"Summarize the following table:\\n\\n{table_element}\\n\\nSummary:\"\n",
" response = chain_gpt_35.invoke([HumanMessage(content=prompt)])\n",
" return response.content\n",
"\n",
"# Function for image summaries\n",
"def summarize_image(encoded_image):\n",
" prompt = [\n",
" AIMessage(content=\"You are a bot that is good at analyzing images.\"),\n",
" HumanMessage(content=[\n",
" {\"type\": \"text\", \"text\": \"Describe the contents of this image.\"},\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": f\"data:image/jpeg;base64,{encoded_image}\"\n",
" },\n",
" },\n",
" ])\n",
" ]\n",
" response = chain_gpt_4_vision.invoke(prompt)\n",
" return response.content"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Processing table elements with feedback and sleep\n",
"table_summaries = []\n",
"for i, te in enumerate(table_elements[0:2]):\n",
" summary = summarize_table(te)\n",
" table_summaries.append(summary)\n",
" print(f\"{i + 1}th element of tables processed.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Processing text elements with feedback and sleep\n",
"text_summaries = []\n",
"for i, te in enumerate(text_elements[0:2]):\n",
" summary = summarize_text(te)\n",
" text_summaries.append(summary)\n",
" print(f\"{i + 1}th element of texts processed.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Processing image elements with feedback and sleep\n",
"image_summaries = []\n",
"for i, ie in enumerate(image_elements[0:2]):\n",
" summary = summarize_image(ie)\n",
" image_summaries.append(summary)\n",
" print(f\"{i + 1}th element of images processed.\")"
]
},
{
"cell_type": "markdown",
"id": "0aa7f52f-bf5c-4ba4-af72-b2ccba59a4cf",
"metadata": {},
"source": [
"## Multi-vector retriever\n",
"\n",
"Use [multi-vector-retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector#summary).\n",
"\n",
"Summaries are used to retrieve raw tables and / or raw chunks of text."
]
},
{
"cell_type": "markdown",
"id": "67b030d4-2ac5-41b6-9245-fc3ba5771d87",
"metadata": {},
"source": [
"### Add to vectorstore\n",
"\n",
"Use [Multi Vector Retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector#summary) with summaries."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d643cc61-827d-4f3c-8242-7a7c8291ed8a",
"metadata": {},
"outputs": [],
"source": [
"import uuid\n",
"\n",
"from langchain.embeddings import OpenAIEmbeddings\n",
"from langchain.retrievers.multi_vector import MultiVectorRetriever\n",
"from langchain.schema.document import Document\n",
"from langchain.storage import InMemoryStore\n",
"from langchain.vectorstores import Chroma\n",
"\n",
"# Initialize the vector store and storage layer\n",
"vectorstore = Chroma(collection_name=\"summaries\", embedding_function=OpenAIEmbeddings())\n",
"store = InMemoryStore()\n",
"id_key = \"doc_id\"\n",
"\n",
"# Initialize the retriever\n",
"retriever = MultiVectorRetriever(vectorstore=vectorstore, docstore=store, id_key=id_key)\n",
"\n",
"# Function to add documents to the retriever\n",
"def add_documents_to_retriever(summaries, original_contents):\n",
" doc_ids = [str(uuid.uuid4()) for _ in summaries]\n",
" summary_docs = [\n",
" Document(page_content=s, metadata={id_key: doc_ids[i]})\n",
" for i, s in enumerate(summaries)\n",
" ]\n",
" retriever.vectorstore.add_documents(summary_docs)\n",
" retriever.docstore.mset(list(zip(doc_ids, original_contents)))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Add text summaries\n",
"add_documents_to_retriever(text_summaries, text_elements)\n",
"\n",
"# Add table summaries\n",
"add_documents_to_retriever(table_summaries, table_elements)\n",
"\n",
"# Add image summaries\n",
"add_documents_to_retriever(image_summaries, image_summaries) # hopefully real images soon"
]
},
{
"cell_type": "markdown",
"id": "4b45fb81-46b1-426e-aa2c-01aed4eac700",
"metadata": {},
"source": [
"# Table retrieval\n",
"\n",
"The most complex table in the paper:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bea75fe-85af-4955-a80c-6e0b44a8e215",
"metadata": {},
"outputs": [],
"source": [
"# We can retrieve this table\n",
"retriever.get_relevant_documents(\n",
" \"What do you see on the images in the database?\"\n",
")"
]
},
{
"cell_type": "markdown",
"id": "6fde6f17-d244-4270-b759-68e1858d399f",
"metadata": {},
"source": [
"We can retrieve this image summary:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "771a47fa-1267-4db8-a6ae-5fde48bbc069",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema.runnable import RunnablePassthrough\n",
"from langchain.prompts import ChatPromptTemplate\n",
"from langchain.schema.output_parser import StrOutputParser\n",
"\n",
"template = \"\"\"Answer the question based only on the following context, which can include text, images and tables:\n",
"{context}\n",
"Question: {question}\n",
"\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(template)\n",
"\n",
"model = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\")\n",
"\n",
"chain = (\n",
" {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
" | prompt\n",
" | model\n",
" | StrOutputParser()\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea8414a8-65ee-4e11-8154-029b454f46af",
"metadata": {},
"outputs": [],
"source": [
"chain.invoke(\n",
" \"What do you see on the images in the database?\"\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment