Created
September 8, 2013 00:36
-
-
Save kylebgorman/6480811 to your computer and use it in GitHub Desktop.
lda-match.R: perform group matching via backward selection using a heuristic based on Fisher's linear discriminant
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env Rscript | |
# lda-match.R: Perform group matching via backward selection using a heuristic based on Fisher's | |
# linear discriminant | |
# Kyle Gorman <[email protected]> | |
require(MASS) | |
lda.match <- function(x, grouping, term.fnc=univariate.all) { | |
# Create a matched group via backward selection using a heuristic | |
# based on Fisher's linear discriminant. Observations are removed | |
# in the order of their distance from the mean value of a linear | |
# projection. | |
# | |
# This proceudre | |
# | |
# @param x a matrix in which columns contain numerical | |
# features on which to match | |
# @param grouping a factor vector containing corresponding group | |
# labels | |
# @param term.fnc function to which (x, grouping) will be applied; | |
# selection halts iff it returns TRUE | |
# @return a logical vector, TRUE iff row is in the match | |
# | |
stopifnot(nrow(x) == length(grouping)) | |
# compute the projection | |
if (ncol(x) == 1) | |
projection <- x | |
else | |
projection <- (x %*% lda(x, grouping$scaling)[, 1] | |
# things to keep track of | |
include <- rep(TRUE, nrow(x)) | |
gtable <- table(grouping[include]) | |
nxt <- tapply(levels(grouping), levels(grouping), function(x) 1) | |
# by-group list of indices of (relative) outliers | |
ord <- order(projection) | |
by.group <- split(ord, grouping[ord]) | |
# reverse order of indices if group mean is on the righthand side | |
projection.mu <- mean(projection) | |
for (group in levels(grouping)) | |
if (mean(projection[grouping==group]) > projection.mu) | |
by.group[[group]] <- rev(by.group[[group]]) | |
while (any(include)) { | |
# determine larger group and take one out | |
group <- levels(grouping)[which.max(gtable)] | |
include[by.group[[group]]][nxt[group]] <- FALSE | |
# check for convergence | |
if (term.fnc(x[include, ], grouping[include])) | |
break | |
# adjust group sizes and "nxt" tables | |
gtable[group] <- gtable[group] - 1 | |
nxt[group] <- nxt[group] + 1 | |
} | |
return(include) | |
} | |
## some sample termination functions | |
univariate.all <- function(x, grouping, p.value=.2, fun=t.test) { | |
# perform a by-group t-test on all columns | |
for (i in 1:ncol(x)) | |
if (fun(x[, i] ~ grouping)$p.value < p.value) | |
return(FALSE) | |
return(TRUE) | |
} | |
manova.one <- function(x, grouping, p.value=.2) { | |
# perform a MANOVA on grouping factor | |
summary(manova(x ~ grouping), test='Wilks')$stats[1, 6] > p.value | |
} | |
## sample run | |
d <- droplevels(subset(read.csv('DX-NRT.csv'), DX %in% c('ALI', 'ALN'))) | |
print(table(d$DX)) | |
mm <- matrix(with(d, c(NVIQ, ADOS, CA)), ncol=3) | |
print(table(d[lda.match(mm, d$DX), ]$DX)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment