Skip to content

Instantly share code, notes, and snippets.

@wbzyl
Forked from mfdela/nmf.rb
Last active August 29, 2015 14:09
Show Gist options
  • Save wbzyl/7468da035578d21a5fee to your computer and use it in GitHub Desktop.
Save wbzyl/7468da035578d21a5fee to your computer and use it in GitHub Desktop.
#!/usr/bin/ruby
require("gsl")
def nmf(v, col, thresh)
r, c = v.shape
# w * h = v
w = GSL::Matrix.alloc(r, col)
h = GSL::Matrix.alloc(col, c)
initrand(w, w.max)
initrand(h, h.max)
dist = thresh
iter = 0
while(dist >= thresh && iter < 1000)
# multiplicative update rule
dist, w, h = update(v, w, h)
puts "#{iter}: #{dist}" #if (iter%10 == 0)
iter += 1
end
puts "Ended in #{iter} iteration(s) with dist=#{dist}"
return w, h
end
def initrand(m, max)
r, c = m.shape
0.upto(r-1) { |i|
0.upto(c-1) { |j|
m[i,j] = rand(max)
}
}
end
def update(v, w, h)
# choose an update rule
# v, w, h are GSL::Matrix objects
dist, w, h = update_eu(v, w, h)
return dist, w, h
end
def update_eu(v, w, h)
# multiplicative update rule
# for minimizing euclidean distance
# v, w, h are GSL::Matrix objects
dist = dist_eu(v, w*h)
wt = w.transpose
ht = h.transpose
h = h.mul_elements(wt * v).div_elements(wt * w * h)
w = w.mul_elements(v * ht).div_elements(w * h * ht)
return dist, w, h
end
def update_kl(v, w, h)
# multiplicative update rule
# for minimizing divergence
# v, w, h are GSL::Matrix objects
dist = dist_kl(v, w*h)
wt = w.transpose
ht = h.transpose
num = v.div_elements(w * h)
rh, ch = h.shape
rw, cw = w.shape
0.upto(rh-1) { |i|
0.upto(ch-1) { |j|
h[i, j] *= w.col(i).mul(num.col(j)).sum / w.col(i).sum
}
}
0.upto(rw-1) { |i|
0.upto(cw-1) { |j|
w[i, j] *= h.row(j).mul(num.row(i)).sum / h.row(j).sum
}
}
return dist, w, h
end
# Euclidean distance
def dist_eu(a, b)
# a,b are GSL::Matrix objects
dist = 0
(a-b).collect { |d| dist += d * d}
dist
end
# Kullback-Leibler divergence
def dist_kl(a, b)
# a, b are GSL::Matrix objects
dist = 0
r, c = a.shape
0.upto(r-1) { |i|
0.upto(c-1) { |j|
if (b[i,j] == 0)
puts "error"
dist = 2**32
break
end
dist += a[i,j] * Math::log(a[i,j]/b[i,j]) - a[i,j] + b[i,j]
}
}
dist
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment