Created
April 15, 2023 02:13
-
-
Save veryeasily/28c9ee43f770f146d5a95c204f9f794d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import openai | |
import urllib.parse | |
import peewee as pw | |
from flask import Flask, request | |
from .local_env import LocalEnv | |
TOTAL_MESSAGES = 10 | |
db = pw.SqliteDatabase(LocalEnv.DATABASE_URL) | |
# Represents a message to eventually be sent to the OpenAI ChatCompletion API | |
class Message(pw.Model): | |
id = pw.PrimaryKeyField() | |
role = pw.TextField() | |
content = pw.TextField() | |
created_at = pw.DateTimeField(constraints=[pw.SQL("DEFAULT CURRENT_TIMESTAMP")]) | |
updated_at = pw.DateTimeField(constraints=[pw.SQL("DEFAULT CURRENT_TIMESTAMP")]) | |
# Saves a message to the database, but also deletes the oldest message if | |
# the database has reached more than TOTAL_MESSAGES | |
@classmethod | |
def save_message_with_limit(cls, role: str, content: str): | |
cls.create(role=role, content=content) | |
if cls.select().count() > TOTAL_MESSAGES: | |
oldest_message = cls.select().order_by(cls.created_at).first() | |
oldest_message.delete_instance() | |
class Meta: | |
database = db | |
db_table = "messages" | |
# Represents a system prompt stored in the sqlite database. There should only be one row. | |
class Prompt(pw.Model): | |
id = pw.PrimaryKeyField() | |
content = pw.TextField() | |
created_at = pw.DateTimeField(constraints=[pw.SQL("DEFAULT CURRENT_TIMESTAMP")]) | |
updated_at = pw.DateTimeField(constraints=[pw.SQL("DEFAULT CURRENT_TIMESTAMP")]) | |
class Meta: | |
database = db | |
db_table = "prompts" | |
class Bot: | |
def __init__(self): | |
self.env = LocalEnv.ENV_NAME | |
def register_app(self, app: Flask): | |
@app.get(f"/{self.env}/prompt", endpoint=f"{self.env}_prompt") | |
def prompt(): | |
prompt = Prompt.select().first() | |
return prompt.content | |
# Updates prompt and deletes all messages | |
@app.put(f"/{self.env}/prompt", endpoint=f"{self.env}_put_prompt") | |
def update_prompt(): | |
data = request.get_json() | |
Prompt.update(content=data["text"]).execute() | |
Message.delete().execute() | |
prompt = Prompt.select().first() | |
return prompt.content | |
@app.get(f"/{self.env}/chat/<string:message>", endpoint=f"{self.env}_chat") | |
def chat(message: str): | |
message = urllib.parse.unquote(message) | |
Message.save_message_with_limit(role="user", content=message) | |
seen_messages = [ | |
{"role": m.role, "content": m.content} for m in Message.select() | |
] | |
system_content = Prompt.select().first().content | |
messages = [{"role": "system", "content": system_content}] + seen_messages | |
completions = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=messages, | |
temperature=0.9, | |
presence_penalty=0.6, | |
max_tokens=88, | |
) | |
response = completions["choices"][0]["message"]["content"] | |
Message.save_message_with_limit(role="assistant", content=response) | |
return response | |
bot = Bot() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment