Created
October 8, 2022 05:06
-
-
Save goodside/454169131c93ab5c57a9cdfac0754028 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Toy demonstration of chain-of-thought and consensus prompting using OpenAI API. | |
© Riley Goodside 2022 | |
""" | |
import os | |
import re | |
from statistics import mode | |
import openai | |
try: | |
openai.api_key = os.environ["OPENAI_API_KEY"] | |
except KeyError: | |
raise RuntimeError("Please set the OPENAI_API_KEY environment variable.") | |
def complete(prompt: str, **kwargs): | |
defaults = {"engine": "text-davinci-002"} | |
kwargs = defaults | kwargs | |
response = openai.Completion.create(prompt=prompt, **kwargs) | |
return response.choices[0].text.strip() | |
def calculate_step_by_step(question: str, max_retries=3) -> str: | |
""" | |
Answer a math question via chain-of-thought prompting. | |
Retry until the result looks like a single number. | |
""" | |
prompt = f"Q: {question}\nA: Let's think step by step." | |
long_answer = complete(prompt, max_tokens=128, temperature=0.5) | |
extraction_prompt = ( | |
f"{prompt} {long_answer}\nTherefore, the answer (Arabic numerals) is" | |
) | |
short_answer = complete(extraction_prompt, max_tokens=32, temperature=0) | |
short_answer = short_answer.strip().rstrip(".").replace(",", "").split("=")[-1] | |
try: | |
short_answer = re.findall(r"-?\d+\.?\d*", short_answer)[0] | |
except IndexError: | |
if max_retries > 0: | |
return calculate_step_by_step(question, max_retries - 1) | |
else: | |
raise RuntimeError(f"Could not extract answer from '{short_answer}'") | |
return short_answer | |
def answer_by_consensus(question: str, n=10) -> str: | |
return mode(calculate_step_by_step(question) for _ in range(n)) | |
EXAMPLE_QUESTION = """\ | |
Q: What is x + 6 * y^5 where x is the sum of the squares of the individual digits of the | |
release year of Miley Cyrus's "Bangerz" and y is the day-of-month portion of Harry | |
Styles's birthday?\ | |
""" | |
if __name__ == "__main__": | |
print("Q:", EXAMPLE_QUESTION) | |
print("A:", answer_by_consensus(EXAMPLE_QUESTION)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment