Skip to content

Instantly share code, notes, and snippets.

@DexGroves
Created April 14, 2016 17:20
Show Gist options
  • Save DexGroves/cfe1d5ac7ea0a05eede102b848a6feaa to your computer and use it in GitHub Desktop.
Save DexGroves/cfe1d5ac7ea0a05eede102b848a6feaa to your computer and use it in GitHub Desktop.
library("mlbench")
library("gbm")
library("ggplot2")
get_spiral_df <- function(N, cycles, sd) {
spiral_data <- mlbench.spirals(N, cycles, sd)
data.frame(x = spiral_data$x[, 1],
y = spiral_data$x[, 2],
class = as.numeric(spiral_data$classes) - 1)
}
sdf <- get_spiral_df(5000, 6.0, 0)
sgbm <- gbm(class ~ .,
distribution = "bernoulli",
data = sdf,
n.trees = 1000,
shrinkage = 0.55,
verbose = TRUE,
interaction.depth = 4,
n.minobsinnode = 10,
train.fraction = 1.0
# ,cv.folds = 4, n.cores = 2
)
partial_dep <- plot(sgbm, i.var = c(1, 2), return.grid = TRUE,
n.trees = sgbm$n.trees, type = "response")
colnames(partial_dep) <- c("x", "y", "p")
ggplot(data = partial_dep) +
geom_tile(aes(x, y, fill = p)) +
scale_fill_continuous(low = "red", high = "blue")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment