-
-
Save sajidrahman/26a6a237c8ffa948554764bfbe2ede32 to your computer and use it in GitHub Desktop.
topic modeling in R
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Brian Abelson @brianabelson | |
# Harmony Institute | |
# December 5, 2012 | |
# lda is a wrapper for lda.collapsed.gibbs.sampler in the "lda" package | |
# it fits topic models using latent dirichlet allocation | |
# it provides arguments for cleaning the input text and tuning the parameters of the model | |
# it also returns alot of useful information about the topics/documents in a format that you can easily join back to your original data | |
# this allows you to easily model outcomes based on the distribution of topics within a collection of texts | |
lda <- function( | |
# DATA # | |
text, # a character vector of text documents | |
ids = NULL, # a vector of ids (to allow joining results to other variables). default is 1:N | |
# CLEANING # | |
lower_case = TRUE, # logical; should the function make the text lower case? | |
remove_stop_words = TRUE, # logical; should the function remove stop words? NOTE: this will also make the text lower case | |
stop_words_to_add = NULL, # a character vector of stopwords to add | |
remove_numbers = TRUE, # logical; should the function remove numbers? | |
remove_punctuation = TRUE, # logical; should the function remove punctuation? | |
remove_non_ascii = TRUE, # logical; should the function remove non-ASCII characters? | |
stem_words = FALSE, # logical; should the function stem the words? | |
char_range = c(2,50), # numeric vector of length two with low and high value of characters per word (inclusive!) | |
min_word_count = 5, # number of times a word/feature must occur in a text to be considered | |
# MODEL PARAMETERS # | |
n_topics = 10, # number of topics to fit | |
n_topic_words = 20, # number of top topic words to return | |
n_iter = 1000, # number of iterations | |
burnin = 100, # number of initial iterations to ignore. the function adds burnin to n_iter | |
alpha = 0.1, # the scalar value of the dirichlet hyperparameter for topic proportions | |
eta = 0.1, # the scalar value of the dirichlet hyperparamater for topic multinomials | |
# OUTPUT # | |
n_assignments = 3 # number of assigments to return (returned as ass_topic_a, ass_topic_b, ass_topic_c, etc.) | |
) { | |
# LIBRARIES | |
if(!require("tm")) { | |
install.packages("tm") | |
library("tm") | |
} | |
if(!require("lda")) { | |
install.packages("lda") | |
library("lda") | |
} | |
if(!require("plyr")) { | |
install.packages("plyr") | |
library("plyr") | |
} | |
if(!require("stringr")) { | |
install.packages("stringr") | |
library("stringr") | |
} | |
if(!require("Rstem")) { | |
install.packages("Rstem", repos="http://www.omegahat.org/R", type="source") | |
library("Rstem") | |
} | |
# start time (for calculating the time it takes for function to run) | |
start <- Sys.time() | |
# gen id var if NULL | |
if(is.null(ids)) { | |
ids <- 1:length(text) | |
} | |
url_pattern = '\b(?:(?:https?|ftp|file)://|www\\.|ftp\\.)[-A-Z0-9+&@#/%=~_|$?!:,.]*[A-Z0-9+&@#/%=~_|$]' | |
gsub(pattern) | |
# META VARIABLES - RAW TEXT | |
# total number of characters/ features / unique features | |
docStats <- function(x) { | |
# length of document | |
len <- nchar(x) | |
# split words | |
words <- str_trim(unlist(strsplit(x, " "))) | |
words <- words[words!=""] | |
# calculate average word length | |
nchars <- laply(words, nchar) | |
len_word <- mean(nchars) | |
# count features | |
n_feat <- length(words) | |
n_unq_feat <- length(unique(words)) | |
# return stats | |
return(data.frame(len, len_word, n_feat, n_unq_feat)) | |
} | |
features_raw <- ldply(text, docStats) | |
names(features_raw) <- paste0(names(features_raw),"_raw") | |
# CLEAN THE INPUT TEXT # | |
# convert text to corpus | |
corpus <- Corpus(VectorSource(text)) | |
# standardize case | |
if (lower_case) { | |
corpus <- tm_map(corpus, tolower) | |
} | |
# remove stopwords / numbers / punctuation / whitespace | |
if (remove_stop_words) { | |
corpus <- tm_map(corpus, tolower) | |
print("removing stop words...") | |
stop_words <- c(stopwords('english'), stop_words_to_add) | |
corpus <- tm_map(corpus, removeWords, stop_words) | |
} | |
# remove numbers / punctuation / strip whitespace | |
print("cleaning text...") | |
if (remove_numbers) { | |
corpus <- tm_map(corpus, removeNumbers) | |
} | |
if (remove_punctuation) { | |
removePunct <- function(x) { | |
gsub("[[:punct:]]", " ", x) | |
} | |
corpus <- tm_map(corpus, removePunct) | |
} | |
# remove non-ASCII characters | |
if (remove_non_ascii) { | |
removeNonASCII <- function(x) { | |
iconv(x, "latin1", "ASCII", sub="") | |
} | |
corpus <- tm_map(corpus, removeNonASCII) | |
} | |
corpus <- tm_map(corpus, stripWhitespace) | |
# filter out words that have characters longer than 255 - these will break the stemming function | |
charFilter <- function(x) { | |
words <- str_trim(unlist(strsplit(x, " "))) | |
#ensure all empty words and words with more than 50 characters are removed | |
nchars <- laply(words, nchar) | |
clean_words <- words[which(nchars <= 255)] | |
output <- paste(clean_words, collapse=" ") | |
return(output) | |
} | |
corpus <- tm_map(corpus, charFilter) | |
corpus <- tm_map(corpus, stripWhitespace) | |
# stem words | |
if(stem_words) { | |
print("stemming words...") | |
# generate stemming function | |
wordStemmer <- function(x) { | |
words <- str_trim(unlist(strsplit(x, " "))) | |
words <- words[words!=""] | |
# stem words | |
stemmed_words <- wordStem(words) | |
# collapse back into one blob | |
output <- paste(stemmed_words, collapse=" ") | |
return(output) | |
} | |
# run stemming function | |
corpus <- tm_map(corpus, wordStemmer) | |
} | |
# filter out words that fall outside of desired char_range | |
charFilter2 <- function(x) { | |
words <- str_trim(unlist(strsplit(x, " "))) | |
nchars <- laply(words, nchar) | |
clean_words <- words[which(nchars >= char_range[1] & nchars <= char_range[2])] | |
output <- str_trim(paste(clean_words, collapse=" ")) | |
return(output) | |
} | |
corpus <- tm_map(corpus, charFilter2) | |
# strip white space again for good measure | |
corpus <- tm_map(corpus, stripWhitespace) | |
# convert corpus back to character vector for lexicalizing | |
text <- as.character(corpus) | |
# META VARIABLES - CLEAN TEXT | |
# total number of characters / features / unique features | |
features_clean <- ldply(text, docStats) | |
names(features_clean) <- paste0(names(features_clean),"_clean") | |
# CREATE / FILTER LEXICON | |
# lexicalize text | |
print("lexicalizing text...") | |
corpus <- lexicalize(text, sep=" ", count=1) | |
# only keep words that appear at least twice. | |
N <- min_word_count | |
keep <- corpus$vocab[word.counts(corpus$documents, corpus$vocab) >= N] | |
# re-lexicalize, using this subsetted vocabulary | |
documents <- lexicalize(text, lower=TRUE, vocab=keep) | |
# FIT TOPICS | |
# gibbs sampling | |
# K is the number of topics | |
print("fitting topics...") | |
K <- n_topics | |
n_iter <- n_iter + burnin | |
result <- lda.collapsed.gibbs.sampler(documents, K, keep, n_iter, alpha, eta) | |
# PREPARE OUTPUT | |
print("preparing output...") | |
# top words by document | |
predictions <- t(predictive.distribution(result$document_sums, result$topics, 0.1, 0.1)) | |
document_words <- data.frame(top.topic.words(predictions, n_topic_words, by.score = TRUE)) | |
names(document_words) <- ids | |
# top words by topic | |
topic_words <- data.frame(top.topic.words(result$topics, num.words = n_topic_words, by.score = TRUE)) | |
names(topic_words) <- paste0("topic_", 1:K) | |
# topics by documents stats | |
raw <- as.data.frame(t(result$document_sums)) | |
names(raw) <- 1:K | |
n_docs <- nrow(raw) | |
topics <- data.frame(id = ids, matrix(0, nrow = n_docs, ncol=2*K)) | |
names(topics) <- c("id", paste0("n_topic_", 1:K), paste0("p_topic_", 1:K)) | |
# add assignment variables dynamically | |
topic_ass_vars <- paste0("ass_topic_", letters[1:n_assignments]) | |
topics[,topic_ass_vars] <- 0 | |
# assign primary and secondary topic(s), get distribution topics by document | |
for(doc in 1:n_docs) { | |
assignments <- as.numeric(names(sort(raw[doc,1:K], decreasing=TRUE))) | |
topics[doc, topic_ass_vars] <- assignments[1:n_assignments] | |
topics[doc, grep("n_topic_[0-9]+", names(topics))] <- raw[doc,] | |
topics[doc, grep("p_topic_[0-9]+", names(topics))] <- raw[doc,] / sum(raw[doc,]) | |
} | |
# add meta variables | |
document_stats <- data.frame(topics, features_raw, features_clean) | |
# CALCULATE JOB LENGTH | |
end <- Sys.time() | |
job_length <- round(difftime(end, start, units="mins"), digits=2) | |
print(paste("lda finished at:", end)) | |
print(paste("job took:", job_length, "minutes")) | |
# RETURN OUTPUT | |
return(list(topic_words = topic_words, | |
document_stats = document_stats, | |
document_words = document_words)) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment