Created
December 17, 2025 06:03
-
-
Save abikoushi/cb7619646d880fafecd237d22bb331db to your computer and use it in GitHub Desktop.
VB-NMF for CSC sparse matrix format
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| module VI | |
| using LogExpFunctions | |
| using SparseArrays | |
| using SpecialFunctions | |
| using Distributions | |
| # A = sparse([1, 1, 2, 3], [1, 3, 2, 3], [0, 1, 2, 0]) | |
| # A.m # Number of rows | |
| # A.n # Number of columns | |
| # A.colptr # Column j is in colptr[j]:(colptr[j+1]-1) | |
| # A.rowval # Row indices of stored values | |
| # A.nzval # Stored values, typically nonzeros | |
| function klgamma_sub(a, b, c, d) | |
| return - (c * d) / a - b * log(a) - loggamma(inv(b)) + (b-1)*(digamma(inv(d)) + log(c)) | |
| end | |
| function kl2gamma(a1, b1, a2, b2) | |
| return klgamma_sub(a2,b2,a2,b2) - klgamma_sub(a1,b1,a2,b2) | |
| end | |
| function kld(alpha, beta, a, b) | |
| lp = 0 | |
| for i in axes(alpha,1) | |
| for l in axes(alpha,2) | |
| lp += kl2gamma(a, b, alpha[i,l], beta[l,1]) | |
| end | |
| end | |
| return lp | |
| end | |
| function NMF(Y::SparseMatrixCSC, lrank, iter) | |
| ncols = Y.n | |
| a = 1.0 | |
| b = 1.0 | |
| Z = rand(Gamma(1.0,1.0), Y.m, lrank) | |
| W = rand(Gamma(1.0,1.0), lrank, Y.n) | |
| logZ = log.(Z) | |
| logW = log.(W) | |
| alpha_z = zeros(Y.m, lrank) | |
| alpha_w = zeros(lrank, Y.n) | |
| beta_z = sum(W, dims=2) .+ b | |
| beta_w = sum(Z, dims=1) .+ b | |
| lp = zeros(iter) | |
| for it in 1:iter | |
| copy!(alpha_z, zero(alpha_z)) | |
| for cind in 1:ncols | |
| for rind in Y.rowval[Y.colptr[cind]:(Y.colptr[cind+1]-1)] | |
| R = exp.(logZ[rind,:] .+ logW[:,cind]) | |
| sumR = sum(R) | |
| R = R ./ sumR | |
| R .*= Y[rind, cind] | |
| alpha_z[rind, :] += R | |
| end | |
| end | |
| alpha_z .+= a | |
| beta_z = sum(W, dims=2) .+ b | |
| Z = alpha_z ./ beta_z' | |
| logZ = digamma.(Z) .- log.(beta_z)' | |
| lpt = 0 | |
| copy!(alpha_w, zero(alpha_w)) | |
| for cind in 1:ncols | |
| for rind in Y.rowval[Y.colptr[cind]:(Y.colptr[cind+1]-1)] | |
| R = exp.(logZ[rind,:] .+ logW[:,cind]) | |
| sumR = sum(R) | |
| lpt += xlogy(Y[rind, cind], sum(R)) | |
| R = R ./ sumR | |
| R .*= Y[rind, cind] | |
| alpha_w[:, cind] += R | |
| end | |
| end | |
| alpha_w .+= a | |
| beta_w = sum(Z, dims=1) | |
| lpt -= sum(beta_w) | |
| beta_w .+= b | |
| W = alpha_w ./ beta_w' | |
| logW = digamma.(W) .- log.(beta_w)' | |
| lp[it] = lpt + kld(alpha_z, beta_z, a, b) + kld(alpha_w', beta_w', a, b) | |
| end | |
| lp .-= sum(logfactorial, Y.nzval) | |
| return Z, W, alpha_z, alpha_w, beta_z, beta_w, lp | |
| end | |
| end | |
| using SparseArrays | |
| using LinearAlgebra | |
| using Distributions | |
| using SpecialFunctions | |
| using Plots | |
| Z = rand(Gamma(1.0,1.0), 100, 3) | |
| W = rand(Gamma(1.0,1.0), 3, 100) | |
| Y = rand.(Poisson.(Z*W)) | |
| Y = sparse(Y) | |
| @time fit = VI.NMF(Y, 3, 500) | |
| scatter(vec(fit[1]*fit[2]), vec(Y), legend=false, alpha=0.4, markerstrokewidth=0) | |
| Plots.abline!(1,0, linestyle=:dash, color=:grey) | |
| plot(fit[7][2:end], legend=false) | |
| fit[7] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment