Skip to content

Instantly share code, notes, and snippets.

@zmjones
Last active August 29, 2015 14:23
Show Gist options
  • Save zmjones/63d21d0308752755a3ae to your computer and use it in GitHub Desktop.
Save zmjones/63d21d0308752755a3ae to your computer and use it in GitHub Desktop.
partial dependence for supervised methods
library(mlr)
library(checkmate)
fr = train("regr.rpart", bh.task)
dr = generatePartialPredictionData(fr, getTaskData(bh.task), c("lstat", "chas"))
plotPartialPrediction(dr, facet = "chas")
fc = train("classif.rpart", iris.task)
dc = generatePartialPredictionData(fc, getTaskData(iris.task), c("Petal.Width", "Petal.Length"),
function(x) table(x) / length(x))
plotPartialPrediction(dc, facet = "Petal.Length")
fcp = train(makeLearner("classif.rpart", predict.type = "prob"), iris.task)
dcp = generatePartialPredictionData(fcp, getTaskData(iris.task), c("Petal.Width", "Petal.Length"))
plotPartialPrediction(dcp, facet = "Petal.Length")
fs = train("surv.coxph", wpbc.task)
ds = generatePartialPredictionData(fs, getTaskData(wpbc.task), c("pnodes", "worst_radius"))
plotPartialPrediction(ds, facet = "worst_radius")
generatePartialPredictionData = function(obj, data, features, fun = mean,
resample = NULL, fmin = NULL, fmax = NULL,
gridsize = 10L, ...) {
assertClass(obj, "WrappedModel")
td = obj$task.desc
rng = lapply(features, function(x) generateFeatureGrid(x, data, resample, fmin, fmax, gridsize))
rng = as.data.frame(rng)
if (length(features) > 1L)
rng = expand.grid(rng)
ppred = lapply(1:nrow(rng), function(x) {
data[features] = rng[x, ]
pred = do.call("predict", c(list("object" = obj, "newdata" = data), list(...)))$data
if (obj$learner$predict.type == "response") {
fun(pred$response)
} else {
cols = lapply(td$class.levels, function(x) grepl(x, colnames(pred)))
cols = apply(do.call("rbind", cols), 2, any)
apply(pred[, cols], 2, fun)
}
})
ppred = as.data.frame(do.call("rbind", ppred))
if (td$type %in% c("regr", "surv")) {
target = td$target
} else {
assert(ncol(ppred) == length(td$class.levels))
target = td$class.levels
}
data = cbind(ppred, rng)
if ((ncol(ppred) == 1L & td$type == "regr") | td$type == "classif")
colnames(data) = c(target, features)
else if (ncol(ppred) == 3L & td$type == "regr") {
colnames(data) = c("lower", target, "upper", features)
assert(all(data$lower < data[[target]] & data[[target]] < data$upper))
} else {
assert(td$type == "surv")
assert(ncol(ppred) == 1L)
colnames(data) = c("risk", features)
}
makeS3Obj("PartialPredictionData",
data = data,
task.desc = td,
target = target,
features = features)
}
print.PartialPredictionData = function(x, ...) {
catf("PartialPredictionData")
catf("Task: %s", x$task.desc$id)
catf("Featuress: %s", paste(x$features, collapse = ", "))
catf("Target: %s", paste(x$target, collapse = ", "))
print(head(x$data))
}
plotPartialPrediction = function(obj, facet = NULL) {
assertClass(obj, "PartialPredictionData")
assert(length(obj$features) <= 2L)
bounds = all(c("lower", "upper") %in% colnames(obj$data))
if (!is.null(facet)) {
assert(facet %in% obj$features & length(obj$features) > 1L)
feature = obj$features[which(obj$features != facet)]
if (!is.factor(obj$data[[facet]]))
obj$data[[facet]] = paste(facet, "=", as.factor(signif(obj$data[[facet]], 2)), sep = " ")
else
obj$data[[facet]] = paste(facet, "=", obj$data[[facet]])
} else {
feature = obj$features
facet = NULL
}
if (obj$task.desc$type == "surv")
target = "risk"
else
target = obj$target
if (all(target %in% obj$task.desc$class.levels)) {
out = reshape2::melt(obj$data, id.vars = obj$features, variable = "Class", value.name = "Probability")
out$Class = gsub("^prob\\.", "", out$Class)
plt = ggplot2::ggplot(out, ggplot2::aes_string(feature, "Probability", color = "Class"))
} else {
plt = ggplot2::ggplot(obj$data, ggplot2::aes_string(feature, target))
}
plt = plt + ggplot2::geom_point() + ggplot2::geom_line()
if (bounds)
plt = plt + ggplot2::geom_ribbon(ggplot2::aes_string(ymin = "lower", ymax = "upper"), alpha = .5)
if (!is.null(facet)) {
plt = plt + ggplot2::facet_wrap(as.formula(paste("~", facet)), scales = "free_y")
}
plt
}
generateFeatureGrid = function(feature, data, resample = NULL,
fmin = NULL, fmax = NULL, cutoff = 10L) {
if (is.factor(data[[feature]])) {
rep(levels(data[[feature]]), length.out = cutoff)
} else {
if (is.null(fmin))
fmin = min(data[[feature]], na.rm = TRUE)
if (is.null(fmax))
fmax = max(data[[feature]], na.rm = TRUE)
if (!is.null(resample)) {
assertChoice(resample, c("bootstrap", "subsample"))
sample(data[[feature]], cutoff, resample == "bootstrap")
} else
seq(fmin, fmax, length.out = cutoff)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment