Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save DexGroves/279e5307bfd4673d4db3 to your computer and use it in GitHub Desktop.
Save DexGroves/279e5307bfd4673d4db3 to your computer and use it in GitHub Desktop.
xgboost creating different model structures based on choice of sparsity
library("data.table")
library("xgboost")
library("Matrix")
generate_data <- function(N) {
data.table(
response = as.numeric(runif(N) > 0.8),
float1 = rnorm(N, 3, 3))
}
N <- 1000
set.seed(1236)
train <- generate_data(N)
# Set one row to exactly zero or else nothing weird happens
# Comment out this line and model structure is the same for sparse/dense
train[1, float1 := 0]
mm_train <- model.matrix(response ~ float1, train)
smm_train <- sparse.model.matrix(response ~ float1, train)
dtrain <- xgb.DMatrix(data = mm_train, label = train[, response])
dtrain_sparse <- xgb.DMatrix(data = smm_train, label = train[, response])
model <- xgb.train(params = list(eta = 1,
max_depth = 1,
min_child_weight = 10,
subsample = 1.0,
objective = "binary:logistic",
eval_metric = "logloss"),
data = dtrain,
nrounds = 1)
model_sparse <- xgb.train(params = list(eta = 1,
max_depth = 1,
min_child_weight = 10,
subsample = 1.0,
objective = "binary:logistic",
eval_metric = "logloss"),
data = dtrain_sparse,
nrounds = 1)
# Different leaf predictions for sparse and dense encodings
xgb.dump(model = model)
# [1] "booster[0]" "0:[f1<6.8692] yes=1,no=2,missing=1"
# [3] "1:leaf=-1.12623" "2:leaf=-1.54639"
xgb.dump(model = model_sparse)
# [1] "booster[0]" "0:[f1<6.8692] yes=1,no=2,missing=2"
# [3] "1:leaf=-1.12527" "2:leaf=-1.55102"
sessionInfo()
# R version 3.2.2 (2015-08-14)
# Platform: x86_64-pc-linux-gnu (64-bit)
# Running under: Ubuntu 15.10
#
# locale:
# [1] LC_CTYPE=en_IE.UTF-8 LC_NUMERIC=C
# [3] LC_TIME=en_IE.UTF-8 LC_COLLATE=en_IE.UTF-8
# [5] LC_MONETARY=en_IE.UTF-8 LC_MESSAGES=en_IE.UTF-8
# [7] LC_PAPER=en_IE.UTF-8 LC_NAME=C
# [9] LC_ADDRESS=C LC_TELEPHONE=C
# [11] LC_MEASUREMENT=en_IE.UTF-8 LC_IDENTIFICATION=C
#
# attached base packages:
# [1] stats graphics grDevices utils datasets methods base
#
# other attached packages:
# [1] Matrix_1.2-2 xgboost_0.4-3 setwidth_1.0-4 colorout_1.1-1
# [5] magrittr_1.5 data.table_1.9.6 devtools_1.9.1
#
# loaded via a namespace (and not attached):
# [1] tools_3.2.2 memoise_0.2.1 stringi_1.0-1 grid_3.2.2
# [5] stringr_1.0.0 digest_0.6.8 chron_2.3-47 lattice_0.20-33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment