Skip to content

Instantly share code, notes, and snippets.

@sergeant-wizard
Created December 2, 2015 05:59
Show Gist options
  • Save sergeant-wizard/91b0c3512accfe960453 to your computer and use it in GitHub Desktop.
Save sergeant-wizard/91b0c3512accfe960453 to your computer and use it in GitHub Desktop.
library(dplyr)
library(ggplot2)
extract_input <- function(data) {
return(data %>% select(c(Sepal.Length, Sepal.Width, Petal.Length, Petal.Width)))
}
extract_labels <- function(data) {
return(data %$% Species %>% factor)
}
split_columns <- function(data) {
return(list(
"input" = data %>% extract_input,
"labels" = data %>% extract_labels
))
}
error_rate <- function(test_data_labels, classified_labels) {
num_data <- test_data_labels %>% length
error_count <- (as.character(test_data_labels) != as.character(classified_labels)) %>% sum
return(error_count / num_data)
}
sample_all <- function(data) {
return(data %>% .[sample(1:nrow(.)),])
}
classifier_error <- function(train, test) {
knn_result <- FNN::knn(train$input, test$input, train$labels, k=32)
return(error_rate(test$label, knn_result))
}
holdout <- function(data, num_train) {
train_data <- data %>% head(num_train) %>% split_columns
test_data <- data %>% tail(nrow(.) - num_train) %>% split_columns
return(classifier_error(train_data, test_data))
}
bootstrap <- function(data) {
num_bootstrap_samples <- 1024
base_sample <- data %>%
split_columns
sample_df <- plyr::ldply(seq(num_bootstrap_samples), function(bootstrap_index_) {
bootstrap_sample <- data %>%
.[sample(1:nrow(.), replace = TRUE),] %>%
split_columns
rbind(
data.frame(
class="N_star, N_star",
error_rate=classifier_error(bootstrap_sample, bootstrap_sample)
),
data.frame(
class="N_star, N",
error_rate=classifier_error(bootstrap_sample, base_sample)
)
)
})
return(sample_df %>%
rbind(data.frame(class="N, N", error_rate=classifier_error(base_sample, base_sample)))
)
}
cross_validation <- function(data, subset_size) {
num_splits <- nrow(iris) / subset_size
error_array <- sapply(seq(num_splits), function(split_index) {
test_data_flags <- c(
rep(FALSE, subset_size * (split_index - 1)),
rep(TRUE, subset_size),
rep(FALSE, subset_size * (num_splits - split_index))
)
test_data <- data[test_data_flags,] %>% split_columns
train_data <- data[!test_data_flags,] %>% split_columns
classifier_error(train_data, test_data)
})
return(error_array %>% mean)
}
# holdout
num_train_array <- seq(from=10, to=140, by=10)
num_samples_per_train <- 64
error_rate_df <- plyr::ldply(num_train_array, function(num_train) {
plyr::ldply(seq(num_samples_per_train), function(sample_index) {
sampled_iris <- iris %>% sample_all
c(num_train=num_train, error_rate=holdout(sampled_iris, num_train))
})
})
p <- ggplot(error_rate_df, aes(x=factor(num_train), y=error_rate))
p + geom_boxplot()
# cross-validation / jacknife
subset_size_array = c(1, 5, 10, 15, 25, 50, 75)
error_rate_df <- plyr::ldply(subset_size_array, function(subset_size) {
plyr::ldply(seq(num_samples_per_train), function(subset_size_index) {
sampled_iris <- iris %>% sample_all
c(subset_size=subset_size, error_rate=cross_validation(sampled_iris, subset_size))
})
})
p <- ggplot(error_rate_df, aes(x=factor(subset_size), y=error_rate))
p + geom_boxplot()
# bootstrap
sampled_iris <- iris %>% sample_all
bootstrap_result <- bootstrap(sampled_iris)
p <- ggplot(bootstrap_result, aes(x=factor(class), y=error_rate))
p + geom_boxplot()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment