Created
January 17, 2023 19:27
-
-
Save jflam/60e0265ff225c8c07ad0656d2dbcd139 to your computer and use it in GitHub Desktop.
Prompt engineering with pandas and GPT-3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
import pandas as pd | |
from langchain.prompts import PromptTemplate | |
from langchain.llms import OpenAI | |
from IPython.display import display, Markdown | |
llm = OpenAI(temperature=0.2, max_tokens=1000) | |
prompt = PromptTemplate( | |
input_variables=[ | |
"df_name", | |
"values", | |
"column_types", | |
"numerical_stats", | |
"result_type", | |
"query"], | |
template=""" | |
The following information is about a dataframe referenced by {df_name} | |
local variable. | |
The first 3 rows of the dataframe are: | |
{values} | |
This is some information about the data types of the columns: | |
{column_types} | |
This is some statistical information about the numerical columns: | |
{numerical_stats} | |
Assume that pandas is already imported in the pd namespace. | |
{result_type} | |
Query: | |
{query} | |
""" | |
) | |
# This is a hack to get the name of the variable that the dataframe is. | |
# Adjust the number of f.back based on how many functions deep you are | |
def retrieve_name(var): | |
callers_local_vars = inspect.currentframe().f_back.f_back.f_back.f_locals.items() | |
return [var_name for var_name, var_val in callers_local_vars if var_val is var] | |
def print_result_as_code(result: str): | |
display( | |
Markdown(f""" | |
```python | |
{result} | |
``` | |
""")) | |
def print_result_as_markdown(result: str): | |
display(Markdown(result)) | |
class DeepWrangler: | |
def __init__(self, dataframe): | |
self.dataframe = dataframe | |
def gen(self, query: str): | |
"""Generate Python code in response to the query.""" | |
result_type = """ | |
Only generate Python code in your response to the query. Do not generate | |
any explanatory text. Comments within the code are fine. | |
""" | |
result = self._query(query, result_type) | |
print_result_as_code(result) | |
def ask(self, query: str): | |
"""Ask a question about the dataframe in natural language.""" | |
result_type = """ | |
Format your response to the query in Markdown. | |
""" | |
result = self._query(query, result_type) | |
print_result_as_markdown(result) | |
def _query(self, query: str, result_type: str) -> str: | |
if not hasattr(self, 'local_var_name'): | |
self.local_var_name = retrieve_name(self.dataframe)[0] | |
# Hacky | |
pd.set_option("display.max_columns", None) | |
values = str(self.dataframe.iloc[:3,:]) | |
pd.reset_option("display.max_columns") | |
df_name = self.local_var_name | |
column_types = str(self.dataframe.dtypes) | |
numerical_stats = self.dataframe[self.dataframe.columns].describe() | |
# The prompt will be something really simple for now like: | |
# Some metadata about the columns follow: | |
openai_query = prompt.format( | |
df_name=df_name, | |
values=values, | |
column_types=column_types, | |
numerical_stats=numerical_stats, | |
result_type=result_type, | |
query=query) | |
return llm(openai_query) | |
# Monkeypatch the pandas DataFrame constructor to inject the DeepWrangler object | |
original_dataframe_init = pd.DataFrame.__init__ | |
def new_dataframe_init(self, *args, **kwargs): | |
original_dataframe_init(self, *args, **kwargs) | |
if not hasattr(self, 'dw'): | |
dw = DeepWrangler(self) | |
self.dw = dw | |
else: | |
raise AttributeError("dw attribute already exists on DataFrame object") | |
pd.DataFrame.__init__ = new_dataframe_init |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment