Created
February 19, 2024 22:45
-
-
Save virattt/9099bf1a32ff2b99383b2fabba0ae763 to your computer and use it in GitHub Desktop.
query_expansion-cohere-gpt.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyO9f8+2MJPoRNjLkLbs2CcQ", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/virattt/9099bf1a32ff2b99383b2fabba0ae763/query_expansion-cohere-gpt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "HRA6TPl0lXqq" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install cohere" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import getpass\n", | |
"import os\n", | |
"\n", | |
"# Set your Cohere API key\n", | |
"os.environ[\"COHERE_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "4fJZ05n9lflZ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"query = \"What are the revenues of Airbnb, Booking, and Expedia?\"" | |
], | |
"metadata": { | |
"id": "5UbW7pydl82v" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use Cohere's `command` model for query expansion" | |
], | |
"metadata": { | |
"id": "025xZuDOlnT6" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import cohere\n", | |
"\n", | |
"# Get your cohere API key on: www.cohere.com\n", | |
"co = cohere.Client(os.environ[\"COHERE_API_KEY\"])\n", | |
"\n", | |
"# Call the model\n", | |
"response = co.chat(\n", | |
" message=query,\n", | |
" search_queries_only=True,\n", | |
")" | |
], | |
"metadata": { | |
"id": "b8-Ua66vlsWJ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the rewritten query\n", | |
"rewritten_queries = [query.get('text', None) for query in response.search_queries]\n", | |
"print(rewritten_queries)" | |
], | |
"metadata": { | |
"id": "QTXJByrxl3E_" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use Cohere's `command-light` model for query expansion" | |
], | |
"metadata": { | |
"id": "oYnx5WtEmXz7" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Get your cohere API key on: www.cohere.com\n", | |
"co = cohere.Client(os.environ[\"COHERE_API_KEY\"])\n", | |
"\n", | |
"# Call the model\n", | |
"response = co.chat(\n", | |
" model='command-light',\n", | |
" message=query,\n", | |
" search_queries_only=True,\n", | |
")" | |
], | |
"metadata": { | |
"id": "lNXcXI9ymJZa" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the rewritten query\n", | |
"rewritten_queries = [query.get('text', None) for query in response.search_queries]\n", | |
"print(rewritten_queries)" | |
], | |
"metadata": { | |
"id": "MaFJ6lrbmfmQ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use GPT-4 for query expansion" | |
], | |
"metadata": { | |
"id": "bPzoWQhVAmLt" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install openai" | |
], | |
"metadata": { | |
"id": "D_LA58dWgVT3" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import getpass\n", | |
"import os\n", | |
"\n", | |
"# Set your OpenAI API key\n", | |
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "WluIGzCFgXlb" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"prompt = \"\"\"\n", | |
"You are a helpful assistant that expands a user query into sub-queries.\n", | |
"The sub-queries should be mutually exclusive and collectively exhaustive.\n", | |
"Your response will be a JSON object with a `queries` field, which is a list of `query` objects.\n", | |
"\"\"\"" | |
], | |
"metadata": { | |
"id": "9Q4iyvUZgmKO" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from openai import OpenAI\n", | |
"import time\n", | |
"import json\n", | |
"\n", | |
"client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n", | |
"\n", | |
"# Call the model\n", | |
"response = client.chat.completions.create(\n", | |
" model='gpt-4-0125-preview',\n", | |
" response_format={\"type\": \"json_object\"},\n", | |
" temperature=0,\n", | |
" messages=[\n", | |
" {\"role\": \"system\", \"content\": prompt},\n", | |
" {\"role\": \"user\", \"content\": query},\n", | |
" ]\n", | |
")" | |
], | |
"metadata": { | |
"id": "Z83h16UuMlMt" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the rewritten query\n", | |
"rewritten_query = response.choices[0].message.content\n", | |
"print(rewritten_query)" | |
], | |
"metadata": { | |
"id": "8VZMWffzm0-i" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use GPT-3.5 to rewrite the query" | |
], | |
"metadata": { | |
"id": "AdrLmbzAAsgX" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Call the model\n", | |
"response = client.chat.completions.create(\n", | |
" model='gpt-3.5-turbo-0125',\n", | |
" response_format={\"type\": \"json_object\"},\n", | |
" temperature=0,\n", | |
" messages=[\n", | |
" {\"role\": \"system\", \"content\": prompt},\n", | |
" {\"role\": \"user\", \"content\": query},\n", | |
" ]\n", | |
")" | |
], | |
"metadata": { | |
"id": "AdpynLvNAvww" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the rewritten query\n", | |
"rewritten_query = response.choices[0].message.content\n", | |
"print(rewritten_query)" | |
], | |
"metadata": { | |
"id": "yFKleKcWD84S" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment