Skip to content

Instantly share code, notes, and snippets.

@jokergoo
Last active May 5, 2022 13:11
Show Gist options
  • Save jokergoo/e8fff4a57ec59efc694b9e730da22b9f to your computer and use it in GitHub Desktop.
Save jokergoo/e8fff4a57ec59efc694b9e730da22b9f to your computer and use it in GitHub Desktop.
library(matrixStats)
library(genefilter)
gene_level = function(mat, condition, method = "tvalue", transform = "none",
binarize = function(x) x) {
le = levels(condition)
l_group1 = condition == le[1]
l_group2 = !l_group1
mat1 = mat[, l_group1, drop = FALSE]
mat2 = mat[, l_group2, drop = FALSE]
if(method == "log2fc") {
stat = log2(rowMeans(mat1)/rowMeans(mat2))
} else if(method == "s2n") {
stat = (rowMeans(mat1) - rowMeans(mat2))/(rowSds(mat1) + rowSds(mat2))
} else if(method == "tvalue") {
stat = (rowMeans(mat1) - rowMeans(mat2))/sqrt(rowVars(mat1)/ncol(mat1) + rowVars((mat2)/ncol(mat2)))
} else if(method == "sam") {
s = sqrt(rowVars(mat1)/ncol(mat1) + rowVars((mat2)/ncol(mat2)))
stat = (rowMeans(mat1) - rowMeans(mat2))/(s + quantile(s, 0.1))
} else if(method == "ttest") {
stat = rowttests(mat, factor(condition))$p.value
} else {
stop("method is not supported.")
}
if(transform == "none") {
} else if(transform == "abs") {
stat = abs(stat)
} else if(transform == "square") {
stat = stat^2
} else if(transform == "binary") {
stat = binarize(stat)
} else {
stop("method is not supported.")
}
return(stat)
}
set_level = function(gene_stat, l_set, method = "mean") {
if(!any(l_set)) {
return(NA)
}
if(method == "mean") {
stat = mean(gene_stat[l_set])
} else if(method == "sum") {
stat = sum(gene_stat[l_set])
} else if(method == "median") {
stat = median(gene_stat[l_set])
} else if(method == "maxmean") {
s = gene_stat[l_set]
s1 = mean(s[s > 0])
s2 = mean(s[s < 0])
stat = ifelse(s1 > abs(s2), s1, s2)
} else if(method == "ks") {
# order gene_stat
od = order(gene_stat, decreasing = TRUE)
gene_stat = gene_stat[od]
l_set = l_set[od]
s_set = abs(gene_stat)
s_set[!l_set] = 0
f1 = cumsum(s_set)/sum(s_set)
l_other = !l_set
f2 = cumsum(l_other)/sum(l_other)
stat = max(f1 - f2)
} else if(method == "wilcox") {
stat = wilcox_stat(gene_stat[l_set], gene_stat[!l_set])
} else if(method == "chisq") {
# should on work with binary gene-level statistics
stat = chisq_stat(gene_stat, l_set)
} else {
stop("method is not supported.")
}
return(stat)
}
wilcox_stat = function(x1, x2) {
if(length(x1) > 100) {
x1 = sample(x1, 100)
}
if(length(x2) > 100) {
x2 = sample(x2, 100)
}
sum(outer(x1, x2, ">"))
}
# x1: a logical vector or a binary vector
# x2: a logical vector or a binary vector
chisq_stat = function(x1, x2) {
n11 = sum(x1 & x2)
n10 = sum(x1)
n20 = sum(!x1)
n01 = sum(x2)
n02 = sum(!x2)
n = length(x1)
n12 = n10 - n11
n21 = n01 - n11
n22 = n20 - n21
p10 = n10/n
p20 = n20/n
p01 = n01/n
p02 = n02/n
e11 = n*p10*p01
e12 = n*p10*p02
e21 = n*p20*p01
e22 = n*p20*p02
stat = (n11 - e11)^2/e11 +
(n12 - e12)^2/e12 +
(n21 - e21)^2/e21 +
(n22 - e22)^2/e22
return(stat)
}
gsea_tiny = function(mat, condition, geneset,
gene_level_method = "tvalue", transform = "none", binarize = function(x) x,
gene_stat, set_level_method = "mean",
nperm = 1000, perm_type = "sample") {
gene_stat = gene_level(mat, condition, method = gene_level_method,
transform = transform, binarize = binarize)
l_set_list = lapply(geneset, function(set) {
rownames(mat) %in% set
})
set_stat = sapply(l_set_list, function(l_set) {
set_level(gene_stat, l_set, set_level_method)
})
## null distribution
set_stat_random = list()
for(i in seq_len(nperm)) {
if(perm_type == "sample") {
condition2 = sample(condition)
gene_stat_random = gene_level(mat, condition2, method = gene_level_method,
transform = transform, binarize = binarize)
set_stat_random[[i]] = sapply(l_set_list, function(l_set) {
set_level(gene_stat_random, l_set, set_level_method)
})
} else if(perm_type == "gene") {
gene_stat_random = sample(gene_stat)
set_stat_random[[i]] = sapply(l_set_list, function(l_set) {
set_level(gene_stat_random, l_set, set_level_method)
})
} else {
stop("wrong permutation type.")
}
if(i %% 100 == 0) {
message(i, " permutations done.")
}
}
set_stat_random = do.call(cbind, set_stat_random)
n_set = length(geneset)
p = numeric(n_set)
for(i in seq_len(n_set)) {
p[i] = sum(set_stat_random[i, ] >= set_stat[i])/nperm
}
df = data.frame(stat = set_stat,
size = sapply(geneset, length),
p.value = p)
df$fdr = p.adjust(p, "BH")
return(df)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment