Last active
April 27, 2023 18:59
-
-
Save noamross/5d3339f45a334418b341b1bd75bee0d4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(xgboost) | |
library(reticulate) | |
#reticulate::py_install("m2cgen") # https://github.com/BayesWitnesses/m2cgen | |
#reticulate::py_install("xgboost") | |
#reticulate::py_install("sklearn") | |
# Fit an XGBoost Model (example from ?xgb.train) | |
data(agaricus.train, package='xgboost') | |
data(agaricus.test, package='xgboost') | |
dtrain <- with(agaricus.train, xgb.DMatrix(data, label = label, nthread = 2)) | |
dtest <- with(agaricus.test, xgb.DMatrix(data, label = label, nthread = 2)) | |
watchlist <- list(train = dtrain, eval = dtest) | |
param <- list(max_depth = 2, eta = 1, verbose = 0, nthread = 2, | |
objective = "binary:logistic", eval_metric = "auc") | |
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist) | |
#> [14:51:40] WARNING: src/learner.cc:767: | |
#> Parameters: { "verbose" } are not used. | |
#> | |
#> [1] train-auc:0.958228 eval-auc:0.960373 | |
#> [2] train-auc:0.981413 eval-auc:0.979930 | |
#Save to disk | |
xgboost::xgb.save(bst, "bst.xgb") | |
#> [1] TRUE | |
# Start a Python session | |
m2 <- reticulate::import("m2cgen") | |
xg <- reticulate::import("xgboost") | |
# Load the model from disk | |
mod <- xg$XGBRegressor() | |
mod$load_model(fname = "bst.xgb") | |
# Convert to a raw R function | |
eval(parse(text = m2$export_to_r(mod, function_name = "pred"))) | |
pred | |
#> function(input) { | |
#> if (input[29] >= -0.0000009536743) { | |
#> if (input[109] >= -0.0000009536743) { | |
#> var0 <- 1.8596492 | |
#> } else { | |
#> var0 <- -1.9407086 | |
#> } | |
#> } else { | |
#> if (input[56] >= -0.0000009536743) { | |
#> var0 <- -1.7004405 | |
#> } else { | |
#> var0 <- 1.7121772 | |
#> } | |
#> } | |
#> if (input[60] >= -0.0000009536743) { | |
#> var1 <- -6.2362447 | |
#> } else { | |
#> if (input[29] >= -0.0000009536743) { | |
#> var1 <- -0.96853036 | |
#> } else { | |
#> var1 <- 0.78471756 | |
#> } | |
#> } | |
#> return(nan + (var0 + var1)) | |
#> } | |
nan <- 0 # For some reason `nan` is added to the function output but it is undefined in R | |
# Predict with the original model | |
predict(bst,as.matrix(agaricus.test$data)[1,,drop = FALSE]) | |
#> [1] 0.01241208 | |
# Predict with the raw code | |
plogis(pred(as.matrix(agaricus.test$data[1, ]))) | |
#> [1] 0.01241208 | |
# Export as a Javascript function | |
cat(m2$export_to_javascript(mod, function_name = "pred")) | |
#> function pred(input) { | |
#> var var0; | |
#> if (input[28] >= -0.0000009536743) { | |
#> if (input[108] >= -0.0000009536743) { | |
#> var0 = 1.8596492; | |
#> } else { | |
#> var0 = -1.9407086; | |
#> } | |
#> } else { | |
#> if (input[55] >= -0.0000009536743) { | |
#> var0 = -1.7004405; | |
#> } else { | |
#> var0 = 1.7121772; | |
#> } | |
#> } | |
#> var var1; | |
#> if (input[59] >= -0.0000009536743) { | |
#> var1 = -6.2362447; | |
#> } else { | |
#> if (input[28] >= -0.0000009536743) { | |
#> var1 = -0.96853036; | |
#> } else { | |
#> var1 = 0.78471756; | |
#> } | |
#> } | |
#> return nan + (var0 + var1); | |
#> } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment