Last active
August 29, 2015 14:04
-
-
Save zmjones/c2d2ff4606b027d4ae0a to your computer and use it in GitHub Desktop.
parallel calculation of marginal or joint dependence of explanatory variables from a party random forest
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pkgs <- c("party", "parallel") | |
invisible(lapply(pkgs, require, character.only = TRUE)) | |
partial_dependence <- function(fit, ivar, cores = 1, ...) { | |
## calculates the partial dependence of the response on explanatory variable(s) | |
## fit must be a party object | |
## ivar must be a character vector of length >= 1 all of which | |
## exist in the dataframe used to fit the model | |
## if the length of ivar > 1, joint dependence is calculated | |
df <- data.frame(get("input", fit@data@env), get("response", fit@data@env)) | |
rng <- lapply(ivar, function(x) ivar_points(df, x)) | |
rng <- expand.grid(rng) | |
pred <- mclapply(1:nrow(rng), function(i) { | |
df[, ivar] <- rng[i, 1:ncol(rng)] | |
if (is.numeric(df[, ncol(df)])) { | |
c(rng[i, 1:ncol(rng)], mean(predict(fit, newdata = df))) | |
} else { | |
pred <- table(predict(fit, newdata = df)) | |
c(rng[i, 1:ncol(rng)], names(pred)[pred == max(pred)]) | |
} | |
}, mc.cores = cores) | |
pred <- as.data.frame(do.call("rbind", pred)) | |
colnames(pred)[1:length(ivar)] <- ivar | |
colnames(pred)[ncol(pred)] <- "pred" | |
return(pred) | |
} | |
ivar_points <- function(df, x, cutoff = 10) { | |
## create a grid of prediction points | |
## df = the dataframe used to fit the model | |
## x is a character indicated the variable | |
## for which the prediction grid is being created | |
## cutoff = the number of unique values where a grid | |
## is created rather than using all unique values | |
rng <- unique(df[, x]) | |
rng <- rng[!is.na(rng)] | |
if (length(rng) > cutoff) | |
rng <- seq(min(rng), max(rng), length.out = cutoff) | |
class(rng) <- class(df[, x]) | |
return(rng) | |
} | |
## example | |
data(iris) | |
fit <- cforest(Species ~ ., data = iris, control = cforest_unbiased(mtry = 2)) | |
pd <- partial_dependence(fit, "Petal.Width", detectCores()) | |
pd_int <- partial_dependence(fit, c("Petal.Width", "Sepal.Length"), detectCores()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment