Last active
January 12, 2025 14:05
-
-
Save virattt/b140fb4bf549b6125d53aa153dc53be6 to your computer and use it in GitHub Desktop.
rag-reranking-gpt-colbert.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/virattt/b140fb4bf549b6125d53aa153dc53be6/rag-reranking-gpt-colbert.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Install dependencies" | |
], | |
"metadata": { | |
"id": "S2mGQxA958dW" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install openai" | |
], | |
"metadata": { | |
"id": "2bY0NapN_z98" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "lEQQJHH9gufm" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install chromadb" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install langchain" | |
], | |
"metadata": { | |
"id": "ygccK6lm54VT" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install tiktoken" | |
], | |
"metadata": { | |
"id": "K5KyVC5O7Elw" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install pypdf" | |
], | |
"metadata": { | |
"id": "_o1MOUo07GBO" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import getpass\n", | |
"import os\n", | |
"\n", | |
"# Set your OpenAI API key\n", | |
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "tavToGb_MJrc" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Download and prepare SEC filing" | |
], | |
"metadata": { | |
"id": "sz639zFf6JoK" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langchain.document_loaders import PyPDFLoader\n", | |
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n", | |
"\n", | |
"# Load $ABNB's financial report. This may take 1-2 minutes since the PDF is large\n", | |
"sec_filing_pdf = \"https://d18rn0p25nwr6d.cloudfront.net/CIK-0001559720/8a9ebed0-815a-469a-87eb-1767d21d8cec.pdf\"\n", | |
"\n", | |
"# Create your PDF loader\n", | |
"loader = PyPDFLoader(sec_filing_pdf)\n", | |
"\n", | |
"# Load the PDF document\n", | |
"documents = loader.load()\n", | |
"\n", | |
"# Chunk the financial report\n", | |
"text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=0)\n", | |
"docs = text_splitter.split_documents(documents)" | |
], | |
"metadata": { | |
"id": "rIO5t-j7611h" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Load the SEC filing into vector store" | |
], | |
"metadata": { | |
"id": "iaYSqxiMLUGb" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from langchain_community.vectorstores import Chroma\n", | |
"from langchain.embeddings.openai import OpenAIEmbeddings\n", | |
"\n", | |
"# Load the document into Chroma\n", | |
"embedding_function = OpenAIEmbeddings()\n", | |
"db = Chroma.from_documents(docs, embedding_function)" | |
], | |
"metadata": { | |
"id": "QVZevdc-Md4N" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Query the vector store" | |
], | |
"metadata": { | |
"id": "m8HqBNyYrDHb" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"query = \"What are the specific factors contributing to Airbnb's increased operational expenses in the last fiscal year?\"\n", | |
"docs = db.similarity_search(query)" | |
], | |
"metadata": { | |
"id": "3qZTrAtXLPl1" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Re-rank the results using GPT-4" | |
], | |
"metadata": { | |
"id": "0UMU-ogKM6w8" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from openai import OpenAI\n", | |
"import time\n", | |
"import json\n", | |
"\n", | |
"start = time.time()\n", | |
"client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n", | |
"response = client.chat.completions.create(\n", | |
" model='gpt-4-1106-preview',\n", | |
" response_format={\"type\": \"json_object\"},\n", | |
" temperature=0,\n", | |
" messages=[\n", | |
" {\"role\": \"system\", \"content\": \"You are an expert relevance ranker. Given a list of documents and a query, your job is to determine how relevant each document is for answering the query. Your output is JSON, which is a list of documents. Each document has two fields, content and score. relevance_score is from 0.0 to 100.0. Higher relevance means higher score.\"},\n", | |
" {\"role\": \"user\", \"content\": f\"Query: {query} Docs: {docs}\"}\n", | |
" ]\n", | |
" )\n", | |
"\n", | |
"print(f\"Took {time.time() - start} seconds to re-rank documents with GPT-4.\")" | |
], | |
"metadata": { | |
"id": "Z83h16UuMlMt" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Sort the scores by highest to lowest and print\n", | |
"scores = json.loads(response.choices[0].message.content)[\"documents\"]\n", | |
"sorted_data = sorted(scores, key=lambda x: x['score'], reverse=True)\n", | |
"print(json.dumps(sorted_data, indent=2))" | |
], | |
"metadata": { | |
"id": "8VZMWffzm0-i" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Re-rank the results using ColBERT" | |
], | |
"metadata": { | |
"id": "wnViL4XQg1FE" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install --quiet transformers torch" | |
], | |
"metadata": { | |
"id": "fXPdarpEiN65" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from transformers import AutoTokenizer, AutoModel\n", | |
"\n", | |
"# Load the tokenizer and the model\n", | |
"tokenizer = AutoTokenizer.from_pretrained(\"colbert-ir/colbertv2.0\")\n", | |
"model = AutoModel.from_pretrained(\"colbert-ir/colbertv2.0\")" | |
], | |
"metadata": { | |
"id": "4jE_MXfelFyv" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import torch\n", | |
"import time\n", | |
"\n", | |
"start = time.time()\n", | |
"scores = []\n", | |
"\n", | |
"# Function to compute MaxSim\n", | |
"def maxsim(query_embedding, document_embedding):\n", | |
" # Expand dimensions for broadcasting\n", | |
" # Query: [batch_size, query_length, embedding_size] -> [batch_size, query_length, 1, embedding_size]\n", | |
" # Document: [batch_size, doc_length, embedding_size] -> [batch_size, 1, doc_length, embedding_size]\n", | |
" expanded_query = query_embedding.unsqueeze(2)\n", | |
" expanded_doc = document_embedding.unsqueeze(1)\n", | |
"\n", | |
" # Compute cosine similarity across the embedding dimension\n", | |
" sim_matrix = torch.nn.functional.cosine_similarity(expanded_query, expanded_doc, dim=-1)\n", | |
"\n", | |
" # Take the maximum similarity for each query token (across all document tokens)\n", | |
" # sim_matrix shape: [batch_size, query_length, doc_length]\n", | |
" max_sim_scores, _ = torch.max(sim_matrix, dim=2)\n", | |
"\n", | |
" # Average these maximum scores across all query tokens\n", | |
" avg_max_sim = torch.mean(max_sim_scores, dim=1)\n", | |
" return avg_max_sim\n", | |
"\n", | |
"# Encode the query\n", | |
"query_encoding = tokenizer(query, return_tensors='pt')\n", | |
"query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)\n", | |
"\n", | |
"# Get score for each document\n", | |
"for document in docs:\n", | |
" document_encoding = tokenizer(document.page_content, return_tensors='pt', truncation=True, max_length=512)\n", | |
" document_embedding = model(**document_encoding).last_hidden_state\n", | |
"\n", | |
" # Calculate MaxSim score\n", | |
" score = maxsim(query_embedding.unsqueeze(0), document_embedding)\n", | |
" scores.append({\n", | |
" \"score\": score.item(),\n", | |
" \"document\": document.page_content,\n", | |
" })\n", | |
"\n", | |
"print(f\"Took {time.time() - start} seconds to re-rank documents with ColBERT.\")" | |
], | |
"metadata": { | |
"id": "u84ePIKtjrtg" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Sort the scores by highest to lowest and print\n", | |
"sorted_data = sorted(scores, key=lambda x: x['score'], reverse=True)\n", | |
"print(json.dumps(sorted_data, indent=2))" | |
], | |
"metadata": { | |
"id": "jBotW5sxjT6W" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "base", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
}, | |
"orig_nbformat": 4, | |
"colab": { | |
"provenance": [], | |
"include_colab_link": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
@truebit If I have done it right you need to add:
# Add this lines
query = "Your query in string format..."
query_encoding = tokenizer(query, return_tensors='pt', truncation=True, max_length=512)
query_embedding = model(**query_encoding).last_hidden_state.squeeze(0)
# Get score for each document
for document in splits:
document_encoding = tokenizer(document, return_tensors='pt', truncation=True, max_length=512)
document_embedding = model(**document_encoding).last_hidden_state
# Calculate MaxSim score
score = maxsim(query_embedding.unsqueeze(0), document_embedding)
...
@Psancs05 thx
Great catch - updated 🙏
@virattt Do you know the difference between using:
query_embedding = model(**query_encoding).last_hidden_state.squeeze(0)
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
I have tested both and seems that the squeeze(0)
returns better quality similar documents (maybe it's just the use-case I tried)
query_embedding = model(**query_encoding).last_hidden_state.squeeze(0)
is correct since it returns a vector per token, whilst
query_embedding = model(**query_encoding).last_hidden_state.mean(dim=1)
returns a single vector averaged over all tokens.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
thanks for sharing but the
query_embedding
variable missing assignment statement