Last active
November 30, 2023 22:58
-
-
Save bradhilton/41c6576b48af60b99b6d80f2e2a11b71 to your computer and use it in GitHub Desktop.
GPT magic functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import codecs | |
from IPython import get_ipython # type: ignore | |
from IPython.core.magic import register_line_cell_magic | |
from IPython.display import clear_output, display, Markdown, update_display # type: ignore | |
from openai import OpenAI | |
from openai.types.chat import ChatCompletionMessageParam | |
from openai.types.chat.completion_create_params import Function | |
import os | |
import re | |
import requests | |
import tiktoken | |
import time | |
from typing import Any, cast, Optional | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
run_description = """When you send a message containing Python code to run, it will be executed in a | |
stateful Jupyter notebook environment. | |
You can use this to execute any Python code and complete user requests. | |
Print, plot or display() anything you want the user to see in the notebook output. | |
Don't import any modules that have already been imported. | |
""" | |
functions: list[Function] = [ | |
{ | |
"name": "run", | |
"description": run_description, | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"code": { | |
"type": "string", | |
"description": "The Python code to run.", | |
}, | |
}, | |
"required": ["code"], | |
}, | |
}, | |
] | |
system_messages: list[ChatCompletionMessageParam] = [ | |
{ | |
"role": "system", | |
"content": r"Use dollar signs instead of brackets and parentheses for inline or block math expressions. For example: $ \sigma $ or $$ \sqrt{x_i} $$", | |
}, | |
# { | |
# "role": "system", | |
# "content": f"Be sure to call run() to fulfill user requests if necessary.", | |
# }, | |
] | |
def get_globals() -> dict[str, Any]: | |
return get_ipython().user_ns # type: ignore | |
def get_plain_text(output: dict[str, Any]) -> str: | |
for item in output["items"]: | |
if item["mime"] == "text/plain": | |
return item["data"] | |
for item in output["items"]: | |
if item["mime"] == "text/markdown": | |
return item["data"] | |
return output["items"][0]["data"] | |
def get_source(content: str) -> str: | |
content = content.replace("with_column(", "with_columns(") | |
if is_code_block(content): | |
return content[10:-4] | |
match = re.search(r'"code":\s*"((?:[^"\\]|\\.)*)', content) | |
if match: | |
return codecs.decode(match.group(1), "unicode_escape") | |
else: | |
return content | |
@register_line_cell_magic | |
def gpt( | |
line: str, | |
_cell: Optional[str] = None, | |
model="gpt-3.5-turbo", | |
function_name: str = "gpt", | |
max_tokens: int = 4096, | |
) -> None: | |
notebook = requests.get("http://localhost:4903/notebook").json() | |
cells = notebook["cells"] | |
messages = system_messages.copy() | |
for cell in cells: | |
if cell["kind"] == 1: | |
messages.append( | |
{ | |
"role": "system", | |
"content": f"Markdown cell:\n{cell['document']['text']}", | |
} | |
) | |
continue | |
cell_input = cell["document"]["text"] | |
cell_output = None | |
if cell["outputs"]: | |
cell_output = cell["outputs"][0]["items"][0]["data"] | |
is_cell_magic = cell_input.startswith(f"%%{function_name}") | |
is_line_magic = cell_input.startswith(f"%{function_name}") | |
if is_cell_magic or is_line_magic: | |
user_message = cell_input.replace( | |
f"%%{function_name}" if is_cell_magic else f"%{function_name}", "" | |
).strip() | |
if user_message == (_cell or line).rstrip(): | |
break | |
messages.append({"role": "user", "content": user_message}) | |
if cell_output: | |
messages.append({"role": "assistant", "content": cell_output.strip()}) | |
else: | |
messages.append( | |
{ | |
"role": "system", | |
"content": f"```python\n{cell_input}\n```" | |
+ "".join( | |
f"\nOutput:\n{get_plain_text(cell_output)}" | |
for cell_output in cell["outputs"] | |
), | |
} | |
) | |
messages.append({"role": "user", "content": (_cell or line).strip()}) | |
trim_messages(messages, max_tokens, model) | |
clear_output() | |
time.sleep(0.4) | |
display_id = display(Markdown(""), display_id=True).display_id # type: ignore | |
content = "" | |
last_update_time = time.time() | |
debounce_interval = 0.1 # seconds | |
chunks = client.chat.completions.create( | |
messages=messages, | |
model=model, | |
stream=True, | |
functions=functions, | |
temperature=0, | |
) | |
is_function_call = False | |
function_call_name = None | |
for chunk in chunks: | |
delta = chunk.choices[0].delta | |
if delta.function_call: | |
if not is_function_call: | |
if content: | |
update_display( | |
Markdown(content), | |
display_id=display_id, | |
) | |
display_id = display(Markdown(""), display_id=True).display_id # type: ignore | |
last_update_time = time.time() | |
content = "" | |
is_function_call = True | |
if function_call_name is None: | |
function_call_name = delta.function_call.name | |
content += delta.function_call.arguments or "" | |
else: | |
content += delta.content or "" | |
current_time = time.time() | |
if current_time - last_update_time >= debounce_interval: | |
update_display( | |
Markdown( | |
f"```python\n{get_source(content)}\n```" | |
if is_function_call | |
else content | |
), | |
display_id=display_id, | |
) | |
last_update_time = current_time | |
if is_function_call or is_code_block(content): | |
source = get_source(content) | |
time.sleep(0.1) | |
update_display( | |
Markdown(f"```python\n{source}\n```"), | |
display_id=display_id, | |
) | |
time.sleep(0.1) | |
exec_source = "\n".join(source.splitlines()[:-1]) | |
eval_source = source.splitlines()[-1] | |
try: | |
exec(exec_source, get_globals()) | |
except SyntaxError: | |
return exec(source, get_globals()) | |
try: | |
eval_result = eval(eval_source, get_globals()) | |
if eval_result is not None: | |
display(eval_result) | |
except SyntaxError: | |
exec(eval_source, get_globals()) | |
else: | |
update_display(Markdown(content), display_id=display_id) | |
@register_line_cell_magic | |
def gpt3(line: str, _cell: Optional[str] = None) -> None: | |
return gpt( | |
line, _cell, model="gpt-3.5-turbo-0613", function_name="gpt3", max_tokens=4096 | |
) | |
@register_line_cell_magic | |
def gpt4(line: str, _cell: Optional[str] = None) -> None: | |
return gpt( | |
line, | |
_cell, | |
model="gpt-4-1106-preview", | |
function_name="gpt4", | |
max_tokens=16384, # May be up to 131072 | |
) | |
return gpt(line, _cell, model="gpt-4-0613", function_name="gpt4", max_tokens=8192) | |
@register_line_cell_magic | |
def gpt16(line: str, _cell: Optional[str] = None) -> None: | |
return gpt( | |
line, | |
_cell, | |
model="gpt-3.5-turbo-16k-0613", | |
function_name="gpt16", | |
max_tokens=16384, | |
) | |
def is_code_block(content: str) -> bool: | |
return content.startswith("```python\n") and content.endswith("\n```") | |
def trim_messages( | |
messages: list[ChatCompletionMessageParam], max_tokens: int, model: str | |
) -> None: | |
messages_len = len(messages) | |
while num_tokens_from_messages(messages, model=model) > max_tokens - 1_000: | |
messages.pop(len(system_messages)) | |
if len(messages) < messages_len: | |
messages.insert( | |
len(system_messages), | |
{ | |
"role": "system", | |
"content": "Note: Some older messages have been removed for brevity.", | |
}, | |
) | |
def num_tokens_from_messages( | |
messages: list[ChatCompletionMessageParam], model: str | |
) -> int: | |
""" | |
Return the number of tokens used by a list of messages. | |
Source: | |
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |
""" | |
try: | |
encoding = tiktoken.encoding_for_model(model) | |
except KeyError: | |
print("Warning: model not found. Using cl100k_base encoding.") | |
encoding = tiktoken.get_encoding("cl100k_base") | |
if model in { | |
"gpt-3.5-turbo-0613", | |
"gpt-3.5-turbo-16k-0613", | |
"gpt-4-0314", | |
"gpt-4-32k-0314", | |
"gpt-4-0613", | |
"gpt-4-32k-0613", | |
"gpt-4-1106-preview", | |
}: | |
tokens_per_message = 3 | |
tokens_per_name = 1 | |
elif model == "gpt-3.5-turbo-0301": | |
tokens_per_message = ( | |
4 # every message follows <|start|>{role/name}\n{content}<|end|>\n | |
) | |
tokens_per_name = -1 # if there's a name, the role is omitted | |
elif "gpt-3.5-turbo" in model: | |
print( | |
"Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613." | |
) | |
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613") | |
elif "gpt-4" in model: | |
print( | |
"Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613." | |
) | |
return num_tokens_from_messages(messages, model="gpt-4-0613") | |
else: | |
raise NotImplementedError( | |
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""" | |
) | |
num_tokens = 0 | |
for message in messages: | |
num_tokens += tokens_per_message | |
for key, value in message.items(): | |
num_tokens += len(encoding.encode(cast(str, value))) | |
if key == "name": | |
num_tokens += tokens_per_name | |
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> | |
return num_tokens |
Updated the gist today to support function calling 😁
More updates
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Updated the gist today to include logic for handling token overflow by removing older messages.