Last active
April 7, 2025 09:39
-
-
Save virattt/7aee148b89b935aa0e7be03d60d72707 to your computer and use it in GitHub Desktop.
cost-query-rewriting-gpt-mistral-cohere.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/virattt/7aee148b89b935aa0e7be03d60d72707/cost-query-rewriting-gpt-mistral-cohere.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Create Query and Prompt" | |
], | |
"metadata": { | |
"id": "m8HqBNyYrDHb" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"query = \"What's going on with Airbnb's numbers?\"\n", | |
"\n", | |
"prompt = \"\"\"\n", | |
"Rewrite the following user query into a clear, specific, and\n", | |
"formal request suitable for retrieving relevant information from a vector database.\n", | |
"Keep in mind that your rewritten query will be sent to a vector database, which\n", | |
"does similarity search for retrieving documents. Your output must be 100 tokens max.\n", | |
"\"\"\"" | |
], | |
"metadata": { | |
"id": "3qZTrAtXLPl1" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Helper function to get num tokens in a phrase\n", | |
"def get_num_tokens(phrase: str) -> float:\n", | |
" words = phrase.split()\n", | |
" word_count = len(words)\n", | |
"\n", | |
" # Multiplying the number of words by 1.3 to get the total number of tokens\n", | |
" return word_count * 1.3" | |
], | |
"metadata": { | |
"id": "brMdLqEbetEC" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install openai" | |
], | |
"metadata": { | |
"id": "2bY0NapN_z98" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import getpass\n", | |
"import os\n", | |
"\n", | |
"# Set your OpenAI API key\n", | |
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "tavToGb_MJrc" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use GPT-4 to rewrite the query" | |
], | |
"metadata": { | |
"id": "bPzoWQhVAmLt" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from openai import OpenAI\n", | |
"import time\n", | |
"import json\n", | |
"\n", | |
"client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n", | |
"\n", | |
"total_time = 0.0\n", | |
"total_cost = 0.0\n", | |
"num_iterations = 10\n", | |
"response = None\n", | |
"\n", | |
"for _ in range(num_iterations):\n", | |
" # Get start time\n", | |
" start_time = time.time()\n", | |
" # Call the model\n", | |
" response = client.chat.completions.create(\n", | |
" model='gpt-4-0125-preview',\n", | |
" temperature=0,\n", | |
" messages=[\n", | |
" {\"role\": \"system\", \"content\": prompt},\n", | |
" {\"role\": \"user\", \"content\": query},\n", | |
" ]\n", | |
" )\n", | |
" # Get end time\n", | |
" end_time = time.time()\n", | |
" # Update total execution time (excluding sleep time)\n", | |
" total_time += (end_time - start_time)\n", | |
" # Update total cost\n", | |
" rewritten_query = response.choices[0].message.content\n", | |
" input_cost = get_num_tokens(phrase=query) * 0.00001\n", | |
" output_cost = get_num_tokens(phrase=rewritten_query) * 0.00003\n", | |
" total_cost += (input_cost + output_cost)\n", | |
" # Wait for 1 seconds before the next iteration\n", | |
" time.sleep(1)\n", | |
"\n", | |
"# Calculate the average execution time\n", | |
"avg_time = total_time / num_iterations\n", | |
"\n", | |
"print(f\"Took {avg_time} seconds to rewrite the query.\")" | |
], | |
"metadata": { | |
"id": "Z83h16UuMlMt" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the rewritten query\n", | |
"rewritten_query = response.choices[0].message.content\n", | |
"print(rewritten_query)" | |
], | |
"metadata": { | |
"id": "8VZMWffzm0-i" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the cost\n", | |
"print(f\"Total cost of rewriting query: ${total_cost}\")" | |
], | |
"metadata": { | |
"id": "8KaBnljJeUan" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use GPT-3.5 to rewrite the query" | |
], | |
"metadata": { | |
"id": "AdrLmbzAAsgX" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"total_time = 0.0\n", | |
"total_cost = 0.0\n", | |
"num_iterations = 10\n", | |
"response = None\n", | |
"\n", | |
"for _ in range(num_iterations):\n", | |
" # Get start time\n", | |
" start_time = time.time()\n", | |
" # Call the model\n", | |
" response = client.chat.completions.create(\n", | |
" model='gpt-3.5-turbo-0125',\n", | |
" temperature=0,\n", | |
" messages=[\n", | |
" {\"role\": \"system\", \"content\": prompt},\n", | |
" {\"role\": \"user\", \"content\": query},\n", | |
" ]\n", | |
" )\n", | |
" # Get end time\n", | |
" end_time = time.time()\n", | |
" # Update total execution time (excluding sleep time)\n", | |
" total_time += (end_time - start_time)\n", | |
" # Update total cost\n", | |
" rewritten_query = response.choices[0].message.content\n", | |
" input_cost = get_num_tokens(phrase=query) * 0.0000005\n", | |
" output_cost = get_num_tokens(phrase=rewritten_query) * 0.0000015\n", | |
" total_cost += (input_cost + output_cost)\n", | |
"\n", | |
" # Wait for 1 seconds before the next iteration\n", | |
" time.sleep(1)\n", | |
"\n", | |
"# Calculate the average execution time\n", | |
"avg_time = total_time / num_iterations\n", | |
"\n", | |
"print(f\"Took {avg_time} seconds to rewrite the query.\")" | |
], | |
"metadata": { | |
"id": "AdpynLvNAvww" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the rewritten query\n", | |
"rewritten_query = response.choices[0].message.content\n", | |
"print(rewritten_query)" | |
], | |
"metadata": { | |
"id": "yFKleKcWD84S" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the cost\n", | |
"print(f\"Total cost of rewriting query: ${format(total_cost, '.8f')}\")" | |
], | |
"metadata": { | |
"id": "a2IgO3Fvftus" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use Mistral to rewrite the query" | |
], | |
"metadata": { | |
"id": "FMHkITrr-ru2" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install mistralai" | |
], | |
"metadata": { | |
"id": "cYy332j3cMbt" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Set your Mistral API key\n", | |
"os.environ[\"MISTRAL_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "rcPNaNTR4leC" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from mistralai.client import MistralClient\n", | |
"from mistralai.models.chat_completion import ChatMessage\n", | |
"\n", | |
"client = MistralClient(api_key=os.environ[\"MISTRAL_API_KEY\"])\n", | |
"\n", | |
"total_time = 0.0\n", | |
"total_cost = 0.0\n", | |
"num_iterations = 10\n", | |
"response = None\n", | |
"\n", | |
"for _ in range(num_iterations):\n", | |
" # Get start time\n", | |
" start_time = time.time()\n", | |
" # Call the model\n", | |
" response = client.chat(\n", | |
" model=\"mistral-medium\",\n", | |
" messages=[\n", | |
" ChatMessage(role=\"system\", content=prompt),\n", | |
" ChatMessage(role=\"user\", content=query)\n", | |
" ]\n", | |
" )\n", | |
" # Get end time\n", | |
" end_time = time.time()\n", | |
" # Update total execution time (excluding sleep time)\n", | |
" total_time += (end_time - start_time)\n", | |
" # Update total cost\n", | |
" rewritten_query = response.choices[0].message.content\n", | |
" input_cost = get_num_tokens(phrase=query) * 0.0000027\n", | |
" output_cost = get_num_tokens(phrase=rewritten_query) * 0.0000081\n", | |
" total_cost += (input_cost + output_cost)\n", | |
" # Wait for 1 seconds before the next iteration\n", | |
" time.sleep(1)\n", | |
"\n", | |
"# Calculate the average execution time\n", | |
"avg_time = total_time / num_iterations\n", | |
"\n", | |
"print(f\"Took {avg_time} seconds to rewrite the query.\")" | |
], | |
"metadata": { | |
"id": "Z3ZQalMlUUb4" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the rewritten query\n", | |
"rewritten_query = response.choices[0].message.content\n", | |
"print(rewritten_query)" | |
], | |
"metadata": { | |
"id": "imAL6_eqUtds" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the cost\n", | |
"print(f\"Total cost of rewriting query: ${format(total_cost, '.8f')}\")" | |
], | |
"metadata": { | |
"id": "HFGVbymlhiuE" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use Cohere to rewrite the query" | |
], | |
"metadata": { | |
"id": "DrgoAKInw-z8" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install cohere" | |
], | |
"metadata": { | |
"id": "uzEp-k7Xxfla" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Set your Cohere API key\n", | |
"os.environ[\"COHERE_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "V7X3rjrb4uAX" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import cohere\n", | |
"\n", | |
"# Get your cohere API key on: www.cohere.com\n", | |
"co = cohere.Client(os.environ[\"COHERE_API_KEY\"])\n", | |
"\n", | |
"total_time = 0\n", | |
"total_cost = 0.0\n", | |
"num_iterations = 10\n", | |
"response = None\n", | |
"\n", | |
"for _ in range(num_iterations):\n", | |
" # Get start time\n", | |
" start_time = time.time()\n", | |
" # Call the model\n", | |
" response = co.chat(\n", | |
" message=query,\n", | |
" search_queries_only=True,\n", | |
" )\n", | |
" # Get end time\n", | |
" end_time = time.time()\n", | |
" # Update total execution time (excluding sleep time)\n", | |
" total_time += (end_time - start_time)\n", | |
" # Compute total cost\n", | |
" rewritten_queries = [query.get('text', None) for query in response.search_queries]\n", | |
" input_cost = get_num_tokens(phrase=query) * 0.000001\n", | |
" output_cost = get_num_tokens(phrase=str(rewritten_queries)) * 0.000002\n", | |
" total_cost = input_cost + output_cost\n", | |
" # Wait for 1 seconds before the next iteration\n", | |
" time.sleep(1)\n", | |
"\n", | |
"# Calculate the average execution time\n", | |
"avg_time = total_time / num_iterations\n", | |
"\n", | |
"print(f\"Took {avg_time} seconds to rewrite the query.\")" | |
], | |
"metadata": { | |
"id": "Zc-w9vn_xAnY" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the rewritten queries\n", | |
"rewritten_queries = [query.get('text', None) for query in response.search_queries]\n", | |
"print(rewritten_queries)" | |
], | |
"metadata": { | |
"id": "zdpHKYZXBIx0" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Compute the costs\n", | |
"print(f\"Total cost of rewriting query: ${format(total_cost, '.8f')}\")" | |
], | |
"metadata": { | |
"id": "jv0eNVDiiYpt" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
}, | |
"orig_nbformat": 4, | |
"colab": { | |
"provenance": [], | |
"gpuType": "T4", | |
"include_colab_link": true | |
}, | |
"accelerator": "GPU" | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment