Skip to content

Instantly share code, notes, and snippets.

@abikoushi
Last active June 6, 2024 15:01
Show Gist options
  • Save abikoushi/b3b19a530172a02e66b85cbc94ce3030 to your computer and use it in GitHub Desktop.
Save abikoushi/b3b19a530172a02e66b85cbc94ce3030 to your computer and use it in GitHub Desktop.
Sparse poisson regression (proto-type)
####
#Negative sampling
####
library(Matrix)
library(ggplot2)
##tau = prior param.
#grad
dlogll <- function(lambda,Y,X,W,tau){
as.vector(t(X)%*%(lambda-Y)) + 0.5*tau*W
}
#hessian
d2logll <- function(lambda,X,W,tau){
t(X)%*%sweep(X,1,lambda,"*") + diag(0.5*tau, ncol(X))
}
#0-sampling
dlogll0 <- function(lambda,X,W,tau){
as.vector(t(X)%*%(lambda)) + 0.5*tau*W
}
poisreg0 <- function(Y,X,tau,iter,lr=1){
N <- length(Y)
D <- ncol(X)
W <- numeric(D)
ll <- numeric(iter)
#pb <- txtProgressBar(min = 1, max = iter, style = 3)
for(i in 1:iter){
lambda <- as.vector(exp(X%*%W))
ll[i] <- sum(dpois(Y,lambda,log = TRUE))
g <- dlogll(lambda,Y,X,W,tau)
H <- d2logll(lambda,X,W,tau)
B <- solve(H, g)
W <- W - lr*B
#setTxtProgressBar(pb, i)
}
return(list(W=W,ll=ll,H=H))
}
poisreg_sp <- function(Y,X,m,tau,iter,lr=1){
N <- length(Y)
D <- ncol(X)
W <- numeric(D)
ll <- numeric(iter)
#pb <- txtProgressBar(min = 1, max = iter, style = 3)
for(i in 1:iter){
lambda <- as.vector(exp(X%*%W))
ll[i] <- sum(dpois(Y,lambda,log = TRUE))
g <- dlogll(lambda,Y,X,W,tau)
H <- d2logll(lambda,X,W,tau)
##
X1 <- t(rmultinom(m,1,prob = rep(1/20,20)))
X2 <- t(rmultinom(m,1,prob = rep(1/20,20)))
Xs <- cbind(1,X1[,-1],X2[,-1])
lam0 <- as.vector(exp(Xs%*%W))
ll[i] <- -sum(lam0)
g <- g + dlogll0(lam0,Xs,W,tau)
H <- H + d2logll(lam0,Xs,W,tau)
###
B <- solve(H, g)
W <- W - lr*B
#setTxtProgressBar(pb, i)
}
return(list(W=W,ll=ll,H=H))
}
df0 <- expand.grid(a = factor(1:20), b = factor(1:20))
X <- model.matrix(~a+b, data=df0)
B <- rnorm(ncol(X))
Y <- rpois(nrow(X),exp(X%*%B))
df1 <- df0
df1$Y <- Y
df1 <- df1[df1$Y>0,]
X1 <- sparse.model.matrix(~a+b, data=df1)
m <- nrow(df0) - nrow(df1)
system.time(
out <- poisreg0(Y,X,tau = 0.1,iter = 50,lr=0.5)
)
system.time(
out_sp <- poisreg_sp(Y,X,m,tau=0.1,iter = 50,lr=0.5)
)
plot(out$ll,type = "l",col="darkorange")
plot(out_sp$ll,type = "l",col="royalblue")
df <- data.frame(true=B,est1=out$W,est2=out_sp$W)
ggplot(df,aes(x=true))+
geom_linerange(aes(ymin=est1,ymax=est2),alpha=0.7)+
geom_point(aes(y=est1),alpha=0.7, colour="darkorange")+
geom_point(aes(y=est2),alpha=0.7, colour="royalblue")+
geom_abline(intercept=0,slope=1,linetype=2)+
theme_bw()+labs(y="estimates")
sqrt(mean((B-out$W)^2))
sqrt(mean((B-out_sp$W)^2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment