Last active
July 15, 2023 17:40
-
-
Save kyo-takano/2802bedd32aad0e562065d9e0c566796 to your computer and use it in GitHub Desktop.
Iteratively generate, review, and improve a code snippet using OpenAI's `Completion` and `Embedding` APIs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Iteratively generate, review, and improve a code snippet using OpenAI's `Completion` and `Embedding` APIs """ | |
import os | |
import openai | |
import textwrap | |
from scipy.spatial import distance | |
import matplotlib.pyplot as plt | |
from tqdm import trange | |
from IPython.display import display, Code, clear_output | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
max_steps = 10 | |
distance_threshold = 0.01 | |
history = { | |
"code": [], | |
"embeddings":[], | |
"distance":[] | |
} | |
query = "Function to compute linear projection from a source matrix to a target matrix using least squares regression" | |
for counter in trange(max_steps): | |
if counter: | |
prompt = textwrap.dedent( | |
# Naive zero-shot fix **without** Chain-of-Thought | |
f"""\ | |
The following code snippet should suffice the following specification:\ | |
{query} | |
```python | |
{history["code"][-1]} | |
``` | |
After fixing potential issues, it can be rewritten as: | |
```python | |
""" | |
) | |
else: | |
prompt = textwrap.dedent( | |
# Few-shot exemplars on README style markdown | |
f"""\ | |
Code descriptions as query and corresponding python objects written by OpenAI experts. | |
--- | |
**Description (1)**: | |
Function to compute the cumulative product of values in the list x, using logprob\ | |
**Response**: | |
```python | |
import numpy as np | |
def logsumexp(x): | |
return np.exp(np.sum(np.log(x))) | |
``` | |
**Description (2)**: | |
seaborn function to plot loss values in list h with both x and y displayed on a log scale\ | |
**Response**: | |
```python | |
import seaborn as sns | |
def func(h): | |
sns.lineplot(x=range(len(h)), y=h, logx=True, logy=True) | |
``` | |
**Description (3)**: | |
{query}\ | |
**Response**: | |
```python | |
""" | |
) | |
response = openai.Completion.create( | |
model="text-davinci-003", | |
prompt=prompt, | |
max_tokens=1024, | |
temperature=0, | |
stop="```" | |
) | |
code = response["choices"][0]["text"] | |
history["code"].append(code) | |
response = openai.Embedding.create( | |
input=code, | |
model="text-embedding-ada-002" | |
) | |
embeddings = response['data'][0]['embedding'] | |
history['embeddings'].append(embeddings) | |
if counter: | |
cosine_distance = distance.cosine(*history['embeddings'][-2:]) | |
history['distance'].append(cosine_distance) | |
clear_output() | |
display(Code(textwrap.dedent(f"""```python | |
{code.strip()} | |
```"""))) | |
plt.plot([None]+history['distance'], "-o") | |
plt.ylabel("Cosine distance") | |
plt.xlabel("Update steps") | |
plt.show() | |
if cosine_distance < distance_threshold: | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment