Skip to content

Instantly share code, notes, and snippets.

Last active April 27, 2023 18:59
Show Gist options
  • Save noamross/5d3339f45a334418b341b1bd75bee0d4 to your computer and use it in GitHub Desktop.
Save noamross/5d3339f45a334418b341b1bd75bee0d4 to your computer and use it in GitHub Desktop.
#reticulate::py_install("m2cgen") #
# Fit an XGBoost Model (example from ?xgb.train)
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
dtrain <- with(agaricus.train, xgb.DMatrix(data, label = label, nthread = 2))
dtest <- with(agaricus.test, xgb.DMatrix(data, label = label, nthread = 2))
watchlist <- list(train = dtrain, eval = dtest)
param <- list(max_depth = 2, eta = 1, verbose = 0, nthread = 2,
objective = "binary:logistic", eval_metric = "auc")
bst <- xgb.train(param, dtrain, nrounds = 2, watchlist)
#> [14:51:40] WARNING: src/
#> Parameters: { "verbose" } are not used.
#> [1] train-auc:0.958228 eval-auc:0.960373
#> [2] train-auc:0.981413 eval-auc:0.979930
#Save to disk, "bst.xgb")
#> [1] TRUE
# Start a Python session
m2 <- reticulate::import("m2cgen")
xg <- reticulate::import("xgboost")
# Load the model from disk
mod <- xg$XGBRegressor()
mod$load_model(fname = "bst.xgb")
# Convert to a raw R function
eval(parse(text = m2$export_to_r(mod, function_name = "pred")))
#> function(input) {
#> if (input[29] >= -0.0000009536743) {
#> if (input[109] >= -0.0000009536743) {
#> var0 <- 1.8596492
#> } else {
#> var0 <- -1.9407086
#> }
#> } else {
#> if (input[56] >= -0.0000009536743) {
#> var0 <- -1.7004405
#> } else {
#> var0 <- 1.7121772
#> }
#> }
#> if (input[60] >= -0.0000009536743) {
#> var1 <- -6.2362447
#> } else {
#> if (input[29] >= -0.0000009536743) {
#> var1 <- -0.96853036
#> } else {
#> var1 <- 0.78471756
#> }
#> }
#> return(nan + (var0 + var1))
#> }
nan <- 0 # For some reason `nan` is added to the function output but it is undefined in R
# Predict with the original model
predict(bst,as.matrix(agaricus.test$data)[1,,drop = FALSE])
#> [1] 0.01241208
# Predict with the raw code
plogis(pred(as.matrix(agaricus.test$data[1, ])))
#> [1] 0.01241208
# Export as a Javascript function
cat(m2$export_to_javascript(mod, function_name = "pred"))
#> function pred(input) {
#> var var0;
#> if (input[28] >= -0.0000009536743) {
#> if (input[108] >= -0.0000009536743) {
#> var0 = 1.8596492;
#> } else {
#> var0 = -1.9407086;
#> }
#> } else {
#> if (input[55] >= -0.0000009536743) {
#> var0 = -1.7004405;
#> } else {
#> var0 = 1.7121772;
#> }
#> }
#> var var1;
#> if (input[59] >= -0.0000009536743) {
#> var1 = -6.2362447;
#> } else {
#> if (input[28] >= -0.0000009536743) {
#> var1 = -0.96853036;
#> } else {
#> var1 = 0.78471756;
#> }
#> }
#> return nan + (var0 + var1);
#> }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment