Created
February 2, 2024 00:14
-
-
Save virattt/d6943f0630ee90afbb0b644890d2050f to your computer and use it in GitHub Desktop.
query-rewriting-gpt.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/virattt/d6943f0630ee90afbb0b644890d2050f/query-rewriting-gpt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Install dependencies" | |
], | |
"metadata": { | |
"id": "S2mGQxA958dW" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install openai" | |
], | |
"metadata": { | |
"id": "2bY0NapN_z98" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import getpass\n", | |
"import os\n", | |
"\n", | |
"# Set your OpenAI API key\n", | |
"os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()" | |
], | |
"metadata": { | |
"id": "tavToGb_MJrc" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Create Query" | |
], | |
"metadata": { | |
"id": "m8HqBNyYrDHb" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"query = \"What's going on with Airbnb's numbers?\"\n", | |
"# query = \"How did the market react to Airbnb's announcement?\"" | |
], | |
"metadata": { | |
"id": "3qZTrAtXLPl1" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use GPT-4 to rewrite the query" | |
], | |
"metadata": { | |
"id": "bPzoWQhVAmLt" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from openai import OpenAI\n", | |
"import time\n", | |
"import json\n", | |
"\n", | |
"start = time.time()\n", | |
"client = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n", | |
"\n", | |
"prompt = \"\"\"\n", | |
"Rewrite the following user query into a clear, specific, and\n", | |
"formal request suitable for retrieving relevant information from a vector database.\n", | |
"Keep in mind that your rewritten query will be sent to a vector database, which\n", | |
"does similarity search for retrieving documents.\n", | |
"\"\"\"\n", | |
"\n", | |
"response = client.chat.completions.create(\n", | |
" model='gpt-4-0125-preview',\n", | |
" temperature=0,\n", | |
" messages=[\n", | |
" {\"role\": \"system\", \"content\": prompt},\n", | |
" {\"role\": \"user\", \"content\": query},\n", | |
" ]\n", | |
" )\n", | |
"\n", | |
"print(f\"Took {time.time() - start} seconds to rewrite query with GPT-4.\")" | |
], | |
"metadata": { | |
"id": "Z83h16UuMlMt" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the rewritten query\n", | |
"rewritten_query = response.choices[0].message.content\n", | |
"print(rewritten_query)" | |
], | |
"metadata": { | |
"id": "8VZMWffzm0-i" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Use GPT-3.5 to rewrite the query" | |
], | |
"metadata": { | |
"id": "AdrLmbzAAsgX" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"start = time.time()\n", | |
"response = client.chat.completions.create(\n", | |
" model='gpt-3.5-turbo-0125',\n", | |
" temperature=0,\n", | |
" messages=[\n", | |
" {\"role\": \"system\", \"content\": prompt},\n", | |
" {\"role\": \"user\", \"content\": query},\n", | |
" ]\n", | |
" )\n", | |
"\n", | |
"print(f\"Took {time.time() - start} seconds to rewrite query with GPT-4.\")" | |
], | |
"metadata": { | |
"id": "AdpynLvNAvww" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Print the rewritten query\n", | |
"rewritten_query = response.choices[0].message.content\n", | |
"print(rewritten_query)" | |
], | |
"metadata": { | |
"id": "yFKleKcWD84S" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.12" | |
}, | |
"orig_nbformat": 4, | |
"colab": { | |
"provenance": [], | |
"gpuType": "T4", | |
"include_colab_link": true | |
}, | |
"accelerator": "GPU" | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment