Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Created November 7, 2022 10:37
Show Gist options
  • Save vankesteren/804672941e54581534c2ac71c60d66bf to your computer and use it in GitHub Desktop.
Save vankesteren/804672941e54581534c2ac71c60d66bf to your computer and use it in GitHub Desktop.
LightGBM for general utility
library(synthpop)
library(lightgbm)
library(tidyverse)
df <- SD2011
syn_res <- syn(df[,1:7])
sdf <- syn_res$syn
train_params <- list(
learning_rate = 1.0,
objective = "binary",
nthread = 6L,
force_row_wise = TRUE
)
pmse_dat <- bind_rows(df[,1:7], sdf) |> mutate(label = rep(c(0L, 1L), each = nrow(df)))
y <- model.response(model.frame(label ~ ., data = pmse_dat))
X <- model.matrix(label ~ ., data = pmse_dat)
Xs <- model.matrix(label ~ ., data = pmse_dat |> filter(label == 1))
lgbdata <- lgb.Dataset(data = unname(X), label = y)
n_sim <- 20L
max_nrounds <- 10000L
Pmat <- matrix(0.0, nrow(Xs), n_sim)
round_seq <- round(seq(1, max_nrounds, length.out = n_sim))
for (i in seq_along(round_seq)) {
cat("iteration", i, "\n")
fit <- lgb.train(params = train_params, data = lgbdata, nrounds = round_seq[i],
verbose = 0L)
Pmat[,i] <- predict(fit, Xs)
}
pMSEs <- colMeans((Pmat - nrow(Xs) / nrow(X))^2)
plot(round_seq, pMSEs, type = "l")
@vankesteren
Copy link
Author

afbeelding

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment