Created
April 27, 2012 14:12
-
-
Save greyblake/2509589 to your computer and use it in GitHub Desktop.
art.rb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Art | |
L = 2 | |
MAX_EPOCH = 20 | |
def initialize(n, m, p = 0.45) | |
@n = n | |
@m = m | |
@p = p | |
@w1 = Array.new(n) { Array.new(m) { 1.0 / (1+n) } } | |
@w2 = Array.new(m) { Array.new(n) { 1.0 } } | |
end | |
def teach(vecs) | |
print_epoch(0) | |
MAX_EPOCH.times do |i| | |
w1_old = @w1.dup | |
w2_old = @w2.dup | |
process_epoch(vecs) | |
print_epoch(i+1) | |
break if w1_old == @w1 && w2_old == @w2 | |
end | |
end | |
def process_epoch(vecs) | |
vecs.each do |vec| | |
process_vector(vec) | |
end | |
end | |
def process_vector(vec) | |
# step 4 | |
s_out = vec | |
# step 5 | |
s_out_norm = sum(s_out) | |
# step 6 | |
z_in = s_out.dup | |
z_out = z_in.dup | |
# step 7 | |
y_in = calc_y_in(z_out) | |
# step 8 | |
while true # !!!DANGEROUS!!! | |
#step 9 | |
jj = y_in.index(y_in.sort.last) | |
#step 10 | |
z_out = [] | |
@n.times do |i| | |
z_out[i] = s_out[i] * @w2[jj][i] | |
end | |
# step 11 | |
z_out_norm = sum(z_out) | |
# step 12 | |
if (z_out_norm.to_f / s_out_norm) > @p | |
break | |
else | |
y_in[jj] = -1 | |
end | |
end | |
# step 13 | |
@n.times do |i| | |
@w1[i][jj] = (L * z_out[i]).to_f / (L-1+z_out_norm) | |
end | |
@w2[jj] = z_out.dup | |
end | |
def to_s | |
out = "w1:\n" | |
out << w_to_s(@w1) | |
out << "w2:\n" | |
out << w_to_s(@w2, "%-3d") | |
out << "\n\n" | |
out | |
end | |
private | |
def sum(vec) | |
vec.inject(&:+) | |
end | |
def calc_y_in(vec) | |
y_in = Array.new(@m) { 0 } | |
@m.times do |j| | |
@n.times do |i| | |
y_in[j] += @w1[i][j] * vec[i] | |
end | |
end | |
y_in | |
end | |
def w_to_s(w, format = "%-7.3f") | |
out = "" | |
w.each do |raw| | |
raw.each do |el| | |
out << format % [el] | |
end | |
out << "\n" | |
end | |
out | |
end | |
def print_epoch(num) | |
puts "=== Epoch #{num} ===" | |
puts self | |
end | |
end | |
vectors = [ | |
[1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0], | |
[1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0], | |
[1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0], | |
[1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0], | |
[1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0], | |
[1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0], | |
[0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1], | |
[0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1], | |
[0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1] | |
] | |
ps = 0.4, 0.7, 0.9 | |
ps.each do |p| | |
puts "===== P = #{p} =====" | |
net = Art.new(13, 9, p) | |
net.teach(vectors) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment