Skip to content

Instantly share code, notes, and snippets.

@IzumiSatoshi
Last active September 21, 2024 10:22
Show Gist options
  • Save IzumiSatoshi/f72bc628652895c800c4c3ed0763dfba to your computer and use it in GitHub Desktop.
Save IzumiSatoshi/f72bc628652895c800c4c3ed0763dfba to your computer and use it in GitHub Desktop.
import subprocess
import openai
import re
from collections import defaultdict
import textwrap
import time
import shlex
openai.api_key = open("./openai_key.txt", "r").read().strip("\n")
MODEL = "gpt-4"
TERMINAL_LINES = 20
LOG_FILE = "./log.txt"
def main():
terminal = Terminal("MYSES")
system_prompt = textwrap.dedent(
"""\
You are an AI that must communicate using the following format:
Record your thought process step-by-step within <thought></thought> blocks.
Provide direct responses to the user in <reply></reply> blocks. Users will only see content within these blocks.
Execute Linux commands to answer user instructions. Write the code within <action></action> blocks, using only one command at a time.
For code, use the echo command and redirect (>) within inside <action></action>.
Add the -y option or similar when executing commands to ensure automatic completion.
Display the current terminal status in <terminal></terminal> after <action></action>.
Example (User's instruction is "please create python program that print 'hoge'"):
<thought>
1. I should create python file.
2. Then, I'll run it and check if it's working.
</thought>
<action>
```
echo "print('hoge')" > hoge.py
```
</action>
<terminal>
(mnist_env) 81809@n1:/mnt/disks/disk_main/projects/GPT/playground$ echo "print('hoge')" > hoge.py
(mnist_env) 81809@n1:/mnt/disks/disk_main/projects/GPT/playground$
</terminal>
<thought>
1. I should run it and check if it's working.
</thought>
<action>
```
python hoge.py
```
</action>
<terminal>
(mnist_env) 81809@n1:/mnt/disks/disk_main/projects/GPT/playground$ python hoge.py
hoge
(mnist_env) 81809@n1:/mnt/disks/disk_main/projects/GPT/playground$
</terminal>
<reply>
The program I created worked correctly.
</reply>
..... Repeat this thought/action/terminal/reply sequence as needed to address user instructions. Do not write outside of these blocks. Remember that you can execute code.
"""
)
messages = [
{"role": "system", "content": system_prompt},
]
col_widths = [50, 50]
print_columns(
[
"[CHAT ROOM]",
"[GPT's THOUGHT]",
],
col_widths=col_widths,
)
while True:
user_input = input("USER: ")
messages.append({"role": "user", "content": user_input})
while True:
write_messages(messages, LOG_FILE)
gpt_response = call_gpt(messages, stop="<terminal>")
gpt_response_dict, last_block = parse_response(gpt_response)
assistant_reply = "\n".join(gpt_response_dict["reply"])
print_columns(
[
f"ASSISTANT: {assistant_reply}",
"\n".join(gpt_response_dict["thought"]),
],
col_widths=col_widths,
)
if last_block == "reply":
messages.append({"role": "assistant", "content": gpt_response})
break
if len(gpt_response_dict["action"]) > 0:
linux_command = gpt_response_dict["action"][-1].strip("`").strip("\n")
terminal_str = terminal.execute_command(linux_command)
next_prompt = get_next_prompt(gpt_response, terminal_str)
messages.append({"role": "assistant", "content": next_prompt})
else:
messages.append({"role": "assistant", "content": gpt_response})
def parse_response(text):
pattern = r"<(\w+)>(.*?)</\1>"
matches = re.findall(pattern, text, re.DOTALL)
result = defaultdict(list)
last_block = None
for key, value in matches:
result[key].append(value.strip())
last_block = key
return result, last_block
def call_gpt(messages, stop=None, max_retries=5):
for attempt in range(max_retries):
try:
print("calling api")
completion = openai.ChatCompletion.create(
model=MODEL,
messages=messages,
stop=stop,
temperature=0,
)
res = completion.choices[0].message.content
return res
except Exception as e:
print(
f"API error: {e}. Attempt {attempt + 1} of {max_retries}. Retrying..."
)
print("Reached maximum retries. Aborting.")
return None
def get_next_prompt(prev_response, terminal_str):
# TODO: so bad
next_prompt = f"""\
{prev_response}
<terminal>
{terminal_str}
</terminal>\
"""
return next_prompt
class Terminal:
"""
with tmux
"""
def __init__(self, session_name):
self.session_name = session_name
def execute_command(self, command):
# Send the command to the tmux session
subprocess.run(["tmux", "send-keys", "-lt", self.session_name, command])
# Send the Enter key to execute the command
subprocess.run(["tmux", "send-keys", "-t", self.session_name, "Enter"])
# Wait for execution (implement the wait_for_execution() method)
self.wait_for_execution()
# Capture the terminal output (implement the capture() method)
terminal_state = "\n".join(
self.capture().split("\n")[-TERMINAL_LINES:]
) # last {TERMINAL_LINES} rows
return terminal_state
def capture(self):
# save to buffer
capture_command = f"tmux capture-pane -t {self.session_name}"
subprocess.run(capture_command, shell=True)
# show buffer
show_buffer_command = "tmux show-buffer"
result = subprocess.run(
show_buffer_command, shell=True, capture_output=True, text=True
)
# TODO: strip() may cause some errors
text = result.stdout.strip()
return text
def extract_last_command_output(self, text):
lines = text.split("\n")
assert lines[-1][-1] == "$", "process haven't finished"
second_to_last_command_line_idx = 0 # in the case big output
for i, line in enumerate(lines[:-1]):
if "$ " in line:
second_to_last_command_line_idx = i
return "\n".join(lines[second_to_last_command_line_idx + 1 : -1])
def wait_for_execution(self):
print("waiting execution...")
while True:
time.sleep(1)
if self.capture().split("\n")[-1][-1] == "$":
break
print("execution done")
def write_messages(messages, file_path):
message_log = ""
for message in messages:
role = message["role"]
content = message["content"]
message_log += f"###{role.capitalize()}###\n{content}\n"
with open(file_path, "w") as file:
file.write(message_log)
def display_messages(messages):
for message in messages:
role = message["role"]
content = message["content"]
print(f"{role.capitalize()}:\n{content}\n")
def wrap_text(text, width):
lines = text.split("\n")
wrapped_lines = []
for line in lines:
wrapped_lines.extend(line[i : i + width] for i in range(0, len(line), width))
return wrapped_lines
def print_columns(strings: list[str], col_widths: list[int]):
lines = [wrap_text(s, col_widths[i]) for i, s in enumerate(strings)]
max_lines = max([len(l) for l in lines])
for i in range(max_lines):
row = []
for j, l in enumerate(lines):
line = l[i] if i < len(l) else ""
row.append(f"{line:<{col_widths[j]}}")
print(" | ".join(row))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment