Skip to content

Instantly share code, notes, and snippets.

@zmjones
Last active August 29, 2015 14:04
Show Gist options
  • Save zmjones/c2d2ff4606b027d4ae0a to your computer and use it in GitHub Desktop.
Save zmjones/c2d2ff4606b027d4ae0a to your computer and use it in GitHub Desktop.
parallel calculation of marginal or joint dependence of explanatory variables from a party random forest
pkgs <- c("party", "parallel")
invisible(lapply(pkgs, require, character.only = TRUE))
partial_dependence <- function(fit, ivar, cores = 1, ...) {
## calculates the partial dependence of the response on explanatory variable(s)
## fit must be a party object
## ivar must be a character vector of length >= 1 all of which
## exist in the dataframe used to fit the model
## if the length of ivar > 1, joint dependence is calculated
df <- data.frame(get("input", fit@data@env), get("response", fit@data@env))
rng <- lapply(ivar, function(x) ivar_points(df, x))
rng <- expand.grid(rng)
pred <- mclapply(1:nrow(rng), function(i) {
df[, ivar] <- rng[i, 1:ncol(rng)]
if (is.numeric(df[, ncol(df)])) {
c(rng[i, 1:ncol(rng)], mean(predict(fit, newdata = df)))
} else {
pred <- table(predict(fit, newdata = df))
c(rng[i, 1:ncol(rng)], names(pred)[pred == max(pred)])
}
}, mc.cores = cores)
pred <- as.data.frame(do.call("rbind", pred))
colnames(pred)[1:length(ivar)] <- ivar
colnames(pred)[ncol(pred)] <- "pred"
return(pred)
}
ivar_points <- function(df, x, cutoff = 10) {
## create a grid of prediction points
## df = the dataframe used to fit the model
## x is a character indicated the variable
## for which the prediction grid is being created
## cutoff = the number of unique values where a grid
## is created rather than using all unique values
rng <- unique(df[, x])
rng <- rng[!is.na(rng)]
if (length(rng) > cutoff)
rng <- seq(min(rng), max(rng), length.out = cutoff)
class(rng) <- class(df[, x])
return(rng)
}
## example
data(iris)
fit <- cforest(Species ~ ., data = iris, control = cforest_unbiased(mtry = 2))
pd <- partial_dependence(fit, "Petal.Width", detectCores())
pd_int <- partial_dependence(fit, c("Petal.Width", "Sepal.Length"), detectCores())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment