Skip to content

Instantly share code, notes, and snippets.

@earino
Created April 15, 2015 20:21
Show Gist options
  • Select an option

  • Save earino/34e85755d5efb1bf7c31 to your computer and use it in GitHub Desktop.

Select an option

Save earino/34e85755d5efb1bf7c31 to your computer and use it in GitHub Desktop.
expanded RF resamples
library(randomForest)
Set.seed(12345)
genr_data <- function(n,p) {
X <- matrix(rnorm(n*p),n,p)
y <- as.factor(apply(X,1, function(x)
ifelse(sum(x^2)>qchisq(0.5,p),"+","-")))
## Hastie etal 10.2
data.frame(X,y)
}
err_rate <- function(md, d_test) {
yp <- predict(md, newdata=d_test)
sum(d_test$y!=yp)/nrow(d_test)
}
n <- 1000
p <- 100
n_runs <- 100
d_train <- genr_data(n,p)
d_test <- genr_data(10000,p)
n_trees <- 500
error_rates <- sapply(1:n_runs, function(x) {
system.time({
md <- randomForest(y~., d_train, ntree=n_trees)
})
err_rate(md, d_test)
})
retval <- data.frame(method="1", mean=mean(error_rates), sd=sd(error_rates))
m <- 10
error_rates <- sapply(1:n_runs, function(x) {
mds_split <- lapply(0:(m-1), function(k) {
idx <- which((1:nrow(d_train)) %% m == k)
randomForest(y~., d_train[idx,], ntree=n_trees/m)
})
md_split <- do.call(combine, mds_split)
err_rate(md_split, d_test)
})
retval <- rbind(retval,
data.frame(method="2", mean=mean(error_rates), sd=sd(error_rates)))
resplits <- 10
error_rates <- sapply(1:n_runs, function(x) {
mds_split <- lapply(0:(m-1), function(k) {
t <- lapply(1:resplits, function(x) {
idx <- sample(1:nrow(d_train), nrow(d_train)/m)
randomForest(y~., d_train[idx,], ntree=n_trees/m/resplits)
})
do.call(combine, t)
})
md_split <- do.call(combine, mds_split)
err_rate(md_split, d_test)
})
retval <- rbind(retval,
data.frame(method="3", mean=mean(error_rates), sd=sd(error_rates)))
resplits <- 10
error_rates <- sapply(1:n_runs, function(x) {
mds_split <- lapply(0:(m-1), function(k) {
t <- lapply(1:resplits, function(x) {
idx <- sample(1:nrow(d_train),(nrow(d_train)/m)*2)
randomForest(y~., d_train[idx,], ntree=n_trees/m/resplits)
})
do.call(combine, t)
})
md_split <- do.call(combine, mds_split)
err_rate(md_split, d_test)
})
retval <- rbind(retval,
data.frame(method="4", mean=mean(error_rates), sd=sd(error_rates)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment