Last active
November 10, 2022 07:27
-
-
Save amacanovic/054d2fac0cb7888182af5ef0b27e8ee8 to your computer and use it in GitHub Desktop.
A code that allows you to use the measures from the ldatuning R package (https://github.com/nikita-moor/ldatuning) for use with keyATM models (https://keyatm.github.io/keyATM/). These measures allow you to choose the optimal number of topics to cluster textual data into with the keyATM models by Eshima, Imai, and Sasaki (http://arxiv.org/abs/200…
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Adapting the ldatuning functions to keyATM model structure in R | |
# Code modified by Ana Macanovic (see amacanovic.github.io for contact details) | |
# Please let me know if you find errors or have suggestions regarding these functions. | |
# The function metrics_ldatuning below adapts the metrics from the ldatuning | |
# package (https://github.com/nikita-moor/ldatuning) so they can be used with the output | |
# of keyATM models (https://keyatm.github.io/keyATM/) . | |
# See this discussion for more details: https://github.com/keyATM/keyATM/issues/169 | |
# Note that the authors of keyATM models suggest that one does not need to adjust the number of topics | |
# in these models as long as the model receives a reasonable set of keywords as an input. In that sense, | |
# this code might not be necessary for fine-tuning keyATM models. | |
# Please also note that it is not a priori evident that the ldatuning measures could or should be | |
# used with the keyATM models. Further, I do not guarantee that this code is error free; so please | |
# use it at your own risk. | |
#the input is a list of fitted models, e.g. | |
# models <- list(model_1 = model_a, | |
# model_2 = model_a2, | |
# model_3 = model_a3) | |
#the function also needs an argument with the keyatm data object, e.g. | |
# dataset_atm <- keyatm_data_a | |
# Measures as adapted from the ldatuning R package | |
Griffiths2004 <- function(models) { | |
# log-likelihoods (remove first burnin stage) | |
burnin <- 0 | |
logLiks <- lapply(models, function(model) { | |
# Check to make sure logLiks were kept; if not, value is NaN | |
if (length(model$model_fit$`Log Likelihood`) == 0) { | |
message("No logLiks were kept, which is required to use this scoring algorithm. Please regenerate the model using the keep control parameter set to a reasonable value (default = 50).") | |
NaN | |
} else { | |
utils::tail(model$model_fit$`Log Likelihood`, n = length(model$model_fit$`Log Likelihood`) - burnin) | |
} | |
}) | |
# harmonic means for every model | |
metrics <- sapply(logLiks, function(x) { | |
# code is a little tricky, see explanation in [Ponweiser2012 p. 36] | |
# ToDo: add variant without "Rmpfr" | |
llMed <- stats::median(x) | |
metric <- as.double( | |
llMed - log( Rmpfr::mean( exp( -Rmpfr::mpfr(x, prec=2000L) + llMed ))) | |
) | |
return(metric) | |
}) | |
return(metrics) | |
} | |
CaoJuan2009 <- function(models) { | |
metrics <- sapply(models, function(model) { | |
# topic-word matrix | |
m1 <- exp(model$phi) | |
# pair-wise cosine distance | |
pairs <- utils::combn(nrow(m1), 2) | |
cos.dist <- apply(pairs, 2, function(pair) { | |
x <- m1[pair[1], ] | |
y <- m1[pair[2], ] | |
# dist <- lsa::cosine(x, y) | |
dist <- crossprod(x, y) / sqrt(crossprod(x) * crossprod(y)) | |
return(dist) | |
}) | |
# metric | |
topics_total <- model$keyword_k + model$no_keyword_topics | |
metric <- sum(cos.dist) / (topics_total*(topics_total-1)/2) | |
return(metric) | |
}) | |
return(metrics) | |
} | |
Arun2010 <- function(models, dataset_atm) { | |
# length of documents (count of words) | |
len <- lengths(dataset_atm$W_raw) | |
# evaluate metrics | |
metrics <- sapply(models, FUN = function(model) { | |
# matrix M1 topic-word | |
m1 <- exp(model$phi) # rowSums(m1) == 1 | |
m1.svd <- svd(m1) | |
cm1 <- as.matrix(m1.svd$d) | |
# matrix M2 document-topic | |
m2 <- model$theta # rowSums(m2) == 1 | |
cm2 <- len %*% m2 # crossprod(len, m2) | |
norm <- norm(as.matrix(len), type="m") | |
cm2 <- as.vector(cm2 / norm) | |
# symmetric Kullback-Leibler divergence | |
divergence <- sum(cm1*log(cm1/cm2)) + sum(cm2*log(cm2/cm1)) | |
return ( divergence ) | |
}) | |
return(metrics) | |
} | |
Deveaud2014 <- function(models) { | |
metrics <- sapply(models, function(model) { | |
### original version | |
# topic-word matrix | |
m1 <- exp(model$phi) | |
# prevent NaN | |
if (any(m1 == 0)) { m1 <- m1 + .Machine$double.xmin } | |
# pair-wise Jensen-Shannon divergence | |
pairs <- utils::combn(nrow(m1), 2) | |
jsd <- apply(pairs, 2, function(pair) { | |
x <- m1[pair[1], ] | |
y <- m1[pair[2], ] | |
### standard Jensen-Shannon divergence | |
# m <- (x + y) / 2 | |
# jsd <- 0.5 * sum(x*log(x/m)) + 0.5 * sum(y*log(y/m)) | |
### divergence by Deveaud2014 | |
jsd <- 0.5 * sum(x*log(x/y)) + 0.5 * sum(y*log(y/x)) | |
return(jsd) | |
}) | |
# ### optimized version | |
# m1 <- model@beta | |
# m1.e <- exp(model@beta) | |
# pairs <- utils::combn(nrow(m1), 2) | |
# jsd <- apply(pairs, 2, function(pair) { | |
# x <- m1[pair[1], ] | |
# y <- m1[pair[2], ] | |
# x.e <- m1.e[pair[1], ] | |
# y.e <- m1.e[pair[2], ] | |
# jsd <- ( sum(x.e*(x-y)) + sum(y.e*(y-x)) ) / 2 | |
# return(jsd) | |
# }) | |
# metric | |
topics_total <- model$keyword_k + model$no_keyword_topics | |
metric <- sum(jsd) / (topics_total*(topics_total-1)) | |
return(metric) | |
}) | |
return(metrics) | |
} | |
# Combining the metrics in a single function: | |
metrics_ldatuning <- function(models, dataset_atm){ | |
griff2004 <- Griffiths2004(models) | |
names(griff2004) <- sapply(models, function(model) model$keyword_k + + model$no_keyword_topics) | |
caojuan2009 <- CaoJuan2009(models) | |
names(caojuan2009) <- sapply(models, function(model) model$keyword_k + + model$no_keyword_topics) | |
arun2010 <- Arun2010(models, dataset_atm) | |
names(arun2010) <- sapply(models, function(model) model$keyword_k + + model$no_keyword_topics) | |
deveaud2014 <- Deveaud2014(models) | |
names(deveaud2014) <- sapply(models, function(model) model$keyword_k + + model$no_keyword_topics) | |
metrics_output <- rbind(griff2004, | |
caojuan2009, | |
arun2010, | |
deveaud2014) | |
metrics_output_keyatm <- data.frame(t(metrics_output)) | |
metrics_output_keyatm <- cbind.data.frame(as.numeric(rownames(metrics_output_keyatm)), | |
metrics_output_keyatm) | |
colnames(metrics_output_keyatm) <- c("topics", | |
"Griffiths2004", | |
"CaoJuan2009", | |
"Arun2010", | |
"Deveaud2014") | |
return(metrics_output_keyatm) | |
} | |
# Using the FindTopicsNumber_plot from the ldatuning package, which takes | |
# the output of the metrics_ldatuning function | |
FindTopicsNumber_plot <- function(values) { | |
# Drop models if present, as they won't rescale | |
if ("LDA_model" %in% names(values)) { | |
values <- values[!names(values) %in% c("LDA_model")] | |
} | |
# normalize to [0,1] | |
columns <- base::subset(values, select = 2:ncol(values)) | |
values <- base::data.frame( | |
values["topics"], | |
base::apply(columns, 2, function(column) { | |
scales::rescale(column, to = c(0, 1), from = range(column)) | |
}) | |
) | |
# melt | |
values <- reshape2::melt(values, id.vars = "topics", na.rm = TRUE) | |
# separate max-arg & min-arg metrics | |
values$group <- values$variable %in% c("Griffiths2004", "Deveaud2014") | |
values$group <- base::factor( | |
values$group, | |
levels = c(FALSE, TRUE), | |
labels = c("minimize", "maximize") | |
) | |
# standart plot | |
p <- ggplot(values, aes_string(x = "topics", y = "value", group = "variable")) | |
p <- p + geom_line() | |
p <- p + geom_point(aes_string(shape = "variable"), size = 3) | |
p <- p + guides(size = FALSE, shape = guide_legend(title = "metrics:")) | |
p <- p + scale_x_continuous(breaks = values$topics) | |
p <- p + labs(x = "number of topics", y = NULL) | |
# separate in two parts | |
p <- p + facet_grid(group ~ .) | |
# style | |
# p <- p + theme_bw(base_size = 14, base_family = "") %+replace% theme( | |
p <- p + theme_bw() %+replace% theme( | |
panel.grid.major.y = element_blank(), | |
panel.grid.minor.y = element_blank(), | |
panel.grid.major.x = element_line(colour = "grey70"), | |
panel.grid.minor.x = element_blank(), | |
legend.key = element_blank(), | |
strip.text.y = element_text(angle = 90) | |
) | |
# move strip block to left side | |
g <- ggplotGrob(p) | |
g$layout[g$layout$name == "strip-right", c("l", "r")] <- 3 | |
grid::grid.newpage() | |
grid::grid.draw(g) | |
# return(p) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment