Created
March 8, 2015 02:25
-
-
Save czxttkl/66c4c8676f5b07807cbc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This file does a bootstrapping experiment to see what the average recall is when we shuffle the element | |
# Largely take reference here to calculate precision, recall and f1: | |
# http://stats.stackexchange.com/questions/15158/precision-and-recall-for-clustering | |
remove(list=ls()) | |
library(combinat) | |
# label 1: 1779 | |
# label 2: 2979 | |
# label 3: 1975 | |
# label 4: 3260 | |
raw.labels <- c(rep(1, times = 1779), rep(2, times = 2979), rep(3, times = 1975), rep(4, times = 3260)) | |
# function to calculate FN | |
cal.fn <- function(j, count.table) { | |
fn <- 0 | |
for (i in c(1,2,3)) { | |
fn <- fn + count.table[i, j] * sum(count.table[(i+1):4, j]) | |
} | |
fn | |
} | |
bootstrap <- function(k) { | |
count.table <- matrix(rep(NA, 16), nrow = 4) | |
random.labels <- sample(c(1,2,3,4), size = length(raw.labels), replace = T) | |
# row i, column j: the counts of elements that is shuffled to label i but belongs to label j | |
for (i in c(1,2,3,4)) { | |
for (j in c(1,2,3,4)) { | |
count.table[i, j] <- sum(random.labels == i & raw.labels== j) | |
} | |
} | |
tp <- sum(apply(count.table, MARGIN = c(1,2), function(x) choose(x, 2))) | |
tp.plus.fp <- sum(sapply(apply(count.table, MARGIN = 1, FUN = sum), function(x) choose(x, 2))) | |
fn <- sum(sapply(c(1,2,3,4), function(x) cal.fn(x, count.table))) | |
precision <- tp / tp.plus.fp | |
recall <- tp / (tp + fn) | |
f1 <- 2 * precision * recall / (precision + recall) | |
list("precision" = precision, "recall" = recall, "f1" = f1) | |
} | |
times <- 100 | |
results <- replicate(times, bootstrap()) | |
results <- matrix(unlist(replicate(times, bootstrap())), ncol = 3, byrow = T) | |
ave_results <- apply(results, MARGIN = 2, mean) | |
names(ave_results) <- c("precision", "recall", "f1") | |
print(ave_results) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment