Skip to content

Instantly share code, notes, and snippets.

@abikoushi
Created December 17, 2025 06:03
Show Gist options
  • Select an option

  • Save abikoushi/cb7619646d880fafecd237d22bb331db to your computer and use it in GitHub Desktop.

Select an option

Save abikoushi/cb7619646d880fafecd237d22bb331db to your computer and use it in GitHub Desktop.
VB-NMF for CSC sparse matrix format
module VI
using LogExpFunctions
using SparseArrays
using SpecialFunctions
using Distributions
# A = sparse([1, 1, 2, 3], [1, 3, 2, 3], [0, 1, 2, 0])
# A.m # Number of rows
# A.n # Number of columns
# A.colptr # Column j is in colptr[j]:(colptr[j+1]-1)
# A.rowval # Row indices of stored values
# A.nzval # Stored values, typically nonzeros
function klgamma_sub(a, b, c, d)
return - (c * d) / a - b * log(a) - loggamma(inv(b)) + (b-1)*(digamma(inv(d)) + log(c))
end
function kl2gamma(a1, b1, a2, b2)
return klgamma_sub(a2,b2,a2,b2) - klgamma_sub(a1,b1,a2,b2)
end
function kld(alpha, beta, a, b)
lp = 0
for i in axes(alpha,1)
for l in axes(alpha,2)
lp += kl2gamma(a, b, alpha[i,l], beta[l,1])
end
end
return lp
end
function NMF(Y::SparseMatrixCSC, lrank, iter)
ncols = Y.n
a = 1.0
b = 1.0
Z = rand(Gamma(1.0,1.0), Y.m, lrank)
W = rand(Gamma(1.0,1.0), lrank, Y.n)
logZ = log.(Z)
logW = log.(W)
alpha_z = zeros(Y.m, lrank)
alpha_w = zeros(lrank, Y.n)
beta_z = sum(W, dims=2) .+ b
beta_w = sum(Z, dims=1) .+ b
lp = zeros(iter)
for it in 1:iter
copy!(alpha_z, zero(alpha_z))
for cind in 1:ncols
for rind in Y.rowval[Y.colptr[cind]:(Y.colptr[cind+1]-1)]
R = exp.(logZ[rind,:] .+ logW[:,cind])
sumR = sum(R)
R = R ./ sumR
R .*= Y[rind, cind]
alpha_z[rind, :] += R
end
end
alpha_z .+= a
beta_z = sum(W, dims=2) .+ b
Z = alpha_z ./ beta_z'
logZ = digamma.(Z) .- log.(beta_z)'
lpt = 0
copy!(alpha_w, zero(alpha_w))
for cind in 1:ncols
for rind in Y.rowval[Y.colptr[cind]:(Y.colptr[cind+1]-1)]
R = exp.(logZ[rind,:] .+ logW[:,cind])
sumR = sum(R)
lpt += xlogy(Y[rind, cind], sum(R))
R = R ./ sumR
R .*= Y[rind, cind]
alpha_w[:, cind] += R
end
end
alpha_w .+= a
beta_w = sum(Z, dims=1)
lpt -= sum(beta_w)
beta_w .+= b
W = alpha_w ./ beta_w'
logW = digamma.(W) .- log.(beta_w)'
lp[it] = lpt + kld(alpha_z, beta_z, a, b) + kld(alpha_w', beta_w', a, b)
end
lp .-= sum(logfactorial, Y.nzval)
return Z, W, alpha_z, alpha_w, beta_z, beta_w, lp
end
end
using SparseArrays
using LinearAlgebra
using Distributions
using SpecialFunctions
using Plots
Z = rand(Gamma(1.0,1.0), 100, 3)
W = rand(Gamma(1.0,1.0), 3, 100)
Y = rand.(Poisson.(Z*W))
Y = sparse(Y)
@time fit = VI.NMF(Y, 3, 500)
scatter(vec(fit[1]*fit[2]), vec(Y), legend=false, alpha=0.4, markerstrokewidth=0)
Plots.abline!(1,0, linestyle=:dash, color=:grey)
plot(fit[7][2:end], legend=false)
fit[7]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment