Last active
September 21, 2024 10:27
-
-
Save mark-mishyn/5381dab38eb2a4bf2cca9f6e86c9458f to your computer and use it in GitHub Desktop.
Simple example of generation JSON Schema for OpenAI function_call with Python and Pydandic
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import openai | |
from pydantic import BaseModel, PrivateAttr | |
from src.common.config import Config | |
openai.api_key = Config.OPEN_AI_API_KEY | |
class FunctionDefinition(BaseModel): | |
@classmethod | |
def to_function_definition(cls): | |
return { | |
"name": cls.__name__, | |
"description": cls.__doc__, | |
"parameters": cls.schema(), | |
} | |
class CalculateSalaryFirstDepartment(FunctionDefinition): | |
"""Makes calculations of salary for first department""" | |
hourly_rate: int | |
regular_hours: int | |
overtime_hours: int | |
bonus: int = 0 | |
_overtime_rate_multiplier: float = PrivateAttr(default=1.5) | |
def process(self) -> int: | |
return int( | |
self.hourly_rate * self.regular_hours | |
+ self._overtime_rate_multiplier * self.hourly_rate * self.overtime_hours | |
+ self.bonus | |
) | |
def calculation_description(self): | |
return ( | |
f"Calculation: (hourly_rate * regular_hours) + " | |
f"(overtime_rate_multiplier * hourly_rate * overtime_hours) + bonus\n" | |
f"({self.hourly_rate} * {self.regular_hours}) + " | |
f"({self._overtime_rate_multiplier} * {self.hourly_rate} * {self.overtime_hours}) + {self.bonus}" | |
) | |
class CalculateSalarySecondDepartment(CalculateSalaryFirstDepartment): | |
"""Makes calculations of salary for second department""" | |
_overtime_rate_multiplier: float = PrivateAttr(default=2) | |
def calculate_salary(user_input: str) -> int: | |
""" | |
Return daily salary | |
""" | |
system_msg = ( | |
"You are a helper to calculate daily employee salary. " | |
"Parse user input and extract arguments for the function call. " | |
"IMPORTANT: If user worked more than 8 hours per day, extra time is considered as " | |
"overtime_hours. For example: if user reported 10 hours, " | |
"it means 8 hours of regular_hours and 2 hours is overtime_hours. " | |
) | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo-0613", | |
messages=[ | |
{"role": "system", "content": system_msg}, | |
{"role": "user", "content": user_input}, | |
], | |
functions=[ | |
CalculateSalaryFirstDepartment.to_function_definition(), | |
CalculateSalarySecondDepartment.to_function_definition(), | |
], | |
) | |
choice = response["choices"][0] | |
# chat can ask for clarification | |
# so while loop can be required in case of realtime interaction with end user | |
if not choice["message"].get("function_call"): | |
return choice["message"]["content"] | |
function_call = choice["message"]["function_call"] | |
# select class by function name | |
calculator_class = globals()[function_call["name"]] | |
function_arguments = json.loads(function_call["arguments"]) | |
# create class instance with arguments | |
calculator = calculator_class(**function_arguments) | |
print(calculator.calculation_description()) | |
return calculator.process() | |
>>> calculate_salary("First department. Today I worked 10 hours. My rate - $20, bonus $100") | |
Calculation: (hourly_rate * regular_hours) + (overtime_rate_multiplier * hourly_rate * overtime_hours) + bonus | |
(20 * 8) + (1.5 * 20 * 2) + 100 | |
Out[10]: 320 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment