Skip to content

Instantly share code, notes, and snippets.

@abikoushi
Created January 3, 2026 08:28
Show Gist options
  • Select an option

  • Save abikoushi/8509938c63b9f5a7c47043fe474eebcb to your computer and use it in GitHub Desktop.

Select an option

Save abikoushi/8509938c63b9f5a7c47043fe474eebcb to your computer and use it in GitHub Desktop.
reparameterized Poisson NMF
## reparameterized Poisson NMF
NMF_re <- function(X,
M=2,
a_W=1,
a_H=1,
alpha_b=1, beta_b=1,
tol=1e-3, maxit=100){
D <- nrow(X)
N <- ncol(X)
b_W <- rgamma(1, 1, 1)
lam <- rgamma(M, alpha_b, beta_b)
W <- matrix(rgamma(D*M,a_W,1),D,M)
W <- sweep(W,2,colSums(W),"/")
H <- matrix(rgamma(M*N,a_H,1),M,N)
H <- H/rowSums(H)
logW0 <-log(W)
logH0 <-log(H)
loglam <- log(lam)
for(k in 1:maxit){
pi <-lapply(1:M, function(m)exp(outer(loglam[m]+logW0[,m],logH0[m,],"+")))
den <-Reduce("+", pi)
pi_norm <-lapply(1:M, function(i){pi[[i]]/den})
S <- lapply(pi_norm, function(p){X*p})
alphahat = sapply(S,sum)+alpha_b
betahat = 1+b_W
loglam = digamma(alphahat) - log(betahat)
lam = alphahat/betahat
b_W = 1/(1+sum(lam))
ahat_W <-sapply(S,function(y)apply(y,1,sum))+a_W
W <-sweep(ahat_W, 2, colSums(ahat_W), "/")
logW <-sweep(digamma(ahat_W), 2, digamma(colSums(ahat_W)))
ahat_H <-t(sapply(S,function(y)apply(y,2,sum)))+a_H
H <-sweep(ahat_H, 1, rowSums(ahat_H),"/")
logH <-sweep(digamma(ahat_H), 1, digamma(rowSums(ahat_H)))
if(all(abs(logH-logH0)<tol) & all(abs(logW-logW0)<tol)){
break
}
logW0<-logW
logH0<-logH
}
return(list(lambdahat = lam, What = W, Hhat = H, bhat = b_W, iter=k))
}
N <- 100
K <- 100
L <- 3
W <- matrix(rgamma(N*L,shape=1,rate=1),N,L)
W <- sweep(W,2,colSums(W),"/")
H <- matrix(rgamma(L*K,shape=1,rate=1),L,K)
H <- H/rowSums(H)
lambda <- sort(rgamma(L, 1, 1e-5), decreasing = TRUE)
Y <- matrix(rpois(N*K,sweep(W,2,lambda,"*")%*%H),N,K)
image(Y)
system.time({
out_re <- NMF_re(X = Y, M=L, maxit=500)
})
print(out_re$iter)
out_re$lambda
plot(with(out_re, sweep(What,2,lambdahat,"*")%*%Hhat), Y, col=rgb(1,0.5,0,0.5), pch=1)
points(sweep(W,2,lambda,"*")%*%H, Y, col=rgb(0,0.5,1,0.5), pch=2)
abline(0,1,lty=2)
plot(out_re$What[,order(out_re$lambdahat)],W, main="W", xlab = "estimates", ylab = "true value")
abline(0,1,lty=2)
plot(out_re$H[order(out_re$lambdahat),],H, main="H", xlab = "estimates", ylab = "true value")
abline(0,1,lty=2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment