Last active
September 14, 2020 14:04
-
-
Save shanebutler/96f0e78a02c84cdcf558 to your computer and use it in GitHub Desktop.
Deploy your RandomForest models in SQL! This tool enables in-database scoring of Random Forest models built using R. To use it, you simply call the function with the Random Forest model, output filename, SQL input data table and the name of the unique key on that table. For example:sql.export.rf(rf.mdl, file="model_output.SQL", input.table="sour…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# sql.export.rf(): save a randomForest model as SQL | |
# v0.04 | |
# Copyright (c) 2013-2014 Shane Butler <shane dot butler at gmail dot com> | |
# | |
# sql.export.rf is free software: you can redistribute it and/or modify it | |
# under the terms of the GNU General Public License as published by | |
# the Free Software Foundation, either version 2 of the License, or | |
# (at your option) any later version. | |
# | |
# sql.export.rf is distributed in the hope that it will be useful, but | |
# WITHOUT ANY WARRANTY; without even the implied warranty of | |
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU | |
# General Public License for more details. | |
# | |
# You should have received a copy of the GNU General Public License | |
# along with sql.export.rf. If not, see <http://www.gnu.org/licenses/>. | |
# | |
# | |
## NOTE: | |
# This code generates SQL scoring code from your randomForest model. | |
# Currently the generated code is not optimal since it makes as many | |
# passes over the input data as there are trees (ie. if there are 500 | |
# trees there will be 500 INSERT... SELECT statements) | |
# | |
## USAGE: | |
# sql.export.rf(rf1, file="model_output.SQL", input.table="data", id="id") | |
# | |
## ARGUMENTS: | |
# variant: Optional argument for Teradata variant="teradata" | |
# | |
sql.export.rf <- function (model, file, input.table="source_table", | |
id="id", | |
variant="generic") { | |
require (randomForest, quietly=TRUE) | |
if (!("randomForest" %in% class(model))) { | |
stop ("Expected a randomForest object") | |
return | |
} | |
sink(file, type="output") | |
if (model$type == "classification") { | |
pred.type <- "VARCHAR" | |
} else { | |
pred.type <- "FLOAT" | |
} | |
if (variant == "teradata") { | |
cat(paste("CREATE VOLATILE TABLE rf_predictions (\n", | |
"\t",id," INT NOT NULL,\n", | |
"\tpred ",pred.type,"\n", | |
") ON COMMIT PRESERVE ROWS;\n\n", | |
"CREATE MULTISET VOLATILE TABLE tmp_rf (\n", | |
"\t",id," INT NOT NULL,\n", | |
"\tpred ",pred.type,"\n", | |
") ON COMMIT PRESERVE ROWS;\n\n",sep="")) | |
} else { | |
cat(paste("CREATE TABLE rf_predictions (\n", | |
"\t",id," INT NOT NULL,\n", | |
"\tpred ",pred.type,"\n", | |
");\n\n", | |
"DROP TABLE IF EXISTS tmp_rf;\n\n", | |
"CREATE TABLE tmp_rf (\n", | |
"\t",id," INT NOT NULL,\n", | |
"\tpred ",pred.type,"\n", | |
");\n\n",sep="")) | |
} | |
for (tree.num in 1:(model$ntree)) { | |
cat(paste("INSERT INTO tmp_rf\nSELECT ",id,",", sep="")) | |
recurse.rf <- function(model, tree.data, tree.row.num, ind=0) { | |
tree.row <- tree.data[tree.row.num,] | |
indent.str <- paste(rep("\t", ind), collapse="") | |
split.var <- as.character(tree.row[,"split var"]) | |
split.point <- tree.row[,"split point"] | |
if(tree.row[,"status"] != -1) { # splitting node | |
if(is.numeric(unlist(model$forest$xlevels[split.var]))) { | |
cat(paste("\n",indent.str,"CASE WHEN", gsub("[.]","_",split.var), "IS NULL THEN NULL", | |
"\n",indent.str,"WHEN", gsub("[.]","_",split.var), "<=", split.point, "THEN ")) | |
recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1)) | |
cat("\n",indent.str,"ELSE ") | |
recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1)) | |
cat("END ") | |
} else { # categorical | |
# function to convert from binary coding to the category values it represents | |
conv.to.binary <- function (ncat, num.to.convert) { | |
ret <- numeric() | |
if((2^ncat) <= num.to.convert) { | |
return (NULL) | |
} else { | |
for (x in (ncat - 1):0) { | |
if (num.to.convert >= (2^x)) { | |
num.to.convert <- num.to.convert - (2^x) | |
ret <- c(ret, 1) | |
} else { | |
ret <- c(ret, 0) | |
} | |
} | |
return(ret) | |
} | |
} | |
categ.bin <- conv.to.binary(model$forest$ncat[split.var], split.point) | |
categ.flags <- (categ.bin[length(categ.bin):1] == 1) | |
categ.values <- unlist(model$forest$xlevels[split.var]) | |
cat(paste("\n",indent.str,"CASE WHEN ", gsub("[.]","_",split.var), " IN ('", | |
paste(categ.values[categ.flags], sep="", collapse="', '"), #FIXME replace quotes dependant on var type | |
"') THEN ", sep="")) | |
recurse.rf(model, tree.data, tree.row[,"left daughter"], ind=(ind+1)) | |
cat(paste("\n",indent.str,"WHEN ", gsub("[.]","_",split.var), " IN ('", | |
paste(categ.values[!categ.flags], sep="", collapse="', '"), | |
"') THEN ", sep="")) | |
recurse.rf(model, tree.data, tree.row[,"right daughter"], ind=(ind+1)) | |
cat(paste("\n", indent.str,"ELSE NULL END ", sep="")) #FIXME: null or a new category | |
} | |
} else { # terminal node | |
if (is.numeric(tree.data$prediction)) { | |
cat(paste(tree.row[,"prediction"], " ", sep="")) | |
} else { | |
cat(paste("'", tree.row[,"prediction"], "' ", sep="")) | |
} | |
} | |
} | |
recurse.rf(model, getTree(model,k=tree.num,labelVar=TRUE), 1) | |
cat(paste("as tree",tree.num,"\nFROM ",input.table,";\n\n", sep="")) | |
} | |
if (model$type == "classification") { | |
# This code is not optimal but many SQL implementations do not support window functions (eg. SQLite) | |
# Had to remove use of WITH because not supported by all SQL variants | |
cat(paste("INSERT INTO rf_predictions\n", | |
"SELECT a.id, a.pred\n", | |
"FROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) a\n", | |
"INNER JOIN (SELECT id, MAX(cnt) as cnt\n", | |
"\t\t\tFROM (SELECT ",id," as id, pred, COUNT(*) as cnt FROM tmp_rf GROUP BY ",id,", pred) cc\n", | |
"\t\t\tGROUP BY id) b\n", | |
"ON a.id = b.id AND a.cnt = b.cnt;\n\n", sep="")) | |
} else { | |
cat(paste("INSERT INTO rf_predictions\n", | |
"SELECT ",id,", AVG(pred)\n", | |
"FROM tmp_rf\n", | |
"GROUP BY ",id,";\n\n", sep="")) | |
} | |
if (variant == "teradata") { | |
cat("DROP TABLE tmp_rf;\n\n") | |
} else { | |
cat("DROP TABLE IF EXISTS tmp_rf;\n\n") | |
} | |
# close the file | |
sink() | |
} |
it gives an error because there are too many "case";
Msg 125, Level 15, State 4, Line 20
Case expressions may only be nested to level 10.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
great, thankyou @ras44!