Created
December 27, 2024 20:11
-
-
Save dearmadisonblue/f73da23358a6571eab2d36a895cbac25 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import openai | |
from bs4 import BeautifulSoup | |
from loguru import logger | |
import tiktoken | |
# Remove default handlers and add a new one to standard output | |
logger.remove() | |
logger.add( | |
lambda msg: print(msg), | |
format="{time:YYYY-MM-DD HH:mm:ss} {level} {message}" | |
) | |
class Model: | |
def __init__(self, api_key: str, system_prompt: str): | |
openai.api_key = api_key | |
self.system_prompt = system_prompt | |
self.encoding = tiktoken.encoding_for_model("gpt-4o") | |
def __call__(self, user_prompt: str) -> str: | |
messages = [ | |
{"role": "system", "content": self.system_prompt}, | |
{"role": "user", "content": user_prompt} | |
] | |
num_tokens_sent = self.count_tokens(messages) | |
logger.info(f"Sending {num_tokens_sent} tokens to OpenAI API: {messages}") | |
response = openai.chat.completions.create( | |
model="gpt-4o", | |
messages=messages | |
) | |
response_content = response.choices[0].message.content | |
num_tokens_received = self.count_tokens(response_content) | |
logger.info(f"Received {num_tokens_received} tokens from OpenAI API: {response_content}") | |
return response_content | |
def count_tokens(self, messages): | |
"""Counts the number of tokens in a list of messages or a string.""" | |
num_tokens = 0 | |
if isinstance(messages, str): | |
num_tokens += len(self.encoding.encode(messages)) | |
elif isinstance(messages, list): | |
for message in messages: | |
num_tokens += len(self.encoding.encode(message["content"])) | |
num_tokens += 4 | |
num_tokens += 2 | |
else: | |
raise ValueError("Input must be either a string or a list of messages.") | |
return num_tokens | |
class Filter: | |
def __call__(self, input_state: str, model: Model) -> str: | |
raise NotImplementedError | |
class Rule(Filter): | |
def __init__(self, if_then_rule: str): | |
self.if_then_rule = if_then_rule | |
def __call__(self, input_state: str, model: Model) -> str: | |
# Wrap the input state in XML state tags | |
user_prompt = f"<state>{input_state}</state>\n\n{self.if_then_rule}" | |
gpt_response_text = model(user_prompt) | |
soup = BeautifulSoup(gpt_response_text, 'xml') | |
state_tag = soup.find('state') | |
error_tag = soup.find('error') | |
if state_tag: | |
return state_tag.text.strip() | |
elif error_tag: | |
raise ValueError(error_tag.text.strip()) | |
else: | |
raise ValueError(f"Unexpected response format from GPT: {gpt_response_text}") | |
class Sequence(Filter): | |
def __init__(self, body: list[Filter]): | |
self.body = body | |
def __call__(self, input_state: str, model: Model) -> str: | |
current_state = input_state | |
for filter_item in self.body: | |
current_state = filter_item(current_state, model) | |
return current_state | |
class Alternate(Filter): | |
def __init__(self, body: list[Filter]): | |
self.body = body | |
def __call__(self, input_state: str, model: Model) -> str: | |
for filter_item in self.body: | |
try: | |
return filter_item(input_state, model) | |
except ValueError: | |
pass | |
raise ValueError("None of the filters in Alternate were successful.") | |
class Any(Filter): | |
def __init__(self, body: Filter): | |
self.body = body | |
def __call__(self, input_state: str, model: Model) -> str: | |
current_state = input_state | |
while True: | |
try: | |
current_state = self.body(current_state, model) | |
except ValueError: | |
return current_state | |
class Some(Filter): | |
def __init__(self, body: Filter): | |
self.body = body | |
def __call__(self, input_state: str, model: Model) -> str: | |
current_state = input_state | |
current_state = self.body(current_state, model) # run it once, raise exception if it fails | |
while True: | |
try: | |
current_state = self.body(current_state, model) | |
except ValueError: | |
return current_state | |
class Maybe(Filter): | |
def __init__(self, body: Filter): | |
self.body = body | |
def __call__(self, input_state: str, model: Model) -> str: | |
try: | |
return self.body(input_state, model) | |
except ValueError: | |
return input_state | |
if __name__ == '__main__': | |
api_key = os.environ['OPENAI_API_KEY'] | |
# Define the system prompt to explain the XML application behavior | |
system_prompt = """The assistant is an XML application that transforms states or signals errors. | |
The user provides an input state wrapped in <state> tags and an if-then rule. | |
The assistant responds with either a new state wrapped in <state> tags or an error wrapped in <error> tags. | |
The assistant transforms the input state into a new state based on the if-then rule provided by the user. | |
If the rule is applicable to the input state, the assistant responds with the new state in <state> tags. | |
If the rule is inapplicable or invalid, the assistant responds with an error message in <error> tags. | |
Always respond with either a <state> tag or an <error> tag. | |
""" | |
model = Model(api_key=api_key, system_prompt=system_prompt) | |
# Define rules for still life painting theme | |
add_fruit_rule = Rule(if_then_rule="If the prompt does NOT mention fruit, then add a description of a fruit, such as a bunch of grapes or a sliced orange.") | |
add_floral_element_rule = Rule(if_then_rule="If the prompt does NOT mention flowers, then add a description of a floral element, such as a vase of roses or a single lily.") | |
add_reflective_object_rule = Rule(if_then_rule="If the prompt does NOT mention a reflective object, then add a description of a reflective object, such as a silver goblet or a glass vase.") | |
add_drapery_rule = Rule(if_then_rule="If the prompt mentions an object on a table, then describe a rich drapery underneath the object.") | |
add_another_object_rule = Rule(if_then_rule="If the prompt EXPLICITLY mentions LESS than three objects, then add another object that would fit the scene, such as a ceramic bowl or a book.") | |
# Test Sequence | |
sequence = Sequence(body=[add_fruit_rule, add_floral_element_rule, add_reflective_object_rule]) | |
sequence_initial_state = "A dimly lit table." | |
try: | |
transformed_sequence_state = sequence(sequence_initial_state, model) | |
logger.info(f"Transformed prompt (Sequence): {transformed_sequence_state}") | |
except ValueError as e: | |
logger.error(f"Error transforming prompt (Sequence): {e}") | |
# Test Alternate | |
# Make sure the first rule will fail with the given initial state | |
alternate = Alternate(body=[add_drapery_rule, add_fruit_rule]) | |
alternate_initial_state = "A lone book." # add_drapery_rule will fail because there is no object on a table | |
try: | |
transformed_alternate_state = alternate(alternate_initial_state, model) | |
logger.info(f"Transformed prompt (Alternate): {transformed_alternate_state}") | |
except ValueError as e: | |
logger.error(f"Error transforming prompt (Alternate): {e}") | |
# Test Any | |
any_filter = Any(body=add_another_object_rule) | |
any_initial_state = "An empty room with a wooden table." | |
try: | |
transformed_any_state = any_filter(any_initial_state, model) | |
logger.info(f"Transformed prompt (Any): {transformed_any_state}") | |
except ValueError as e: | |
logger.error(f"Error transforming prompt (Any): {e}") | |
# Test Some | |
some_filter = Some(body=add_another_object_rule) | |
some_initial_state = "A ceramic jug." | |
try: | |
transformed_some_state = some_filter(some_initial_state, model) | |
logger.info(f"Transformed prompt (Some): {transformed_some_state}") | |
except ValueError as e: | |
logger.error(f"Error transforming prompt (Some): {e}") | |
# Test Maybe | |
maybe_filter = Maybe(body=add_floral_element_rule) | |
maybe_initial_state = "A composition with a copper kettle and a ceramic bowl." | |
try: | |
transformed_maybe_state = maybe_filter(maybe_initial_state, model) | |
logger.info(f"Transformed prompt (Maybe): {transformed_maybe_state}") | |
except ValueError as e: | |
logger.error(f"Error transforming prompt (Maybe): {e}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment