Skip to content

Instantly share code, notes, and snippets.

@shanebutler
Last active September 14, 2020 14:04
Show Gist options
  • Save shanebutler/96f0e78a02c84cdcf558 to your computer and use it in GitHub Desktop.
Save shanebutler/96f0e78a02c84cdcf558 to your computer and use it in GitHub Desktop.
Deploy your RandomForest models in SQL! This tool enables in-database scoring of Random Forest models built using R. To use it, you simply call the function with the Random Forest model, output filename, SQL input data table and the name of the unique key on that table. For example:sql.export.rf(rf.mdl, file="model_output.SQL", input.table="sour…
# sql.export.rf(): save a randomForest model as SQL
# v0.04
# Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com>
#
# sql.export.rf is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# sql.export.rf is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with sql.export.rf. If not, see <http://www.gnu.org/licenses/>.
#
#
## NOTE:
# This code generates SQL scoring code from your randomForest model.
# Currently the generated code is not optimal since it makes as many
# passes over the input data as there are trees (ie. if there are 500
# trees there will be 500 INSERT... SELECT statements)
#
## USAGE:
# sql.export.rf(rf1, file="model_output.SQL", input.table="data", id="id")
#
## ARGUMENTS:
# variant: Optional argument for Teradata variant="teradata"
#
sql.export.rf <- function (model, file, input.table="source_table",
id="id",
variant="generic") {
require (randomForest, quietly=TRUE)
if (!("randomForest" %in% class(model))) {
stop ("Expected a randomForest object")
return
}
sink(file, type="output")
if (model$type == "classification") {
pred.type <- "VARCHAR"
} else {
pred.type <- "FLOAT"
}
if (variant == "teradata") {
cat(paste("CREATE VOLATILE TABLE rf_predictions (\n",
"\t",id," INT NOT NULL,\n",
"\tpred ",pred.type,"\n",
") ON COMMIT PRESERVE ROWS;\n\n",
"CREATE MULTISET VOLATILE TABLE tmp_rf (\n",
"\t",id," INT NOT NULL,\n",
"\tpred ",pred.type,"\n",
") ON COMMIT PRESERVE ROWS;\n\n",sep=""))
} else {
cat(paste("CREATE TABLE rf_predictions (\n",
"\t",id," INT NOT NULL,\n",
"\tpred ",pred.type,"\n",
");\n\n",
"DROP TABLE IF EXISTS tmp_rf;\n\n",
"CREATE TABLE tmp_rf (\n",
"\t",id," INT NOT NULL,\n",
"\tpred ",pred.type,"\n",
");\n\n",sep=""))
}
for (tree.num in 1:(model$ntree)) {
cat(paste("INSERT INTO tmp_rf\nSELECT ",id,",", sep=""))
recurse.rf <- function(model, tree.data, tree.row.num, ind=0) {
tree.row <- tree.data[tree.row.num,]
indent.str <- paste(rep("\t", ind), collapse="")
split.var <- as.character(tree.row[,"split var"])
split.point <- tree.row[,"split point"]
if(tree.row[,"status"] != -1) { # splitting node
if(is.numeric(unlist(model$forest$xlevels[split.var]))) {
cat(paste("\n",indent.str,"CASE WHEN", gsub("[.]","_",split.var), "IS NULL THEN NULL",
"\n",indent.str,"WHEN", gsub("[.]","_",split.var), "<=", split.point, "THEN "))
recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1))
cat("\n",indent.str,"ELSE ")
recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1))
cat("END ")
} else { # categorical
# function to convert from binary coding to the category values it represents
conv.to.binary <- function (ncat, num.to.convert) {
ret <- numeric()
if((2^ncat) <= num.to.convert) {
return (NULL)
} else {
for (x in (ncat - 1):0) {
if (num.to.convert >= (2^x)) {
num.to.convert <- num.to.convert - (2^x)
ret <- c(ret, 1)
} else {
ret <- c(ret, 0)
}
}
return(ret)
}
}
categ.bin <- conv.to.binary(model$forest$ncat[split.var], split.point)
categ.flags <- (categ.bin[length(categ.bin):1] == 1)
categ.values <- unlist(model$forest$xlevels[split.var])
cat(paste("\n",indent.str,"CASE WHEN ", gsub("[.]","_",split.var), " IN ('",
paste(categ.values[categ.flags], sep="", collapse="', '"), #FIXME replace quotes dependant on var type
"') THEN ", sep=""))
recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1))
cat(paste("\n",indent.str,"WHEN ", gsub("[.]","_",split.var), " IN ('",
paste(categ.values[!categ.flags], sep="", collapse="', '"),
"') THEN ", sep=""))
recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1))
cat(paste("\n", indent.str,"ELSE NULL END ", sep="")) #FIXME: null or a new category
}
} else { # terminal node
if (is.numeric(tree.data$prediction)) {
cat(paste(tree.row[,"prediction"], " ", sep=""))
} else {
cat(paste("'", tree.row[,"prediction"], "' ", sep=""))
}
}
}
recurse.rf(model, getTree(model,k=tree.num,labelVar=TRUE), 1)
cat(paste("as tree",tree.num,"\nFROM ",input.table,";\n\n", sep=""))
}
if (model$type == "classification") {
# This code is not optimal but many SQL implementations do not support window functions (eg. SQLite)
# Had to remove use of WITH because not supported by all SQL variants
cat(paste("INSERT INTO rf_predictions\n",
"SELECT a.id, a.pred\n",
"FROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) a\n",
"INNER JOIN (SELECT id, MAX(cnt) as cnt\n",
"\t\t\tFROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) cc\n",
"\t\t\tGROUP BY id) b\n",
"ON a.id = b.id AND a.cnt = b.cnt;\n\n", sep=""))
} else {
cat(paste("INSERT INTO rf_predictions\n",
"SELECT ",id,", AVG(pred)\n",
"FROM tmp_rf\n",
"GROUP BY ",id,";\n\n", sep=""))
}
if (variant == "teradata") {
cat("DROP TABLE tmp_rf;\n\n")
} else {
cat("DROP TABLE IF EXISTS tmp_rf;\n\n")
}
# close the file
sink()
}
@pjankiewicz
Copy link

I was looking for a way to export R randomForest object. This is a smart approach. It probably would me more convenient to create an SQL function that accepts the variables as arguments. Eventually I would recommend using pmml R package to export randomForest to a standardized XML file.

@shanebutler
Copy link
Author

@pjankiewicz a function would probably be more efficient however there is no standard for SQL functions that would work across multiple platforms

@Clls1
Copy link

Clls1 commented Aug 25, 2017

that is very useful. However, how can i get the probabilities instead of the output 0/1?

@xavierjl70
Copy link

How do you get the probabilities from Random Forest as the member levels or record levels?
Thank you.

@xavierjl70
Copy link

How do you get the probabilities from Random Forest as the member levels or record levels in R I meant?
Thank you.

@ras44
Copy link

ras44 commented Oct 18, 2018

for anyone interested in doing something similar with xgboost, here you go: https://github.com/ras44/articles/blob/master/20181018_xgboost_scoring_via_sql.md

@shanebutler
Copy link
Author

for anyone interested in doing something similar with xgboost, here you go: https://github.com/ras44/articles/blob/master/20181018_xgboost_scoring_via_sql.md

great, thankyou @ras44!

@nezorepla
Copy link

it gives an error because there are too many "case";

Msg 125, Level 15, State 4, Line 20
Case expressions may only be nested to level 10.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment