Forked from virattt/rag-reranking-gpt-colbert.ipynb
Created
January 22, 2024 04:56
-
-
Save davidathompson/837e394de112f6accbf876d456ac9fce to your computer and use it in GitHub Desktop.
rag-reranking-gpt-colBERT.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/virattt/b140fb4bf549b6125d53aa153dc53be6/rag-reranking-gpt-colbert.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Install dependencies" | |
], | |
"metadata": { | |
"id": "S2mGQxA958dW" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"pip install --quiet openai" | |
], | |
"metadata": { | |
"id": "2bY0NapN_z98" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "lEQQJHH9gufm" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install --quiet chromadb" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install --quiet langchain" | |
], | |
"metadata": { | |
"id": "ygccK6lm54VT" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install --quiet tiktoken" | |
], | |
"metadata": { | |
"id": "K5KyVC5O7Elw" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install --quiet pypdf" | |
], | |
"metadata": { | |
"id": "_o1MOUo07GBO" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import getpass\n", | |
"import os\n", | |
"\n", | |
"# Set your OpenAI API key\n", | |
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "tavToGb_MJrc" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Download and prepare SEC filing" | |
], | |
"metadata": { | |
"id": "sz639zFf6JoK" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langchain.document_loaders import PyPDFLoader\n", | |
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n", | |
"\n", | |
"# Load $ABNB's financial report. This may take 1-2 minutes since the PDF is large\n", | |
"sec_filing_pdf = \"https://d18rn0p25nwr6d.cloudfront.net/CIK-0001559720/8a9ebed0-815a-469a-87eb-1767d21d8cec.pdf\"\n", | |
"\n", | |
"# Create your PDF loader\n", | |
"loader = PyPDFLoader(sec_filing_pdf)\n", | |
"\n", | |
"# Load the PDF document\n", | |
"documents = loader.load()\n", | |
"\n", | |
"# Chunk the financial report\n", | |
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=0)\n", | |
"docs = text_splitter.split_documents(documents)" | |
], | |
"metadata": { | |
"id": "rIO5t-j7611h" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Load the SEC filing into vector store" | |
], | |
"metadata": { | |
"id": "iaYSqxiMLUGb" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langchain_community.vectorstores import Chroma\n", | |
"from langchain.embeddings.openai import OpenAIEmbeddings\n", | |
"\n", | |
"# Load the document into Chroma\n", | |
"embedding_function = OpenAIEmbeddings()\n", | |
"db = Chroma.from_documents(docs, embedding_function)" | |
], | |
"metadata": { | |
"id": "QVZevdc-Md4N" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Query the vector store" | |
], | |
"metadata": { | |
"id": "m8HqBNyYrDHb" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"query = \"What are the specific factors contributing to Airbnb's increased operational expenses in the last fiscal year?\"\n", | |
"docs = db.similarity_search(query)" | |
], | |
"metadata": { | |
"id": "3qZTrAtXLPl1" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Re-rank the results using GPT-4" | |
], | |
"metadata": { | |
"id": "0UMU-ogKM6w8" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from openai import OpenAI\n", | |
"import time\n", | |
"import json\n", | |
"\n", | |
"start = time.time()\n", | |
"client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n", | |
"response = client.chat.completions.create(\n", | |
" model='gpt-4-1106-preview',\n", | |
" response_format={\"type\": \"json_object\"},\n", | |
" temperature=0,\n", | |
" messages=[\n", | |
" {\"role\": \"system\", \"content\": \"You are an expert relevance ranker. Given a list of documents and a query, your job is to determine how relevant each document is for answering the query. Your output is JSON, which is a list of documents. Each document has two fields, content and score. relevance_score is from 0.0 to 100.0. Higher relevance means higher score.\"},\n", | |
" {\"role\": \"user\", \"content\": f\"Query: {query} Docs: {docs}\"}\n", | |
" ]\n", | |
" )\n", | |
"\n", | |
"print(f\"Took {time.time() - start} seconds to re-rank documents with GPT-4.\")" | |
], | |
"metadata": { | |
"id": "Z83h16UuMlMt" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Sort the scores by highest to lowest and print\n", | |
"scores = json.loads(response.choices[0].message.content)[\"documents\"]\n", | |
"sorted_data = sorted(scores, key=lambda x: x['score'], reverse=True)\n", | |
"print(json.dumps(sorted_data, indent=2))" | |
], | |
"metadata": { | |
"id": "8VZMWffzm0-i" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Re-rank the results using ColBERT" | |
], | |
"metadata": { | |
"id": "wnViL4XQg1FE" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install --quiet transformers torch" | |
], | |
"metadata": { | |
"id": "fXPdarpEiN65" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from transformers import AutoTokenizer, AutoModel\n", | |
"\n", | |
"# Load the tokenizer and the model\n", | |
"tokenizer = AutoTokenizer.from_pretrained(\"colbert-ir/colbertv2.0\")\n", | |
"model = AutoModel.from_pretrained(\"colbert-ir/colbertv2.0\")" | |
], | |
"metadata": { | |
"id": "4jE_MXfelFyv" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import torch\n", | |
"import time\n", | |
"\n", | |
"start = time.time()\n", | |
"scores = []\n", | |
"\n", | |
"# Function to compute MaxSim\n", | |
"def maxsim(query_embedding, document_embedding):\n", | |
" # Expand dimensions for broadcasting\n", | |
" # Query: [batch_size, query_length, embedding_size] -> [batch_size, query_length, 1, embedding_size]\n", | |
" # Document: [batch_size, doc_length, embedding_size] -> [batch_size, 1, doc_length, embedding_size]\n", | |
" expanded_query = query_embedding.unsqueeze(2)\n", | |
" expanded_doc = document_embedding.unsqueeze(1)\n", | |
"\n", | |
" # Compute cosine similarity across the embedding dimension\n", | |
" sim_matrix = torch.nn.functional.cosine_similarity(expanded_query, expanded_doc, dim=-1)\n", | |
"\n", | |
" # Take the maximum similarity for each query token (across all document tokens)\n", | |
" # sim_matrix shape: [batch_size, query_length, doc_length]\n", | |
" max_sim_scores, _ = torch.max(sim_matrix, dim=2)\n", | |
"\n", | |
" # Average these maximum scores across all query tokens\n", | |
" avg_max_sim = torch.mean(max_sim_scores, dim=1)\n", | |
" return avg_max_sim\n", | |
"\n", | |
"# Get score for each document\n", | |
"for document in docs:\n", | |
" document_encoding = tokenizer(document.page_content, return_tensors='pt', truncation=True, max_length=512)\n", | |
" document_embedding = model(**document_encoding).last_hidden_state\n", | |
"\n", | |
" # Calculate MaxSim score\n", | |
" score = maxsim(query_embedding.unsqueeze(0), document_embedding)\n", | |
" scores.append({\n", | |
" \"score\": score.item(),\n", | |
" \"document\": document.page_content,\n", | |
" })\n", | |
"\n", | |
"print(f\"Took {time.time() - start} seconds to re-rank documents with ColBERT.\")" | |
], | |
"metadata": { | |
"id": "u84ePIKtjrtg" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Sort the scores by highest to lowest and print\n", | |
"sorted_data = sorted(scores, key=lambda x: x['score'], reverse=True)\n", | |
"print(json.dumps(sorted_data, indent=2))" | |
], | |
"metadata": { | |
"id": "jBotW5sxjT6W" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "base", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
}, | |
"orig_nbformat": 4, | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment