Last active
July 28, 2021 09:25
-
-
Save trinker/594bd132b180a43945f7 to your computer and use it in GitHub Desktop.
Find the optimal number of topics in a topic model using the harmonic mean of the log likelihood
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#' Find Optimal Number of Topics | |
#' | |
#' Iteratively produces models and then compares of the harmonic mean of the log likelihoods in a graphical output. | |
#' | |
#' @param x A \code{\link[tm]{DocumentTermMatrix}}. | |
#' @param max.k Maximum number of topics to fit (start small [i.e., default of 30] and add as necessary). | |
#' @param burnin Object of class \code{"integer"}; number of omitted Gibbs iterations at beginning, by default equals 0. | |
#' @param iter Object of class \code{"integer"}; number of Gibbs iterations, by default equals 2000. | |
#' @param keep Object of class \code{"integer"}; if a positive integer, the log-likelihood is saved every keep iterations. | |
#' @param method The method to be used for fitting; currently \code{method = "VEM"} or \code{method= "Gibbs"} are supported. | |
#' @param \ldots Other arguments passed to \code{??LDAcontrol}. | |
#' @return Returns the \code{\link[base]{data.frame}} of k (nuber of topics) and the associated log likelihood. | |
#' @references \url{http://stackoverflow.com/a/21394092/1000343} \cr | |
#' Ponweiser, M. (2012). Latent Dirichlet Allocation in R (Diploma Thesis). Vienna University of | |
#' Economics and Business, Vienna. http://cran.r-project.org/web/packages/topicmodels/vignettes/topicmodels.pdf | |
#' @keywords k topicmodel | |
#' @export | |
#' @author Ben Marwick and Tyler Rinker <tyler.rinker@@gmail.com>. | |
#' @examples | |
#' ## Install/Load Tools & Data | |
#' if (!require("pacman")) install.packages("pacman") | |
#' pacman::p_load_gh("trinker/gofastr") | |
#' pacman::p_load(tm, topicmodels, dplyr, tidyr, devtools, LDAvis, ggplot2) | |
#' | |
#' | |
#' ## Source topicmodels2LDAvis function | |
#' devtools::source_url("https://gist.githubusercontent.com/trinker/477d7ae65ff6ca73cace/raw/79dbc9d64b17c3c8befde2436fdeb8ec2124b07b/topicmodels2LDAvis") | |
#' | |
#' data(presidential_debates_2012) | |
#' | |
#' | |
#' ## Generate Stopwords | |
#' stops <- c( | |
#' tm::stopwords("english"), | |
#' "governor", "president", "mister", "obama","romney" | |
#' ) %>% | |
#' gofastr::prep_stopwords() | |
#' | |
#' | |
#' ## Create the DocumentTermMatrix | |
#' doc_term_mat <- presidential_debates_2012 %>% | |
#' with(gofastr::q_dtm_stem(dialogue, paste(person, time, sep = "_"))) %>% | |
#' gofastr::remove_stopwords(stops) %>% | |
#' gofastr::filter_tf_idf() %>% | |
#' gofastr::filter_documents() | |
#' | |
#' | |
#' opti_k <- optimal_k(doc_term_mat) | |
#' opti_k | |
optimal_k <- function(x, max.k = 30, burnin = 1000, iter = 1000, keep = 50, method = "Gibbs", verbose = TRUE, ...){ | |
if (max.k > 20) { | |
message("\nGrab a cup of coffee this is gonna take a while...\n") | |
flush.console() | |
} | |
tic <- Sys.time() | |
hm_many <- sapply(2:max.k, function(k){ | |
fitted <- topicmodels::LDA(x, k = k, method = method, control = list(burnin = burnin, iter = iter, keep = keep) ) | |
logLiks <- fitted@logLiks[-c(1:(burnin/keep))] | |
harmonicMean(logLiks) | |
}) | |
out <- c(2:max.k)[which.max(hm_many)] | |
class(out) <- c("optimal_k", class(out)) | |
attributes(out)[["k_dataframe"]] <- data.frame( | |
k = 2:max.k, | |
harmonic_mean = hm_many | |
) | |
if (isTRUE(verbose)) cat(sprintf("Optimal number of topics = %s\n",as.numeric(out))) | |
out | |
} | |
harmonicMean <- function(logLikelihoods, precision=2000L) { | |
llMed <- Rmpfr::median(logLikelihoods) | |
as.double(llMed - log(Rmpfr::mean(exp(-Rmpfr::mpfr(logLikelihoods, prec = precision) + llMed)))) | |
} | |
#' Plots a plot.optimal_k Object | |
#' | |
#' Plots a plot.optimal_k object | |
#' | |
#' @param x A \code{optimal_k} object. | |
#' @param \ldots Ignored. | |
#' @method plot plot.optimal_k | |
#' @export | |
plot.optimal_k <- function(x, ...){ | |
y <- attributes(x)[["k_dataframe"]] | |
y <- y[y[["k"]] == as.numeric(x), ] | |
ggplot2::ggplot(attributes(x)[["k_dataframe"]], ggplot2::aes_string(x="k", y="harmonic_mean")) + | |
ggplot2::xlab("Number of Topics") + | |
ggplot2::ylab("Harmonic Mean of Log Likelihood") + | |
geom_point(data=y, color="blue", fill=NA, size = 6, shape = 21) + | |
ggplot2::geom_line(size=1) + | |
ggplot2::theme_bw() + | |
ggplot2::theme( | |
axis.title.x = ggplot2::element_text(vjust = -0.25, size = 14), | |
axis.title.y = ggplot2::element_text(size = 14, angle=90) | |
) | |
} | |
#' Prints a optimal_k Object | |
#' | |
#' Prints a optimal_k object | |
#' | |
#' @param x A \code{optimal_k} object. | |
#' @param \ldots Ignored. | |
#' @method print optimal_k | |
#' @export | |
print.optimal_k <- function(x, ...){ | |
print(graphics::plot(x)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi I have a question. If the method is 'VEM', how do you calculate the harmonic mean? Because the iteration is different. Thanks.