Last active
January 8, 2025 01:42
-
-
Save elct9620/15ca1f0311bc736301c3396193c36861 to your computer and use it in GitHub Desktop.
A minimal AI agent inspired by ihower
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# frozen_string_literal: true | |
require 'bundler/inline' | |
gemfile do | |
source 'https://rubygems.org' | |
gem 'dotenv' | |
gem 'ruby-openai' | |
end | |
Dotenv.load | |
module Agent | |
# State is interface to represent the thread state and concrete actions | |
# | |
# RBS: | |
# interface _State | |
# def call(thread: Thread) -> _State | |
# end | |
module State | |
# initial state | |
class Start | |
def call(*) = Input.new | |
end | |
# stop the conversation | |
class Stop | |
def call(*) = nil | |
end | |
# get the user input | |
class Input | |
def call(thread:) | |
input = thread.gets | |
return Stop.new if input == 'exit' | |
thread.add(role: 'user', content: input) | |
Assistant.new | |
end | |
end | |
# Assistant is a state to call the agent | |
class Assistant | |
def call(thread:) | |
res = thread.agent.call(messages: thread.memory.messages) | |
thread.add(res) | |
thread.puts(res['content']) if res['content'] | |
return Tool.new if res['tool_calls']&.any? | |
Input.new | |
end | |
end | |
# Tool is a state to call the tool | |
class Tool | |
def call(thread:) # rubocop:disable Metrics/MethodLength | |
thread.memory.last['tool_calls'].each do |call| | |
name = call.dig('function', 'name') | |
params = JSON.parse(call.dig('function', 'arguments'), symbolize_names: true) | |
res = if thread.agent.tool?(name) | |
thread.agent.tool(name).call(**params) | |
else | |
{ success: false, content: "Tool not found: #{name}" } | |
end | |
thread.add(role: 'tool', content: res.to_json, tool_call_id: call['id']) | |
end | |
Assistant.new | |
end | |
end | |
end | |
# Tool is provide additional actions to the agent | |
# | |
# RBS: | |
# interface _Tool | |
# def name() -> String | |
# def description() -> String | |
# def parameters() -> Hash | |
# def call() -> untyped | |
# end | |
class Tool | |
attr_reader :name, :description, :parameters | |
def initialize(name:, description:, parameters: {}) | |
@name = name | |
@description = description | |
@parameters = parameters | |
end | |
def call(**) | |
raise NotImplementedError | |
end | |
end | |
# Agent is a LLM adapter to provide a chat interface | |
# | |
# RBS: | |
# interface _Agent | |
# def call(messages: Array[Hash[Symbol, String]]) -> Hash[Symbol, String] | |
# def tool?(String) -> bool | |
# def tool(String) -> _Tool | |
# end | |
# | |
# In this example, the agent is implemented a OpenAI Agent | |
class OpenAI | |
attr_reader :client, :model | |
def initialize(client:, tools: [], model: 'gpt-4o-mini') | |
@client = client | |
@model = model | |
@_tools = tools | |
end | |
def call(messages:) | |
client.chat( | |
parameters: { | |
model: model, | |
messages: messages, | |
tools: tools | |
} | |
).dig('choices', 0, 'message') | |
end | |
def tools | |
@tools ||= @_tools.map do |tool| | |
{ | |
type: :function, | |
function: { | |
name: tool.name, | |
description: tool.description, | |
parameters: tool.parameters | |
} | |
} | |
end | |
end | |
def tool?(name) | |
@_tools.any? { |tool| tool.name == name.to_sym } | |
end | |
def tool(name) | |
@_tools.find { |tool| tool.name == name.to_sym } | |
end | |
end | |
# Memory is a storage to keep the conversation history | |
# | |
# RBS: | |
# interface _Memory | |
# def <<(String) -> void | |
# def messages -> Array[Hash[Symbol, String]] | |
# def last -> Hash[Symbol, String] | |
# end | |
class Memory | |
def initialize | |
@messages = [] | |
end | |
def <<(message) | |
@messages << message | |
end | |
def messages | |
@messages.dup.freeze | |
end | |
def last | |
@messages.last | |
end | |
end | |
# Thread is a chat context as a state machine | |
# | |
# RBS: | |
# interface _Thread | |
# def agent -> _Agent | |
# def memory -> _Memory | |
# def gets -> String | |
# def puts(String) -> void | |
# def add(Hash) -> void | |
# end | |
class Thread | |
attr_reader :agent, :state, :input, :output, :memory, :verbose | |
def initialize(agent:, input: $stdin, output: $stdout, memory: Memory.new, verbose: false) | |
@agent = agent | |
@input = input | |
@output = output | |
@state = State::Start.new | |
@memory = memory | |
@verbose = verbose | |
end | |
def run(prompt) | |
add(role: 'system', content: prompt) | |
@state = state.call(thread: self) while state | |
end | |
# Read the input from the user | |
def gets | |
print '> ' | |
input.gets.chomp | |
end | |
# Write the content to the output | |
def puts(content) | |
output.puts(content) | |
end | |
# Add a message to the memory | |
def add(message) | |
puts JSON.pretty_generate(message) if verbose | |
memory << message | |
end | |
end | |
end | |
# :nodoc: | |
class AddToCartTool < Agent::Tool | |
def initialize # rubocop:disable Metrics/MethodLength | |
super( | |
name: :add_to_cart, | |
description: 'Add the product to the cart', | |
parameters: { | |
type: 'object', | |
properties: { | |
name: { type: 'string', description: 'The name of the product' }, | |
quantity: { type: 'number', description: 'The quantity of the product' } | |
} | |
} | |
) | |
end | |
def call(name:, quantity:) | |
{ success: true, content: "Added #{quantity} #{name} to the cart" } | |
end | |
end | |
agent = Agent::OpenAI.new( | |
client: OpenAI::Client.new(access_token: ENV['OPENAI_API_KEY']), | |
tools: [AddToCartTool.new] | |
) | |
thread = Agent::Thread.new( | |
agent: agent, | |
verbose: ARGV.include?('--verbose') | |
) | |
ARGV.clear | |
thread.run( | |
<<~PROMPT | |
你是一個書店購物AI助手。你的任務是: | |
1. 根據用戶提供的書名和數量 | |
2. 當資訊不完整時,主動詢問缺失的商品名稱和數量 | |
3. 當無法使用工具時,最多嘗試三次,並回傳錯誤訊息 | |
不要任何額外的解釋或判斷,僅需回傳函式呼叫。 | |
PROMPT | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# frozen_string_literal: true | |
require 'bundler/inline' | |
gemfile do | |
source 'https://rubygems.org' | |
gem 'dotenv' | |
gem 'ruby-openai' | |
end | |
Dotenv.load | |
module Agent | |
# State is interface to represent the thread state and concrete actions | |
# | |
# RBS: | |
# interface _State | |
# def call(thread: Thread) -> _State | |
# end | |
module State | |
# initial state | |
class Start | |
def call(*) = Input.new | |
end | |
# stop the conversation | |
class Stop | |
def call(*) = nil | |
end | |
# get the user input | |
class Input | |
def call(thread:) | |
input = thread.gets | |
return Stop.new if input == 'exit' | |
thread.add(role: 'user', content: input) | |
Assistant.new | |
end | |
end | |
# Assistant is a state to call the agent | |
class Assistant | |
def call(thread:) | |
res = thread.agent.call(messages: thread.memory.messages) | |
thread.add(res) | |
thread.puts(res['content']) if res['content'] | |
return Tool.new if res['tool_calls']&.any? | |
Input.new | |
end | |
end | |
# Tool is a state to call the tool | |
class Tool | |
def call(thread:) # rubocop:disable Metrics/MethodLength | |
thread.memory.last['tool_calls'].each do |call| | |
name = call.dig('function', 'name') | |
params = JSON.parse(call.dig('function', 'arguments'), symbolize_names: true) | |
res = if thread.agent.tool?(name) | |
thread.agent.tool(name).call(thread:, **params) | |
else | |
{ success: false, content: "Tool not found: #{name}" } | |
end | |
thread.add(role: 'tool', content: res.to_json, tool_call_id: call['id']) | |
end | |
Assistant.new | |
end | |
end | |
end | |
# Tool is provide additional actions to the agent | |
# | |
# RBS: | |
# interface _Tool | |
# def name() -> String | |
# def description() -> String | |
# def parameters() -> Hash | |
# def call() -> untyped | |
# end | |
class Tool | |
attr_reader :name, :description, :parameters | |
def initialize(name:, description:, parameters: {}) | |
@name = name | |
@description = description | |
@parameters = parameters | |
end | |
def call(**) | |
raise NotImplementedError | |
end | |
end | |
# Agent is a LLM adapter to provide a chat interface | |
# | |
# RBS: | |
# interface _Agent | |
# def call(messages: Array[Hash[Symbol, String]]) -> Hash[Symbol, String] | |
# def tool?(String) -> bool | |
# def tool(String) -> _Tool | |
# end | |
# | |
# In this example, the agent is implemented a OpenAI Agent | |
class OpenAI | |
attr_reader :client, :model | |
def initialize(client:, tools: [], model: 'gpt-4o-mini') | |
@client = client | |
@model = model | |
@_tools = tools | |
end | |
def call(messages:) | |
client.chat( | |
parameters: { | |
model: model, | |
messages: messages, | |
tools: tools | |
} | |
).dig('choices', 0, 'message') | |
end | |
def tools | |
@tools ||= @_tools.map do |tool| | |
{ | |
type: :function, | |
function: { | |
name: tool.name, | |
description: tool.description, | |
parameters: tool.parameters | |
} | |
} | |
end | |
end | |
def tool?(name) | |
@_tools.any? { |tool| tool.name == name.to_sym } | |
end | |
def tool(name) | |
@_tools.find { |tool| tool.name == name.to_sym } | |
end | |
end | |
# Memory is a storage to keep the conversation history | |
# | |
# RBS: | |
# interface _Memory | |
# def <<(String) -> void | |
# def messages -> Array[Hash[Symbol, String]] | |
# def last -> Hash[Symbol, String] | |
# end | |
class Memory | |
def initialize | |
@messages = [] | |
end | |
def <<(message) | |
@messages << message | |
end | |
def messages | |
@messages.dup.freeze | |
end | |
def last | |
@messages.last | |
end | |
end | |
# User Interface is designed to interact with the user | |
# | |
# RBS: | |
# interface _UI | |
# def gets -> String | |
# def puts(String) -> void | |
# end | |
class UI | |
def initialize(input: $stdin, output: $stdout) | |
@input = input | |
@output = output | |
end | |
def gets | |
print '> ' | |
@input.gets.chomp | |
end | |
def puts(content) | |
@output.puts(content) | |
end | |
end | |
# Thread is a chat context as a state machine | |
# | |
# RBS: | |
# interface _Thread | |
# def agent -> _Agent | |
# def memory -> _Memory | |
# def gets -> String | |
# def puts(String) -> void | |
# def add(Hash) -> void | |
# def agent?(name: String) -> bool | |
# def transfer(agent: String) -> void | |
# end | |
class Thread | |
extend Forwardable | |
delegate %i[gets puts] => :user_interface | |
attr_reader :agent, :agents, :state, :user_interface, :memory, :verbose | |
def initialize(agent:, agents: {}, user_interface: UI.new, memory: Memory.new, verbose: false) | |
@agent = agent | |
@agents = agents | |
@user_interface = user_interface | |
@state = State::Start.new | |
@memory = memory | |
@verbose = verbose | |
end | |
def run(prompt) | |
add(role: 'system', content: prompt) | |
@state = state.call(thread: self) while state | |
end | |
# Add a message to the memory | |
def add(message) | |
puts JSON.pretty_generate(message) if verbose | |
memory << message | |
end | |
# Check if the agent exists | |
def agent?(name:) | |
agents.key?(name.to_sym) | |
end | |
# Transfer to the agent | |
def transfer(agent:) | |
new_agent = agents[agent.to_sym] | |
return unless new_agent | |
@agent = new_agent | |
end | |
end | |
end | |
# :nodoc: | |
class TransferTool < Agent::Tool | |
def initialize(description:) | |
super( | |
name: :transfer, | |
description: description, | |
parameters: { | |
type: 'object', | |
properties: { | |
agent: { type: 'string', description: 'The agent name' } | |
} | |
} | |
) | |
end | |
def call(agent:, thread:, **) | |
return { success: false, content: "Agent not found: #{agent}" } unless thread.agent?(name: agent) | |
thread.transfer(agent: agent) | |
{ success: true, content: "Agent transferred to: #{agent}" } | |
end | |
end | |
# :nodoc: | |
class AddToCartTool < Agent::Tool | |
def initialize # rubocop:disable Metrics/MethodLength | |
super( | |
name: :add_to_cart, | |
description: 'Add the product to the cart', | |
parameters: { | |
type: 'object', | |
properties: { | |
name: { type: 'string', description: 'The name of the product' }, | |
quantity: { type: 'number', description: 'The quantity of the product' } | |
} | |
} | |
) | |
end | |
def call(name:, quantity:, **) | |
{ success: true, content: "Added #{quantity} #{name} to the cart" } | |
end | |
end | |
# :nodoc: | |
class CheckoutTool < Agent::Tool | |
def initialize | |
super( | |
name: :checkout, | |
description: 'Checkout the cart' | |
) | |
end | |
def call(**) | |
{ success: true, content: 'Checkout the cart' } | |
end | |
end | |
openai = OpenAI::Client.new(access_token: ENV['OPENAI_API_KEY']) | |
shopping = Agent::OpenAI.new( | |
client: openai, | |
tools: [ | |
AddToCartTool.new, | |
TransferTool.new(description: 'Can transfer to another agent, available agents: checkout') | |
] | |
) | |
checkout = Agent::OpenAI.new( | |
client: openai, | |
tools: [ | |
CheckoutTool.new, | |
TransferTool.new(description: 'Can transfer to another agent, available agents: shopping') | |
] | |
) | |
thread = Agent::Thread.new( | |
agent: shopping, | |
agents: { | |
shopping: shopping, | |
checkout: checkout | |
}, | |
verbose: ARGV.include?('--verbose') | |
) | |
ARGV.clear | |
thread.run( | |
<<~PROMPT | |
你是一個書店購物AI助手。你的任務是: | |
1. 根據用戶提供的書名和數量 | |
2. 當資訊不完整時,主動詢問缺失的商品名稱和數量 | |
3. 當無法使用工具時,最多嘗試三次,並回傳錯誤訊息 | |
不要任何額外的解釋或判斷,僅需回傳函式呼叫。 | |
PROMPT | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment