-
-
Save hatappi/07e73126e2c54ee87743c52df627e5b4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env ruby | |
require 'matrix' | |
CLASS_1 = 1 | |
CLASS_2 = 2 | |
data = Matrix.rows([ | |
[ 1.0, CLASS_1], | |
[ 0.5, CLASS_1], | |
[-0.2, CLASS_1], | |
[-0.4, CLASS_2], | |
[-1.3, CLASS_2], | |
[-2.0, CLASS_2], | |
]) | |
features = Matrix.rows(data.column(0).map { |v| Vector.elements([v]) }) # 特徴ベクトル行列 | |
labels = data.column(1) | |
wvec = Vector.elements([0.2, 0.3]) # 初期の重みベクトル | |
xvecs = Matrix.hstack(Matrix.build(features.row_count, 1) { 1 }, features) # xvec[0] = 1 | |
# 重み係数の学習 | |
def train(xvecs, wvec, tvec) | |
rho = 0.2 # 学習係数 | |
wvec.each_with_index do |w, key| | |
sum = 0 | |
xvecs.row_count.times do |i| | |
xvec = xvecs.row(i) | |
b = tvec[i] | |
wx = wvec.dot(xvec) | |
sum += (wx - b) * xvec[key] | |
wvec = wvec.collect do |v| | |
if wvec[key] == v | |
wvec[key] - low * sum # (3.33) | |
else | |
v | |
end | |
end | |
end | |
end | |
wvec | |
end | |
# 重み係数を求める | |
def get_wvec(xvecs, wvec, tvec) | |
loop = 100 | |
loop.times.each do | |
wvec = train(xvecs, wvec, tvec) | |
end | |
wvec | |
end | |
def detect_class(wvec1, wvec2, x) | |
g1 = wvec1[1] * x + wvec1[0] | |
g2 = wvec2[1] * x + wvec2[0] | |
return CLASS_1 if g1 > g2 | |
return CLASS_2 if g1 < g2 | |
end | |
# クラス1について | |
# クラス1の教師ベクトル | |
tvec1 = labels.clone.map do |c| | |
if c == CLASS_1 | |
1 | |
elsif c == CLASS_2 | |
0 | |
end | |
end | |
wvec1 = get_wvec(xvecs, wvec, tvec1) | |
puts 'wvec1 = %s' % wvec1 | |
puts 'g1(x) = %f x + %f' % [wvec1[1], wvec1[0]] | |
# xvecs.row_count.times do |i| | |
# xvec = xvecs.row(i) | |
# label = labels[i] | |
# puts 'g1(%s) = %s (クラス:%s)' % [xvec[1], wvec1.dot(xvec), label] | |
# end | |
# クラス2について | |
# クラス2の教師ベクトル | |
tvec2 = labels.clone.map do |c| | |
if c == CLASS_1 | |
0 | |
elsif c == CLASS_2 | |
1 | |
end | |
end | |
wvec2 = get_wvec(xvecs, wvec, tvec2) | |
puts 'wvec2 = %s' % wvec2 | |
puts 'g2(x) = %f x + %f' % [wvec2[1], wvec2[0]] | |
# xvecs.row_count.times do |i| | |
# xvec = xvecs.row(i) | |
# label = labels[i] | |
# puts 'g2(%s) = %s (クラス:%s)' % [xvec[1], wvec2.dot(xvec), label] | |
# end | |
[-0.4, -0.39, -0.38, -0.37, -0.35, -0.3, -0.2].each do |x| | |
klass = detect_class(wvec1, wvec2, x) | |
puts "#{x.to_s}: class #{klass}" | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment