-
-
Save andrwj/cd073f7db8c0d7e66c36887e01606a4b to your computer and use it in GitHub Desktop.
Simple Python script to invoke ChatGPT API.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# | |
# Takes a chat transcript (for ChatGPT) on stdin, calls the OpenAI | |
# ChatGPT API, and prints the response on stdout. | |
# | |
# Your OpenAI API key must be set in the environment variable | |
# OPENAI_API_KEY. | |
# | |
# Logs are written to ~/chat.log. | |
import sys | |
import os | |
import re | |
import openai | |
import datetime | |
import logging | |
# All input/output to the API is logged here. | |
# Expand filename to full path. | |
LOG_FILE = os.path.expanduser('~/chat.log') | |
# Read stdin and save into a string. | |
def read_stdin(): | |
return sys.stdin.read() | |
def set_openai_key(): | |
# Read key from env var OPENAI_API_KEY. | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
def parse_input(input): | |
logging.debug(input) | |
# Split input into lines and iterate over them. | |
s = '' | |
cur_role = '' | |
messages = [] | |
for line in input.splitlines(): | |
# Check if line matches a regex like '%user%. | |
r = re.match(r"^%(.+)%", line) | |
if r: | |
messages.append({"role": cur_role, "content": s}) | |
s = '' | |
cur_role = r.group(1) | |
else: | |
s += line | |
messages.append({"role": cur_role, "content": s}) | |
return messages[1:] | |
def get_response(messages): | |
completion = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=messages, | |
) | |
logging.debug(completion) | |
reply = completion["choices"][0]["message"]["content"] | |
role = completion["choices"][0]["message"]["role"] | |
logging.debug('==== OUTPUT⇟\n') | |
logging.debug(reply) | |
logging.debug(role) | |
return reply | |
def main(): | |
# Send logging messages to a file. | |
logging.basicConfig(filename=LOG_FILE, level=logging.DEBUG) | |
# Get date and time in standard format. | |
logging.debug(str(datetime.datetime.now())) | |
messages = parse_input(read_stdin()) | |
logging.debug('==== INPUT⇟\n') | |
logging.debug(messages) | |
set_openai_key() | |
reply = get_response(messages) | |
print(reply) | |
if __name__ == '__main__': | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment