Last active
September 24, 2024 23:54
-
-
Save jeremedia/7e874bc6283a10ce8b4d2746413d3ce4 to your computer and use it in GitHub Desktop.
Ruby implementation of OpenAI structured outputs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'json' | |
require 'dry-schema' | |
require 'openai' | |
require 'ostruct' | |
module StructuredOutputs | |
# Schema class for defining JSON schemas | |
class Schema | |
MAX_OBJECT_PROPERTIES = 100 | |
MAX_NESTING_DEPTH = 5 | |
def initialize(name = nil, &block) | |
# Use the provided name or derive from class name | |
@name = name || self.class.name.split('::').last.downcase | |
# Initialize the base schema structure | |
@schema = { | |
type: 'object', | |
properties: {}, | |
required: [], | |
additionalProperties: false, | |
strict: true | |
} | |
@definitions = {} | |
# Execute the provided block to define the schema | |
instance_eval(&block) if block_given? | |
validate_schema | |
end | |
# Convert the schema to a hash format | |
def to_hash | |
{ | |
name: @name, | |
description: "Schema for the structured response", | |
schema: @schema.merge({ '$defs' => @definitions }) | |
} | |
end | |
private | |
# Define a string property | |
def string(name, enum: nil, description: nil) | |
add_property(name, { type: 'string', enum: enum, description: description }.compact) | |
end | |
# Define a number property | |
def number(name) | |
add_property(name, { type: 'number' }) | |
end | |
# Define a boolean property | |
def boolean(name) | |
add_property(name, { type: 'boolean' }) | |
end | |
# Define an object property | |
def object(name, &block) | |
properties = {} | |
required = [] | |
Schema.new.tap do |s| | |
s.instance_eval(&block) | |
properties = s.instance_variable_get(:@schema)[:properties] | |
required = s.instance_variable_get(:@schema)[:required] | |
end | |
add_property(name, { type: 'object', properties: properties, required: required, additionalProperties: false }) | |
end | |
# Define an array property | |
def array(name, items:) | |
add_property(name, { type: 'array', items: items }) | |
end | |
# Define an anyOf property | |
def any_of(name, schemas) | |
add_property(name, { anyOf: schemas }) | |
end | |
# Define a reusable schema component | |
def define(name, &block) | |
@definitions[name] = Schema.new(&block).instance_variable_get(:@schema) | |
end | |
# Reference a defined schema component | |
def ref(name) | |
{ '$ref' => "#/$defs/#{name}" } | |
end | |
# Add a property to the schema | |
def add_property(name, definition) | |
@schema[:properties][name] = definition | |
@schema[:required] << name | |
end | |
# Validate the schema against defined limits | |
def validate_schema | |
properties_count = count_properties(@schema) | |
raise 'Exceeded maximum number of object properties' if properties_count > MAX_OBJECT_PROPERTIES | |
max_depth = calculate_max_depth(@schema) | |
raise 'Exceeded maximum nesting depth' if max_depth > MAX_NESTING_DEPTH | |
end | |
# Count the total number of properties in the schema | |
def count_properties(schema) | |
return 0 unless schema.is_a?(Hash) && schema[:properties] | |
count = schema[:properties].size | |
schema[:properties].each_value do |prop| | |
count += count_properties(prop) | |
end | |
count | |
end | |
# Calculate the maximum nesting depth of the schema | |
def calculate_max_depth(schema, current_depth = 1) | |
return current_depth unless schema.is_a?(Hash) && schema[:properties] | |
max_child_depth = schema[:properties].values.map do |prop| | |
calculate_max_depth(prop, current_depth + 1) | |
end.max | |
[current_depth, max_child_depth].max | |
end | |
end | |
# Client class for interacting with OpenAI API | |
class OpenAIClient | |
def initialize | |
OpenAI.configure do |config| | |
config.access_token = ENV.fetch("OPENAI_ACCESS_TOKEN") | |
config.log_errors = true | |
end | |
@client = OpenAI::Client.new | |
end | |
# Send a request to OpenAI API and parse the response | |
def parse(model:, messages:, response_format:) | |
response = @client.chat( | |
parameters: { | |
model: model, | |
messages: messages, | |
response_format: { | |
type: "json_schema", | |
json_schema: response_format.to_hash | |
} | |
} | |
) | |
content = JSON.parse(response['choices'][0]['message']['content']) | |
if response['choices'][0]['message']['refusal'] | |
OpenStruct.new(refusal: response['choices'][0]['message']['refusal'], parsed: nil) | |
else | |
OpenStruct.new(refusal: nil, parsed: content) | |
end | |
end | |
end | |
end | |
# Example usage: Define a schema for math reasoning | |
class MathReasoning < StructuredOutputs::Schema | |
def initialize | |
super do | |
define :step do | |
string :explanation | |
string :output | |
end | |
array :steps, items: ref(:step) | |
string :final_answer | |
end | |
end | |
end | |
begin | |
# Create an OpenAI client | |
client = StructuredOutputs::OpenAIClient.new | |
# Create an instance of the MathReasoning schema | |
schema = MathReasoning.new | |
# Send a request to OpenAI API | |
result = client.parse( | |
model: "gpt-4o-2024-08-06", | |
messages: [ | |
{ role: "system", content: "You are a helpful math tutor. Guide the user through the solution step by step." }, | |
{ role: "user", content: "how can I solve 8x + 7 = -23" } | |
], | |
response_format: schema | |
) | |
# Handle the response | |
if result.refusal | |
puts "The model refused to respond: #{result.refusal}" | |
else | |
puts JSON.pretty_generate(result.parsed) | |
end | |
rescue => e | |
puts "Error: #{e}" | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@danielfriis thanks! Updated.