-
-
Save wbzyl/7468da035578d21a5fee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/ruby | |
require("gsl") | |
def nmf(v, col, thresh) | |
r, c = v.shape | |
# w * h = v | |
w = GSL::Matrix.alloc(r, col) | |
h = GSL::Matrix.alloc(col, c) | |
initrand(w, w.max) | |
initrand(h, h.max) | |
dist = thresh | |
iter = 0 | |
while(dist >= thresh && iter < 1000) | |
# multiplicative update rule | |
dist, w, h = update(v, w, h) | |
puts "#{iter}: #{dist}" #if (iter%10 == 0) | |
iter += 1 | |
end | |
puts "Ended in #{iter} iteration(s) with dist=#{dist}" | |
return w, h | |
end | |
def initrand(m, max) | |
r, c = m.shape | |
0.upto(r-1) { |i| | |
0.upto(c-1) { |j| | |
m[i,j] = rand(max) | |
} | |
} | |
end | |
def update(v, w, h) | |
# choose an update rule | |
# v, w, h are GSL::Matrix objects | |
dist, w, h = update_eu(v, w, h) | |
return dist, w, h | |
end | |
def update_eu(v, w, h) | |
# multiplicative update rule | |
# for minimizing euclidean distance | |
# v, w, h are GSL::Matrix objects | |
dist = dist_eu(v, w*h) | |
wt = w.transpose | |
ht = h.transpose | |
h = h.mul_elements(wt * v).div_elements(wt * w * h) | |
w = w.mul_elements(v * ht).div_elements(w * h * ht) | |
return dist, w, h | |
end | |
def update_kl(v, w, h) | |
# multiplicative update rule | |
# for minimizing divergence | |
# v, w, h are GSL::Matrix objects | |
dist = dist_kl(v, w*h) | |
wt = w.transpose | |
ht = h.transpose | |
num = v.div_elements(w * h) | |
rh, ch = h.shape | |
rw, cw = w.shape | |
0.upto(rh-1) { |i| | |
0.upto(ch-1) { |j| | |
h[i, j] *= w.col(i).mul(num.col(j)).sum / w.col(i).sum | |
} | |
} | |
0.upto(rw-1) { |i| | |
0.upto(cw-1) { |j| | |
w[i, j] *= h.row(j).mul(num.row(i)).sum / h.row(j).sum | |
} | |
} | |
return dist, w, h | |
end | |
# Euclidean distance | |
def dist_eu(a, b) | |
# a,b are GSL::Matrix objects | |
dist = 0 | |
(a-b).collect { |d| dist += d * d} | |
dist | |
end | |
# Kullback-Leibler divergence | |
def dist_kl(a, b) | |
# a, b are GSL::Matrix objects | |
dist = 0 | |
r, c = a.shape | |
0.upto(r-1) { |i| | |
0.upto(c-1) { |j| | |
if (b[i,j] == 0) | |
puts "error" | |
dist = 2**32 | |
break | |
end | |
dist += a[i,j] * Math::log(a[i,j]/b[i,j]) - a[i,j] + b[i,j] | |
} | |
} | |
dist | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment