Skip to content

Instantly share code, notes, and snippets.

@subham27-07
Last active November 1, 2024 01:06
Show Gist options
  • Save subham27-07/be8b52d635c44683b2a40f736c3f782a to your computer and use it in GitHub Desktop.
Save subham27-07/be8b52d635c44683b2a40f736c3f782a to your computer and use it in GitHub Desktop.
nl4dv-llm-colab.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/subham27-07/1552357018963f7f86b04de0f20a2e17/nl4dv-llm-demo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"id": "0a5b63c9-be7e-476e-8b9c-dab1b099a572",
"metadata": {
"id": "0a5b63c9-be7e-476e-8b9c-dab1b099a572"
},
"source": [
"## Installation"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "8462090c-d1f1-40a0-95a0-e1266b063aa4",
"metadata": {
"id": "8462090c-d1f1-40a0-95a0-e1266b063aa4"
},
"outputs": [],
"source": [
"!pip install nl4dv &> /dev/null\n",
"!python -m nltk.downloader popular &> /dev/null\n",
"!python -m spacy download en_core_web_sm &> /dev/null\n",
"!pip install --upgrade notebook &> /dev/null\n",
"!jupyter nbextension install --sys-prefix --py vega &> /dev/null\n",
"!jupyter nbextension enable vega --py --sys-prefix &> /dev/null\n",
"!pip install altair &> /dev/null"
]
},
{
"cell_type": "markdown",
"id": "821b7294-6492-4ea0-8357-6efa479be442",
"metadata": {
"id": "821b7294-6492-4ea0-8357-6efa479be442"
},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c33a2aac-5ab0-4beb-91ff-b7ef76a64140",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c33a2aac-5ab0-4beb-91ff-b7ef76a64140",
"outputId": "d6e6b281-e18b-47f3-f7de-0ec439cd9bcd"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/altair/vega/v5/__init__.py:18: AltairDeprecationWarning: The module altair.vega.v5 is deprecated and will be removed in Altair 5.\n",
" warnings.warn(\n"
]
}
],
"source": [
"import json\n",
"from nl4dv import NL4DV\n",
"import os\n",
"import altair as alt\n",
"from altair import vega, vegalite\n",
"vega.renderers.enable('colab')\n",
"vegalite.renderers.enable('colab')\n",
"from IPython.display import display"
]
},
{
"cell_type": "markdown",
"id": "6eadfbaf-f30b-4d9e-9bcb-46fa06ec3014",
"metadata": {
"id": "6eadfbaf-f30b-4d9e-9bcb-46fa06ec3014"
},
"source": [
"## Initializing NL4DV with a Cars Dataset"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b7505699-a137-497a-b1b2-e35dc66515e8",
"metadata": {
"id": "b7505699-a137-497a-b1b2-e35dc66515e8"
},
"outputs": [],
"source": [
"data_url=\"https://raw.githubusercontent.com/nl4dv/nl4dv/master/examples/assets/data/cars-w-year.csv\" #paste your data URL\n",
"processing_mode=\"gpt\" # Choose your processing mode LLM or parsing. Choose \"gpt\" for the LLM-based mode or \"semantic-parsing\" for the rules-based mode.\n",
"gpt_api_key=\"[OpenAI KEY HERE]\" #paste your openAI api key\n",
"nl4dv_instance = NL4DV(data_url=data_url,\n",
" processing_mode=processing_mode,\n",
" gpt_api_key=gpt_api_key)"
]
},
{
"cell_type": "markdown",
"id": "ad40cfb6-b376-416c-b44a-e45680adc1e3",
"metadata": {
"id": "ad40cfb6-b376-416c-b44a-e45680adc1e3"
},
"source": [
"## Query and Rendering the Most Relevant Visualization"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9a3971c2-3b13-4d48-9552-d1dbeb089aa3",
"metadata": {
"id": "9a3971c2-3b13-4d48-9552-d1dbeb089aa3"
},
"outputs": [],
"source": [
"response=nl4dv_instance.analyze_query(\"Weight of cars grouped by year\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "760a50fa-0072-4419-b254-21c5e74b4df4",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 281
},
"id": "760a50fa-0072-4419-b254-21c5e74b4df4",
"outputId": "894b3823-3829-4134-8911-a236a962524f"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
"<div id=\"altair-viz-aba3db16c046444381b1d64649dc73bd\"></div>\n",
"<script type=\"text/javascript\">\n",
" var VEGA_DEBUG = (typeof VEGA_DEBUG == \"undefined\") ? {} : VEGA_DEBUG;\n",
" (function(spec, embedOpt){\n",
" let outputDiv = document.currentScript.previousElementSibling;\n",
" if (outputDiv.id !== \"altair-viz-aba3db16c046444381b1d64649dc73bd\") {\n",
" outputDiv = document.getElementById(\"altair-viz-aba3db16c046444381b1d64649dc73bd\");\n",
" }\n",
" const paths = {\n",
" \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n",
" \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n",
" \"vega-lite\": \"https://cdn.jsdelivr.net/npm//[email protected]?noext\",\n",
" \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n",
" };\n",
"\n",
" function maybeLoadScript(lib, version) {\n",
" var key = `${lib.replace(\"-\", \"\")}_version`;\n",
" return (VEGA_DEBUG[key] == version) ?\n",
" Promise.resolve(paths[lib]) :\n",
" new Promise(function(resolve, reject) {\n",
" var s = document.createElement('script');\n",
" document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
" s.async = true;\n",
" s.onload = () => {\n",
" VEGA_DEBUG[key] = version;\n",
" return resolve(paths[lib]);\n",
" };\n",
" s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n",
" s.src = paths[lib];\n",
" });\n",
" }\n",
"\n",
" function showError(err) {\n",
" outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n",
" throw err;\n",
" }\n",
"\n",
" function displayChart(vegaEmbed) {\n",
" vegaEmbed(outputDiv, spec, embedOpt)\n",
" .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n",
" }\n",
"\n",
" if(typeof define === \"function\" && define.amd) {\n",
" requirejs.config({paths});\n",
" require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n",
" } else {\n",
" maybeLoadScript(\"vega\", \"5\")\n",
" .then(() => maybeLoadScript(\"vega-lite\", \"4.17.0\"))\n",
" .then(() => maybeLoadScript(\"vega-embed\", \"6\"))\n",
" .catch(showError)\n",
" .then(() => displayChart(vegaEmbed));\n",
" }\n",
" })({\"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\", \"data\": {\"url\": \"https://raw.githubusercontent.com/nl4dv/nl4dv/master/examples/assets/data/cars-w-year.csv\"}, \"mark\": \"bar\", \"encoding\": {\"x\": {\"field\": \"Year\", \"type\": \"ordinal\"}, \"y\": {\"aggregate\": \"mean\", \"field\": \"Weight\", \"type\": \"quantitative\"}}}, {\"mode\": \"vega-lite\"});\n",
"</script>"
]
},
"metadata": {}
}
],
"source": [
"display(alt.display.html_renderer(response['visList'][0]['vlSpec']), raw=True)"
]
},
{
"cell_type": "markdown",
"id": "69980448-0ca6-4bfc-9c53-d24430a0b42e",
"metadata": {
"id": "69980448-0ca6-4bfc-9c53-d24430a0b42e"
},
"source": [
"## Auto-detecting Follow Up Queries"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2a04a11b-0123-416c-8caa-df4c14f889f1",
"metadata": {
"id": "2a04a11b-0123-416c-8caa-df4c14f889f1"
},
"outputs": [],
"source": [
"response=nl4dv_instance.analyze_query(\"Average weight of cars grouped by year and by origin\",dialog=True) #this query is automatically inferred as follow-up query"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "49db0c16-9548-4416-b20f-8da0cfbdee4e",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 276
},
"id": "49db0c16-9548-4416-b20f-8da0cfbdee4e",
"outputId": "7c9c906e-babd-4f4e-9e41-be53bd62a4af"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
"<div id=\"altair-viz-3fa079d5841944558dc7998858145fff\"></div>\n",
"<script type=\"text/javascript\">\n",
" var VEGA_DEBUG = (typeof VEGA_DEBUG == \"undefined\") ? {} : VEGA_DEBUG;\n",
" (function(spec, embedOpt){\n",
" let outputDiv = document.currentScript.previousElementSibling;\n",
" if (outputDiv.id !== \"altair-viz-3fa079d5841944558dc7998858145fff\") {\n",
" outputDiv = document.getElementById(\"altair-viz-3fa079d5841944558dc7998858145fff\");\n",
" }\n",
" const paths = {\n",
" \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n",
" \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n",
" \"vega-lite\": \"https://cdn.jsdelivr.net/npm//[email protected]?noext\",\n",
" \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n",
" };\n",
"\n",
" function maybeLoadScript(lib, version) {\n",
" var key = `${lib.replace(\"-\", \"\")}_version`;\n",
" return (VEGA_DEBUG[key] == version) ?\n",
" Promise.resolve(paths[lib]) :\n",
" new Promise(function(resolve, reject) {\n",
" var s = document.createElement('script');\n",
" document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
" s.async = true;\n",
" s.onload = () => {\n",
" VEGA_DEBUG[key] = version;\n",
" return resolve(paths[lib]);\n",
" };\n",
" s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n",
" s.src = paths[lib];\n",
" });\n",
" }\n",
"\n",
" function showError(err) {\n",
" outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n",
" throw err;\n",
" }\n",
"\n",
" function displayChart(vegaEmbed) {\n",
" vegaEmbed(outputDiv, spec, embedOpt)\n",
" .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n",
" }\n",
"\n",
" if(typeof define === \"function\" && define.amd) {\n",
" requirejs.config({paths});\n",
" require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n",
" } else {\n",
" maybeLoadScript(\"vega\", \"5\")\n",
" .then(() => maybeLoadScript(\"vega-lite\", \"4.17.0\"))\n",
" .then(() => maybeLoadScript(\"vega-embed\", \"6\"))\n",
" .catch(showError)\n",
" .then(() => displayChart(vegaEmbed));\n",
" }\n",
" })({\"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\", \"data\": {\"url\": \"https://raw.githubusercontent.com/nl4dv/nl4dv/master/examples/assets/data/cars-w-year.csv\"}, \"mark\": \"bar\", \"encoding\": {\"x\": {\"field\": \"Year\", \"type\": \"ordinal\"}, \"y\": {\"aggregate\": \"mean\", \"field\": \"Weight\", \"type\": \"quantitative\"}, \"color\": {\"field\": \"Origin\", \"type\": \"nominal\"}}}, {\"mode\": \"vega-lite\"});\n",
"</script>"
]
},
"metadata": {}
}
],
"source": [
"display(alt.display.html_renderer(response['visList'][0]['vlSpec']), raw=True)"
]
},
{
"cell_type": "markdown",
"id": "6cb34901-976d-4967-8b16-0406b220bace",
"metadata": {
"id": "6cb34901-976d-4967-8b16-0406b220bace"
},
"source": [
"## New Standalone Query and Visualization"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ebb03e9f-1c01-4c85-b4b0-407541bc5ee9",
"metadata": {
"id": "ebb03e9f-1c01-4c85-b4b0-407541bc5ee9"
},
"outputs": [],
"source": [
"response=nl4dv_instance.analyze_query(\"create a line graph of average weight by year\",dialog=True) #this query is automatically inferred as follow-up query"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "da07a67f-65e7-494a-8c08-80b867d911f6",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 281
},
"id": "da07a67f-65e7-494a-8c08-80b867d911f6",
"outputId": "b48994f0-d359-44e8-cfb0-6ad89180a1b0"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
"<div id=\"altair-viz-40379eeb52954c839704063f98e3b8e6\"></div>\n",
"<script type=\"text/javascript\">\n",
" var VEGA_DEBUG = (typeof VEGA_DEBUG == \"undefined\") ? {} : VEGA_DEBUG;\n",
" (function(spec, embedOpt){\n",
" let outputDiv = document.currentScript.previousElementSibling;\n",
" if (outputDiv.id !== \"altair-viz-40379eeb52954c839704063f98e3b8e6\") {\n",
" outputDiv = document.getElementById(\"altair-viz-40379eeb52954c839704063f98e3b8e6\");\n",
" }\n",
" const paths = {\n",
" \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n",
" \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n",
" \"vega-lite\": \"https://cdn.jsdelivr.net/npm//[email protected]?noext\",\n",
" \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n",
" };\n",
"\n",
" function maybeLoadScript(lib, version) {\n",
" var key = `${lib.replace(\"-\", \"\")}_version`;\n",
" return (VEGA_DEBUG[key] == version) ?\n",
" Promise.resolve(paths[lib]) :\n",
" new Promise(function(resolve, reject) {\n",
" var s = document.createElement('script');\n",
" document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
" s.async = true;\n",
" s.onload = () => {\n",
" VEGA_DEBUG[key] = version;\n",
" return resolve(paths[lib]);\n",
" };\n",
" s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n",
" s.src = paths[lib];\n",
" });\n",
" }\n",
"\n",
" function showError(err) {\n",
" outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n",
" throw err;\n",
" }\n",
"\n",
" function displayChart(vegaEmbed) {\n",
" vegaEmbed(outputDiv, spec, embedOpt)\n",
" .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n",
" }\n",
"\n",
" if(typeof define === \"function\" && define.amd) {\n",
" requirejs.config({paths});\n",
" require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n",
" } else {\n",
" maybeLoadScript(\"vega\", \"5\")\n",
" .then(() => maybeLoadScript(\"vega-lite\", \"4.17.0\"))\n",
" .then(() => maybeLoadScript(\"vega-embed\", \"6\"))\n",
" .catch(showError)\n",
" .then(() => displayChart(vegaEmbed));\n",
" }\n",
" })({\"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\", \"data\": {\"url\": \"https://raw.githubusercontent.com/nl4dv/nl4dv/master/examples/assets/data/cars-w-year.csv\"}, \"mark\": \"line\", \"encoding\": {\"x\": {\"field\": \"Year\", \"type\": \"ordinal\"}, \"y\": {\"aggregate\": \"mean\", \"field\": \"Weight\", \"type\": \"quantitative\"}}}, {\"mode\": \"vega-lite\"});\n",
"</script>"
]
},
"metadata": {}
}
],
"source": [
"display(alt.display.html_renderer(response['visList'][0]['vlSpec']), raw=True)"
]
},
{
"cell_type": "markdown",
"id": "3ff55ee3-0120-4f8a-894d-cfcbcca1006a",
"metadata": {
"id": "3ff55ee3-0120-4f8a-894d-cfcbcca1006a"
},
"source": [
"## Following Up on a Previous Conversation"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "c57d0733-f1e1-4d1d-b84c-bab7aa21f423",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 276
},
"id": "c57d0733-f1e1-4d1d-b84c-bab7aa21f423",
"outputId": "cb435401-ea7d-478f-c42d-cedb110c36ae"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
"<div id=\"altair-viz-41d1fe1827a24d028dfdc6d634c7500e\"></div>\n",
"<script type=\"text/javascript\">\n",
" var VEGA_DEBUG = (typeof VEGA_DEBUG == \"undefined\") ? {} : VEGA_DEBUG;\n",
" (function(spec, embedOpt){\n",
" let outputDiv = document.currentScript.previousElementSibling;\n",
" if (outputDiv.id !== \"altair-viz-41d1fe1827a24d028dfdc6d634c7500e\") {\n",
" outputDiv = document.getElementById(\"altair-viz-41d1fe1827a24d028dfdc6d634c7500e\");\n",
" }\n",
" const paths = {\n",
" \"vega\": \"https://cdn.jsdelivr.net/npm//vega@5?noext\",\n",
" \"vega-lib\": \"https://cdn.jsdelivr.net/npm//vega-lib?noext\",\n",
" \"vega-lite\": \"https://cdn.jsdelivr.net/npm//[email protected]?noext\",\n",
" \"vega-embed\": \"https://cdn.jsdelivr.net/npm//vega-embed@6?noext\",\n",
" };\n",
"\n",
" function maybeLoadScript(lib, version) {\n",
" var key = `${lib.replace(\"-\", \"\")}_version`;\n",
" return (VEGA_DEBUG[key] == version) ?\n",
" Promise.resolve(paths[lib]) :\n",
" new Promise(function(resolve, reject) {\n",
" var s = document.createElement('script');\n",
" document.getElementsByTagName(\"head\")[0].appendChild(s);\n",
" s.async = true;\n",
" s.onload = () => {\n",
" VEGA_DEBUG[key] = version;\n",
" return resolve(paths[lib]);\n",
" };\n",
" s.onerror = () => reject(`Error loading script: ${paths[lib]}`);\n",
" s.src = paths[lib];\n",
" });\n",
" }\n",
"\n",
" function showError(err) {\n",
" outputDiv.innerHTML = `<div class=\"error\" style=\"color:red;\">${err}</div>`;\n",
" throw err;\n",
" }\n",
"\n",
" function displayChart(vegaEmbed) {\n",
" vegaEmbed(outputDiv, spec, embedOpt)\n",
" .catch(err => showError(`Javascript Error: ${err.message}<br>This usually means there's a typo in your chart specification. See the javascript console for the full traceback.`));\n",
" }\n",
"\n",
" if(typeof define === \"function\" && define.amd) {\n",
" requirejs.config({paths});\n",
" require([\"vega-embed\"], displayChart, err => showError(`Error loading script: ${err.message}`));\n",
" } else {\n",
" maybeLoadScript(\"vega\", \"5\")\n",
" .then(() => maybeLoadScript(\"vega-lite\", \"4.17.0\"))\n",
" .then(() => maybeLoadScript(\"vega-embed\", \"6\"))\n",
" .catch(showError)\n",
" .then(() => displayChart(vegaEmbed));\n",
" }\n",
" })({\"$schema\": \"https://vega.github.io/schema/vega-lite/v4.json\", \"data\": {\"url\": \"https://raw.githubusercontent.com/nl4dv/nl4dv/master/examples/assets/data/cars-w-year.csv\"}, \"mark\": \"bar\", \"encoding\": {\"x\": {\"field\": \"Year\", \"type\": \"ordinal\"}, \"y\": {\"aggregate\": \"mean\", \"field\": \"Weight\", \"type\": \"quantitative\"}, \"color\": {\"field\": \"Cylinders\", \"type\": \"nominal\"}}}, {\"mode\": \"vega-lite\"});\n",
"</script>"
]
},
"metadata": {}
}
],
"source": [
"response = nl4dv_instance.analyze_query(\"Average weight of cars grouped by year and by the number of cylinders.\", dialog=True, dialog_id=0, query_id=1)\n",
"display(alt.display.html_renderer(response['visList'][0]['vlSpec']), raw=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c93cfd94-37c6-4f99-876f-9df79e03d6d1",
"metadata": {
"id": "c93cfd94-37c6-4f99-876f-9df79e03d6d1"
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"colab": {
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment