Last active
December 15, 2015 12:58
-
-
Save mitmul/5263704 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# encoding: utf-8 | |
require "narray" | |
class DownhillSimplex | |
def initialize(objective_function, simplex_num, min, max, dimension) | |
@obj_func = objective_function | |
@alpha = 1.0 | |
@beta = 0.5 | |
@gamma = 2.0 | |
@dim = dimension | |
@simplex = simplex_num.times.inject([]){|a, i| a << NVector.float(@dim).collect{|v| v = rand(min..max)}} | |
all_eval | |
calc_params | |
end | |
attr_accessor :x_l, :x_l_eval, :x_h, :x_h_eval, :x_path | |
def converge(epsilon) | |
# 基準値がεより小さくなるまで繰り返す | |
while criterion > epsilon | |
optimize | |
end | |
end | |
def criterion | |
2 * (@x_h_eval - @x_l_eval).abs / (@x_h_eval.abs + @x_l_eval.abs) | |
end | |
def optimize | |
# まずは反射位置を計算 | |
@x_r = reflect | |
x_r_eval = @obj_func.call(@x_r) | |
# 反射位置が最良だったら | |
if x_r_eval < @x_l_eval | |
# 拡大する(もっと進んでみる) | |
x_e = expand | |
x_e_eval = @obj_func.call(x_e) | |
# 拡大位置が最良だったら | |
if x_e_eval < x_r_eval | |
# 最悪点を拡大位置で置き換える | |
@simplex[@evaluation.sort_index[-1]] = x_e | |
# 拡大位置が最良でなかったら | |
else | |
# 最悪点を反射位置を置き換える | |
@simplex[@evaluation.sort_index[-1]] = @x_r | |
end | |
# 最悪点と2nd最悪点の間に反射点があるなら | |
elsif @x_s_eval < x_r_eval && x_r_eval < @x_h_eval | |
# 最悪点を反射点で置き換える | |
@simplex[@evaluation.sort_index[-1]] = @x_r | |
# 縮小操作 | |
x_c = contract | |
x_c_eval = @obj_func.call(x_c) | |
# 縮小点が最悪点より良ければ | |
if x_c_eval < @x_h_eval | |
# 最悪点を縮小点で置き換える | |
@simplex[@evaluation.sort_index[-1]] = x_c | |
# 縮小点が最悪点より悪ければ | |
else | |
# 収縮操作 | |
shrink | |
end | |
# いずれでもない場合 | |
else | |
# 最悪点を反射点で置き換える | |
@simplex[@evaluation.sort_index[-1]] = @x_r | |
end | |
all_eval | |
calc_params | |
# puts "x:(#{@x_l[0]},#{@x_l[1]})\teval:#{@x_l_eval}" | |
end | |
def reflect | |
(1 + @alpha) * @x_g - @alpha * @x_h | |
end | |
def expand | |
@gamma * @x_r + (1 - @gamma) * @x_g | |
end | |
def contract | |
@beta * @x_h + (1 - @beta) * @x_g | |
end | |
def shrink | |
@simplex.map! do |s| | |
s = 0.5 * (@x_l + s) | |
end | |
end | |
def calc_params | |
@x_h = @simplex[@evaluation.sort_index[-1]] # 最悪点 | |
@x_s = @simplex[@evaluation.sort_index[-2]] # 2番目に悪い点 | |
@x_l = @simplex[@evaluation.sort_index[0]] # 最良点 | |
@x_g = (@simplex.inject(NVector.float(2)){|g, s| g += s} - @x_h) / (@simplex.size - 1.0) # 最悪点を除いた重心 | |
# 各評価値 | |
@x_h_eval = @obj_func.call(@x_h) | |
@x_s_eval = @obj_func.call(@x_s) | |
@x_l_eval = @obj_func.call(@x_l) | |
end | |
def all_eval | |
@evaluation = NVector.to_na @simplex.inject([]){|a, s| a << @obj_func.call(s)} | |
end | |
end | |
# 目的関数 | |
obj_func = Proc.new do |x| | |
func_a = -Math.exp(-((x[0] + 5)**2 + (x[1] + 5)**2) / 2.0) | |
func_b = -4.0/5.0 * Math.exp(-((x[0] + 3)**2 + (x[1] + 3)**2) / 2.0) | |
func_a + func_b | |
end | |
opt = DownhillSimplex.new(obj_func, 10, -10.0, 10.0, 2) | |
require "Benchmark" | |
puts Benchmark.measure{ | |
opt.converge(1E-13) | |
} | |
puts "x:(#{opt.x_l[0]},#{opt.x_l[1]})\teval:#{opt.x_l_eval}" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment