Last active
December 2, 2024 19:26
-
-
Save Filimoa/bd883bbbd6475993f3f6d26212750135 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Literal | |
from bank_database import DatabaseConn | |
from openai_sdk import client | |
from pydantic import BaseModel, Field | |
RESPOND_TO_USER_QUERY_SYSTEM_PROMPT = """ | |
You are a support agent in our bank, give the best advice to the user based on their query. | |
If the user is asking for their balance, give them the balance. | |
If the user is asking to block their card, block their card. | |
Otherwise, give general advice. | |
Context: | |
{context} | |
User query: | |
{user_query} | |
""" | |
class CustomerBalanceContext(BaseModel): | |
balance: float = Field(description="The balance of the customer") | |
class CustomerCardContext(BaseModel): | |
name: str = Field(description="The name of the customer") | |
cards: list[dict] = Field(description="The cards of the customer") | |
# lots of other stuff | |
class UnknownQueryContext(BaseModel): | |
pass | |
class ClassifiedQuery(BaseModel): | |
query_type: Literal["balance", "card", "other"] = Field( | |
description="The type of query" | |
) | |
risk: int = Field(description="Risk level of query", ge=0, le=10) | |
async def handle_card_issue(db: DatabaseConn, customer_id: int) -> CustomerCardContext: | |
name = await db.query( | |
"SELECT name FROM customers WHERE id = %s", | |
(customer_id,), | |
).first() | |
cards = await db.query( | |
"SELECT * FROM cards WHERE customer_id = %s", | |
(customer_id,), | |
) | |
# do important stuff here | |
return CustomerCardContext(name=name, cards=cards) | |
async def get_customer_balance( | |
db: DatabaseConn, customer_id: int, include_pending: bool | |
) -> CustomerBalanceContext: | |
return await db.query( | |
"SELECT balance FROM customers WHERE id = %s AND include_pending = %s", | |
(customer_id, include_pending), | |
).first() | |
async def support_agent(query: str, customer_id: int, db: DatabaseConn): | |
classified_query = client.chat.completions.create( | |
model="gpt-4o-mini", | |
response_model=ClassifiedQuery, | |
messages=[{"role": "user", "content": query}], | |
) | |
context = None | |
if classified_query.query_type == "balance": | |
context = await get_customer_balance(db, customer_id, include_pending=True) | |
elif classified_query.query_type == "card": | |
context = await handle_card_issue(db, customer_id) | |
else: | |
context = UnknownQueryContext() | |
return await client.chat.completions.create( | |
model="gpt-4o-mini", | |
response_model=RESPOND_TO_USER_QUERY_SYSTEM_PROMPT.format( | |
context=context, user_query=query | |
), | |
messages=[{"role": "user", "content": query}], | |
) | |
async def main(): | |
customer_id = 123 | |
db = DatabaseConn() | |
custom_query = "What is my balance?" | |
result = await support_agent(custom_query, customer_id, db) | |
print(result.data) | |
""" | |
support_advice='Hello John, your current account balance, including pending transactions, is $123.45.' block_card=False risk=1 | |
""" | |
custom_query = "I just lost my card!" | |
result = await support_agent(custom_query, customer_id, db) | |
print(result.data) | |
""" | |
support_advice="I'm sorry to hear that, John. We are temporarily blocking your card to prevent unauthorized transactions." block_card=True risk=8 | |
""" | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment