Created
May 4, 2023 13:15
-
-
Save jmclawson/640042f2d679bcef1d20cf8056a66acd to your computer and use it in GitHub Desktop.
Functions for building a topic model and exploring it. Visualizations include document-level distributions (static and interactive), word distributions per topic, and topic word clouds.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(wordcloud) | |
library(topicmodels) | |
library(plotly) | |
# Moves a table of texts through the necessary | |
# steps of preparation before building a topic | |
# model. The function applies these steps: | |
# 1. identifies text divisions by the `doc_id` | |
# column | |
# 2. divides each of the texts into same-sized | |
# chunks of `sample_size` words (default | |
# is 1000 words) | |
# 3. unnests text table into a table with one | |
# word per row | |
# 4. removes stop words and proper nouns | |
# (identified as any word that only appears | |
# with a capitalized first letter) | |
# 5. counts word frequencies for each chunk | |
# 6. converts the table of frequencies into a | |
# document term matrix | |
# 7. builds a topic model with `k` topics | |
make_topic_model <- function( | |
df, | |
doc_id = title, | |
sample_size = 1000, | |
k = 15) { | |
set_doc_samples <- function( | |
df, | |
size = 1000, | |
doc_id = title, | |
set_min = NULL, | |
collapse_cols = TRUE) { | |
df <- df |> | |
group_by({{doc_id}}) |> | |
mutate(set_id = | |
ceiling(row_number()/size)) |> | |
ungroup() | |
if (!is.null(set_min)) { | |
df <- df |> | |
group_by({{doc_id}}) |> | |
mutate(set_count = n()) |> | |
filter(set_count > set_min) | |
} | |
if(collapse_cols) { | |
df <- df |> | |
unite({{doc_id}}, {{doc_id}}, set_id) | |
} | |
return(df) | |
} | |
df |> | |
unnest_without_caps() |> | |
set_doc_samples(doc_id = {{doc_id}}, | |
size = sample_size) |> | |
anti_join(get_stopwords()) |> | |
count({{doc_id}}, word, sort = TRUE) |> | |
rename( | |
document = {{doc_id}}, | |
term = word, | |
value = n) |> | |
cast_dtm(document, term, value) |> | |
LDA(k = k, | |
method = "Gibbs", | |
control = list(best = TRUE, | |
initialize = "random")) | |
} | |
visualize_document_topics <- function( | |
lda, | |
top_n = 4, | |
direct_label = TRUE, | |
title = TRUE, | |
save = TRUE, | |
saveas = "png", | |
savedir = "plots", | |
omit = NULL, | |
smooth = TRUE) { | |
df_string <- deparse(substitute(lda)) | |
plot_topic_parts <- function(df, | |
direct_label = TRUE) { | |
plot <- df |> | |
ggplot(aes(x = set, y = n)) | |
if (direct_label) { | |
plot <- plot + | |
geom_area(aes(fill = as.factor(topic)), | |
show.legend = FALSE) + | |
geom_text( | |
data = df |> | |
group_by(doc_id) |> | |
filter(set == max(set)) |> | |
arrange(desc(topic)) |> | |
mutate(n = cumsum(n), | |
set = set + (3000 * ( | |
row_number() - 1 | |
))), | |
aes(x = set + 800, | |
label = topic, | |
color = as.factor(topic)), | |
show.legend = FALSE, | |
hjust = 0 | |
) | |
} else { | |
plot <- plot + | |
geom_area(aes(fill = as.factor(topic)), | |
show.legend = TRUE) | |
} | |
plot + | |
facet_wrap(~ doc_id, | |
strip.position = "top", | |
ncol = 1, | |
labeller = labeller(groupwrap = label_wrap_gen(6))) + | |
scale_x_continuous(expand = expansion(c(0, 0.1)), | |
labels = scales::label_comma()) + | |
theme_minimal() + | |
labs(y = element_blank(), | |
x = "words", | |
fill = "topic") + | |
theme(plot.title.position = "plot", | |
strip.background = element_rect(fill = NA, color = NA), | |
strip.text = element_text(colour = "black", | |
hjust = 0), | |
panel.grid.major.x = element_blank(), | |
panel.grid.minor = element_blank()) + | |
scale_fill_viridis_d(option = "turbo") + | |
scale_color_viridis_d(option = "turbo") + | |
scale_y_continuous( | |
position = "right", | |
labels = scales::label_percent()) + | |
coord_cartesian(clip = 'off') | |
} | |
k <- attributes(lda)$k | |
plot <- lda |> | |
prep_document_topics(top_n = top_n, | |
omit = omit, | |
smooth = smooth) |> | |
plot_topic_parts(direct_label = direct_label) | |
if (title) { | |
plot <- plot + ggtitle(df_string) | |
} | |
if (save) { | |
ifelse(!dir.exists(file.path(savedir)), | |
dir.create(file.path(savedir)), | |
FALSE) | |
filename <- paste0(savedir, "/", | |
df_string, | |
" - document topics", | |
".", | |
saveas) | |
if (!saveas %in% c("pdf", "png")) { | |
ggsave(filename, | |
plot = plot, | |
dpi = 300, | |
bg = "white") | |
} else { | |
ggsave(filename, plot = plot, dpi = 300) | |
} | |
} | |
plot | |
} | |
prep_document_topics <- function( | |
lda, | |
top_n = 4, | |
omit = NULL, | |
smooth = FALSE){ | |
df_string <- deparse(substitute(lda)) | |
k <- attributes(lda)$k | |
doc_tops <- lda |> | |
tidy(matrix = "gamma") |> | |
separate(document, | |
c("title", "set"), | |
sep = "_") |> | |
mutate(set = as.integer(set), | |
ordered = TRUE) |> | |
group_by(title, topic) |> | |
mutate(topic_mean = mean(gamma, | |
na.rm = TRUE)) | |
if (!is.null(omit)) { | |
doc_tops <- doc_tops |> | |
filter(!topic %in% omit) | |
} | |
doc_tops <- doc_tops |> | |
group_by(title) |> | |
mutate(topic_rank = | |
dense_rank(-topic_mean)) |> | |
ungroup() |> | |
# by default, n = 4 commonest topics | |
filter(topic_rank <= top_n) |> | |
# combine author and text | |
mutate(doc_id = title) |> | |
group_by(doc_id, set, topic) |> | |
summarise(n = sum(gamma, na.rm = TRUE), .groups = "keep") |> | |
mutate(percentage = n / sum(n), | |
set = set * 1000) | |
top_terms <- | |
lda |> | |
tidy() |> | |
group_by(topic) |> | |
arrange(desc(beta)) |> | |
slice_head(n = 10) |> | |
summarize(words = paste0(term, collapse = ", ")) | |
result <- doc_tops |> | |
left_join(top_terms, by = "topic") |> | |
mutate(display = paste0("topic ", topic, ": ", words)) | |
if(smooth) { | |
result <- result |> | |
ungroup() |> | |
group_by(topic) |> | |
arrange(set) |> | |
# rolling average across three sets | |
mutate(n2 = lead(n), | |
n3 = lead(n, n=2L)) |> | |
ungroup() |> | |
rowwise() |> | |
mutate(n_smooth = mean(c(n, n2, n3), na.rm = TRUE), | |
.after = n3) |> | |
select(-n, -n2, -n3) |> | |
rename(n = n_smooth) | |
} | |
result | |
} | |
interactive_document_topics <- function( | |
df, | |
top_n = 4, | |
title = FALSE, | |
height = NULL, | |
omit = NULL, | |
smooth = TRUE) { | |
df_string <- deparse(substitute(df)) | |
plot <- df |> | |
prep_document_topics(top_n, omit = omit, smooth = smooth) |> | |
mutate( | |
# Shorten title to first word, dropping articles and prepositions | |
doc_id = doc_id |> | |
str_remove_all("^The\\b|^A\\b|^In the\\b|^In\\b|^To\\b") |> | |
str_remove_all("^ ") |> | |
strsplit(split = " ") |> | |
sapply(`[`, 1) |> | |
str_extract("[A-Za-z]+"), | |
topic = as.factor(topic)) |> | |
ggplot(aes(x = set, y = n)) + | |
geom_area(aes(fill = topic, | |
color = topic, | |
text = display), | |
show.legend = FALSE) + | |
facet_grid(doc_id ~ ., | |
labeller = labeller(groupwrap = label_wrap_gen(6))) + | |
scale_x_continuous(expand = expansion(c(0, 0.1)), | |
labels = scales::label_comma()) + | |
theme_minimal() + | |
labs(y = element_blank(), | |
x = "words", | |
fill = "topic") + | |
theme( | |
plot.title.position = "plot", | |
strip.background = element_rect(fill = "white", color = "white"), | |
strip.text = element_text(colour = "black", | |
hjust = 0), | |
panel.grid.major.x = element_blank(), | |
panel.grid.minor = element_blank()) + | |
scale_fill_viridis_d(alpha = 0.8) + | |
scale_color_viridis_d(alpha = 1) + | |
scale_y_continuous( | |
labels = scales::label_percent()) | |
if (title) { | |
plot <- plot + ggtitle(df_string) | |
} | |
plot |> | |
ggplotly(tooltip = "text", height = height) |> | |
hide_legend() |> | |
suppressWarnings() | |
} | |
visualize_topic_bars <- function( | |
df, | |
topics, | |
top_n = 10, | |
expand_bars = TRUE, | |
save = TRUE, | |
savedir = "plots") { | |
df_string <- deparse(substitute(df)) | |
plot <- tidy(df) |> | |
filter(topic %in% topics) |> | |
mutate(topic = paste("topic", topic) |> | |
factor(levels = paste("topic", topics))) |> | |
group_by(topic) |> | |
arrange(desc(beta)) |> | |
slice_head(n=top_n) |> | |
ungroup() |> | |
ggplot(aes(y = reorder_within(term, beta, topic), | |
x = beta)) + | |
geom_col(aes(fill = topic), | |
show.legend = FALSE) + | |
scale_y_reordered() + | |
labs(y = NULL, | |
x = NULL, | |
title = df_string) + | |
theme_minimal() + | |
theme(axis.text.x = element_blank(), | |
panel.grid = element_blank()) | |
if(expand_bars) { | |
plot <- plot + | |
facet_wrap(~ topic, scales = "free") | |
} else { | |
plot <- plot + | |
facet_wrap(~ topic, scales = "free_y") | |
} | |
if (save) { | |
ifelse(!dir.exists(file.path(savedir)), | |
dir.create(file.path(savedir)), | |
FALSE) | |
filename <- paste0(savedir,"/", | |
df_string, | |
" - topics ", | |
paste0(topics, collapse=", "), | |
".png") | |
ggsave(filename, plot=plot) | |
} | |
plot | |
} | |
visualize_topic_wordcloud <- function( | |
df, | |
topics = NULL, | |
crop = TRUE, | |
savedir = "plots") { | |
save_topic_wordcloud <- function( | |
df, | |
topics = NULL, | |
dir = "plots", | |
count = 150, | |
df_string = NULL){ | |
if(is.null(df_string)) { | |
df_string <- deparse(substitute(df)) | |
cat("df_string was null!") | |
} | |
df <- tidy(df) | |
if(!is.null(topics)) { | |
df <- df |> filter(topic %in% topics) | |
} | |
ifelse(!dir.exists(file.path(dir)), | |
dir.create(file.path(dir)), | |
FALSE) | |
for(t in unique(df$topic)){ | |
filename <- paste0(dir,"/", df_string, " - topic ", t, ".png") | |
png(filename, width = 12, | |
height = 8, units = "in", | |
res = 300) | |
wordcloud(words = df |> | |
filter(topic == t) |> | |
pull(term), | |
freq = df |> | |
filter(topic == t) |> | |
pull(beta), | |
max.words = count, | |
random.order = FALSE, | |
scale=c(3, .3), | |
rot.per = 0.2, | |
colors=viridis::turbo( | |
n=9, | |
direction =-1)[1:8]) | |
dev.off() | |
} | |
} | |
df_string <- deparse(substitute(df)) | |
save_topic_wordcloud(df, topics, df_string = df_string) | |
if (!is.null(topics)) { | |
paths <- paste0(savedir,"/", df_string, " - topic ", topics, ".png") | |
} else { | |
paths <- list.files(savedir, pattern = paste0(df_string, " - topic "), | |
full.names = TRUE) | |
} | |
if (crop) { | |
knitr::include_graphics(paths[1]) |> | |
knitr::plot_crop() | |
} else { | |
knitr::include_graphics(paths) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment