Last active
September 21, 2024 10:22
-
-
Save IzumiSatoshi/f72bc628652895c800c4c3ed0763dfba to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import subprocess | |
import openai | |
import re | |
from collections import defaultdict | |
import textwrap | |
import time | |
import shlex | |
openai.api_key = open("./openai_key.txt", "r").read().strip("\n") | |
MODEL = "gpt-4" | |
TERMINAL_LINES = 20 | |
LOG_FILE = "./log.txt" | |
def main(): | |
terminal = Terminal("MYSES") | |
system_prompt = textwrap.dedent( | |
"""\ | |
You are an AI that must communicate using the following format: | |
Record your thought process step-by-step within <thought></thought> blocks. | |
Provide direct responses to the user in <reply></reply> blocks. Users will only see content within these blocks. | |
Execute Linux commands to answer user instructions. Write the code within <action></action> blocks, using only one command at a time. | |
For code, use the echo command and redirect (>) within inside <action></action>. | |
Add the -y option or similar when executing commands to ensure automatic completion. | |
Display the current terminal status in <terminal></terminal> after <action></action>. | |
Example (User's instruction is "please create python program that print 'hoge'"): | |
<thought> | |
1. I should create python file. | |
2. Then, I'll run it and check if it's working. | |
</thought> | |
<action> | |
``` | |
echo "print('hoge')" > hoge.py | |
``` | |
</action> | |
<terminal> | |
(mnist_env) 81809@n1:/mnt/disks/disk_main/projects/GPT/playground$ echo "print('hoge')" > hoge.py | |
(mnist_env) 81809@n1:/mnt/disks/disk_main/projects/GPT/playground$ | |
</terminal> | |
<thought> | |
1. I should run it and check if it's working. | |
</thought> | |
<action> | |
``` | |
python hoge.py | |
``` | |
</action> | |
<terminal> | |
(mnist_env) 81809@n1:/mnt/disks/disk_main/projects/GPT/playground$ python hoge.py | |
hoge | |
(mnist_env) 81809@n1:/mnt/disks/disk_main/projects/GPT/playground$ | |
</terminal> | |
<reply> | |
The program I created worked correctly. | |
</reply> | |
..... Repeat this thought/action/terminal/reply sequence as needed to address user instructions. Do not write outside of these blocks. Remember that you can execute code. | |
""" | |
) | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
] | |
col_widths = [50, 50] | |
print_columns( | |
[ | |
"[CHAT ROOM]", | |
"[GPT's THOUGHT]", | |
], | |
col_widths=col_widths, | |
) | |
while True: | |
user_input = input("USER: ") | |
messages.append({"role": "user", "content": user_input}) | |
while True: | |
write_messages(messages, LOG_FILE) | |
gpt_response = call_gpt(messages, stop="<terminal>") | |
gpt_response_dict, last_block = parse_response(gpt_response) | |
assistant_reply = "\n".join(gpt_response_dict["reply"]) | |
print_columns( | |
[ | |
f"ASSISTANT: {assistant_reply}", | |
"\n".join(gpt_response_dict["thought"]), | |
], | |
col_widths=col_widths, | |
) | |
if last_block == "reply": | |
messages.append({"role": "assistant", "content": gpt_response}) | |
break | |
if len(gpt_response_dict["action"]) > 0: | |
linux_command = gpt_response_dict["action"][-1].strip("`").strip("\n") | |
terminal_str = terminal.execute_command(linux_command) | |
next_prompt = get_next_prompt(gpt_response, terminal_str) | |
messages.append({"role": "assistant", "content": next_prompt}) | |
else: | |
messages.append({"role": "assistant", "content": gpt_response}) | |
def parse_response(text): | |
pattern = r"<(\w+)>(.*?)</\1>" | |
matches = re.findall(pattern, text, re.DOTALL) | |
result = defaultdict(list) | |
last_block = None | |
for key, value in matches: | |
result[key].append(value.strip()) | |
last_block = key | |
return result, last_block | |
def call_gpt(messages, stop=None, max_retries=5): | |
for attempt in range(max_retries): | |
try: | |
print("calling api") | |
completion = openai.ChatCompletion.create( | |
model=MODEL, | |
messages=messages, | |
stop=stop, | |
temperature=0, | |
) | |
res = completion.choices[0].message.content | |
return res | |
except Exception as e: | |
print( | |
f"API error: {e}. Attempt {attempt + 1} of {max_retries}. Retrying..." | |
) | |
print("Reached maximum retries. Aborting.") | |
return None | |
def get_next_prompt(prev_response, terminal_str): | |
# TODO: so bad | |
next_prompt = f"""\ | |
{prev_response} | |
<terminal> | |
{terminal_str} | |
</terminal>\ | |
""" | |
return next_prompt | |
class Terminal: | |
""" | |
with tmux | |
""" | |
def __init__(self, session_name): | |
self.session_name = session_name | |
def execute_command(self, command): | |
# Send the command to the tmux session | |
subprocess.run(["tmux", "send-keys", "-lt", self.session_name, command]) | |
# Send the Enter key to execute the command | |
subprocess.run(["tmux", "send-keys", "-t", self.session_name, "Enter"]) | |
# Wait for execution (implement the wait_for_execution() method) | |
self.wait_for_execution() | |
# Capture the terminal output (implement the capture() method) | |
terminal_state = "\n".join( | |
self.capture().split("\n")[-TERMINAL_LINES:] | |
) # last {TERMINAL_LINES} rows | |
return terminal_state | |
def capture(self): | |
# save to buffer | |
capture_command = f"tmux capture-pane -t {self.session_name}" | |
subprocess.run(capture_command, shell=True) | |
# show buffer | |
show_buffer_command = "tmux show-buffer" | |
result = subprocess.run( | |
show_buffer_command, shell=True, capture_output=True, text=True | |
) | |
# TODO: strip() may cause some errors | |
text = result.stdout.strip() | |
return text | |
def extract_last_command_output(self, text): | |
lines = text.split("\n") | |
assert lines[-1][-1] == "$", "process haven't finished" | |
second_to_last_command_line_idx = 0 # in the case big output | |
for i, line in enumerate(lines[:-1]): | |
if "$ " in line: | |
second_to_last_command_line_idx = i | |
return "\n".join(lines[second_to_last_command_line_idx + 1 : -1]) | |
def wait_for_execution(self): | |
print("waiting execution...") | |
while True: | |
time.sleep(1) | |
if self.capture().split("\n")[-1][-1] == "$": | |
break | |
print("execution done") | |
def write_messages(messages, file_path): | |
message_log = "" | |
for message in messages: | |
role = message["role"] | |
content = message["content"] | |
message_log += f"###{role.capitalize()}###\n{content}\n" | |
with open(file_path, "w") as file: | |
file.write(message_log) | |
def display_messages(messages): | |
for message in messages: | |
role = message["role"] | |
content = message["content"] | |
print(f"{role.capitalize()}:\n{content}\n") | |
def wrap_text(text, width): | |
lines = text.split("\n") | |
wrapped_lines = [] | |
for line in lines: | |
wrapped_lines.extend(line[i : i + width] for i in range(0, len(line), width)) | |
return wrapped_lines | |
def print_columns(strings: list[str], col_widths: list[int]): | |
lines = [wrap_text(s, col_widths[i]) for i, s in enumerate(strings)] | |
max_lines = max([len(l) for l in lines]) | |
for i in range(max_lines): | |
row = [] | |
for j, l in enumerate(lines): | |
line = l[i] if i < len(l) else "" | |
row.append(f"{line:<{col_widths[j]}}") | |
print(" | ".join(row)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment