Last active
February 7, 2021 14:00
-
-
Save jmclawson/21c6a40c78fd66d708bec45d5c0b52e2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(tidyverse) | |
library(tidytext) | |
library(reshape2) | |
library(wordVectors) | |
##### Modeling a Corpus ##### | |
# This process for preparing and modeling the corpus is adapted from Women Writers Project's template_word2vec.Rmd | |
# These adaptations should allow for for preservation of modeling settings to aid in replicability. | |
# After training the model, recall its setting parameters by exploring the object's attributes. | |
# Example 1: attributes(w2vModel)$window | |
# Example 2: attributes(w2vModel)$negative_samples | |
# Example 3: attributes(w2vModel)$vectors | |
readTextFiles <- function(file, path2file) { | |
message(file) | |
rawText = paste(scan(file, | |
sep = "\n", | |
what = "raw", | |
strip.white = TRUE)) | |
output = tibble(filename = gsub(path2file, | |
"", | |
file), | |
text = rawText) %>% | |
group_by(filename) %>% | |
summarise(text = paste(rawText, collapse = " ")) | |
return(output) | |
} | |
# If the name of the folder with your corpus differs from your model name, be sure to set source.dir | |
prep_model <- function(model = "w2vModel", | |
lowercase = TRUE, | |
bundle_ngrams = 1, | |
source.dir = NULL){ | |
modelInput <- paste0("data/", | |
model, | |
".txt") | |
modelCleaned <- paste0("data/", | |
model, | |
"_cleaned.txt") | |
if (is.null(source.dir)) { | |
source.dir = paste0("data/", | |
model) | |
} | |
if (!file.exists(modelInput)) { | |
fileList <- list.files(source.dir,full.names = TRUE) | |
combinedTexts <- tibble(filename = fileList) %>% | |
group_by(filename) %>% | |
do(readTextFiles(.$filename, source.dir)) | |
combinedTexts$text %>% write_lines(modelInput) | |
} else {message("'", getwd(), "/", | |
modelInput, | |
"' already exists.")} | |
if (!file.exists(modelCleaned)) { | |
prep_word2vec(origin = modelInput, | |
destination = modelCleaned, | |
lowercase = lowercase, | |
bundle_ngrams = bundle_ngrams) | |
} else {message("'", getwd(), "/", | |
modelCleaned, | |
"' already exists.")} | |
} | |
# For later recall, this function saves a metadata_model.Rdata file beside the model.bin in your data folder | |
train_model <- function(model="w2vModel", | |
vectors=100, | |
window=6, | |
iter=10, | |
negative_samples=15, | |
threads=3){ | |
if(!exists(".Random.seed")) set.seed(NULL) | |
modelSeed <- .Random.seed | |
modelBin <- paste0("data/",model,".bin") | |
modelInput <- paste0("data/",model,".txt") | |
modelCleaned <- paste0("data/",model,"_cleaned.txt") | |
if (!file.exists(modelBin)) { | |
the_model <- train_word2vec( | |
modelCleaned, | |
output_file = modelBin, | |
vectors = vectors, | |
threads = threads, | |
window = window, | |
iter = iter, | |
negative_samples = negative_samples | |
) | |
# This metadata gets lost after first run, so store it | |
attributes(the_model)$vectors <- vectors | |
attributes(the_model)$window <- window | |
attributes(the_model)$iter <- iter | |
attributes(the_model)$negative_samples <- negative_samples | |
attributes(the_model)$seed <- modelSeed | |
model_metadata <- | |
c("vectors" = vectors, | |
"window" = window, | |
"iter" = iter, | |
"negative_samples" = negative_samples, | |
"seed" = modelSeed) | |
save(model_metadata, | |
file = paste0("data/metadata_", | |
model, | |
".Rdata")) | |
# Save this model in a global object | |
assign(model, the_model, envir = .GlobalEnv) | |
} else { | |
the_model <- read.vectors(modelBin) | |
meta_filename <- paste0("data/metadata_",model,".Rdata") | |
if (file.exists(meta_filename)) { | |
load(file = meta_filename) | |
attributes(the_model)$vectors <- | |
model_metadata["vectors"] | |
attributes(the_model)$window <- | |
model_metadata["window"] | |
attributes(the_model)$iter <- | |
model_metadata["iter"] | |
attributes(the_model)$negative_samples <- | |
model_metadata["negative_samples"] | |
attributes(the_model)$seed <- | |
model_metadata["seed"] | |
} | |
assign(model, the_model, envir = .GlobalEnv) | |
} | |
} | |
##### Managing the Results ##### | |
# The get_siml() function returns distances between one word 'x' and a vector of words 'y' for model 'wem' | |
# Example: get_siml(w2vModel, "salty", c("food", "ocean", "attitude", "air")) | |
get_siml <- function(wem, x, y){ | |
sapply(y, function(z) { | |
cosineSimilarity(wem[[x]], | |
wem[[z]]) %>% | |
round(9) | |
}) | |
} | |
# The make_siml_matrix() function returns a matrix of distances between two vectors of words 'x' and 'y' in model 'wem' | |
# Example: make_siml_matrix(w2vModel, c("salty", "sweet", "fresh"), c("food", "ocean", "attitude", "air")) | |
make_siml_matrix <- function(wem, x, y){ | |
dis_col <- read.table(text = "", | |
colClasses = "double", | |
col.names = y) | |
for (each in x) { | |
dis_col[each,] <- wem %>% | |
get_siml(each,y) %>% | |
data.frame() %>% | |
t() | |
} | |
rownames(dis_col) <- gsub("hue_","",rownames(dis_col)) | |
colnames(dis_col) <- gsub("hue_","",colnames(dis_col)) | |
as.matrix(dis_col) %>% t() | |
} | |
# The scale_matrix() function scales values in a matrix, amplifying the signal in each row and column | |
# Example 1: scale_matrix(my_matrix) | |
# Example 2: make_siml_matrix(w2vModel, my_adj, my_nouns) %>% scale_matrix() | |
scale_matrix <- function(x, diagonal=TRUE){ | |
if (!diagonal) {x[x == 1] <- NA} | |
# scale each column 0 to 1 | |
scale_cols <- x %>% | |
apply(1, function(x) { | |
sapply(x, function(y) { | |
((y - min(x, na.rm = TRUE)) / | |
(max(x, na.rm = TRUE) - | |
min(x, na.rm = TRUE))) %>% | |
suppressWarnings() | |
}) | |
}) %>% | |
as.matrix() | |
# scale each row 0 to 1 | |
scale_rows <- x %>% | |
apply(2, function(x) { | |
sapply(x, function(y) { | |
((y - min(x, na.rm = TRUE)) / | |
(max(x, na.rm = TRUE) - | |
min(x, na.rm = TRUE))) %>% | |
suppressWarnings() | |
}) | |
}) %>% | |
as.matrix() %>% | |
t() | |
# add these matrices together | |
scale_join <- scale_cols + scale_rows | |
return(t(scale_join)) | |
} | |
##### Mapping the Proximities ##### | |
# For cosine_heatmap(), set x and y to words you'd like to compare, distance-wise, along x and y axes. | |
# For instance, x=c("salty","sweet","fresh"), y=c("food","ocean","air") or x=my_flavors, etc. | |
# Try setting 'labeled' to "title" or "simple"; toggle 'values' to TRUE or FALSE | |
# Example 1: cosine_heatmap(w2vModel, my_adj, my_nouns) | |
# Example 2: cosine_heatmap(w2vModel, my_words, my_words) | |
cosine_heatmap <- function(input, | |
round = 2, | |
legend = NULL, | |
values = TRUE, | |
redundant = TRUE, | |
sort.y = TRUE, | |
sort.x = TRUE, | |
limit.y = NA, | |
limit.x = NA, | |
omit.y, | |
top.down = TRUE, | |
amplify = FALSE, | |
diagonal = TRUE, | |
sort.twice = NULL, | |
dot.to.apost = TRUE, | |
colorset = "viridis", | |
alpha = NULL){ | |
if (ncol(input)==1) { | |
if (is.null(sort.twice) && sort.y) { | |
sort.twice <- TRUE | |
} | |
sort.x <- FALSE | |
sort.y <- FALSE | |
} | |
if (is.null(alpha) && values) { | |
alpha <- 0.75 | |
} else if (is.null(alpha) && !values) { | |
alpha <- 1 | |
} | |
if (is.null(sort.twice)) {sort.twice <- FALSE} | |
if (!identical(rownames(input), | |
colnames(input))) { | |
if (!redundant) { | |
redundant <- TRUE | |
message("Values for 'x' and 'y' don't match. Setting 'redundant' to TRUE.") | |
} | |
} | |
if (amplify) { | |
values <- FALSE | |
message("Amplifying a heatmap converts measures from cosine similarity, which is meaningful, to something based wholly on terms from a single row and column. Because these values lose meaning and are useful only for comparison, they will not be displayed.") | |
} | |
the_matrix <- input | |
if (sort.y) { | |
row_order <- the_matrix %>% | |
apply(1, median) %>% | |
sort(decreasing = TRUE) %>% | |
names() | |
the_matrix <- the_matrix[row_order,] | |
} | |
if (!missing(omit.y)) { | |
not_i <- omit.y %>% | |
paste0(collapse = "|") %>% | |
grep(rownames(the_matrix)) | |
the_matrix <- the_matrix[-not_i,] | |
} | |
if (!is.na(limit.y)) { | |
the_matrix <- the_matrix[1:limit.y,] | |
} | |
if (sort.x) { | |
col_order <- the_matrix %>% | |
apply(2, median) %>% | |
sort(decreasing = TRUE) %>% | |
names() | |
the_matrix <- the_matrix[,col_order] | |
} | |
if (amplify) { | |
the_matrix <- the_matrix %>% | |
scale_matrix(diagonal) | |
} | |
if (!redundant) { | |
the_matrix[upper.tri(the_matrix)] <- NA | |
the_matrix <- the_matrix %>% melt(na.rm = TRUE) | |
} else { | |
the_matrix <- the_matrix %>% melt() | |
} | |
# this is an undocumented option. Perhaps cut it? | |
if (dot.to.apost) { | |
the_matrix$Var1 <- | |
factor(the_matrix$Var1, | |
labels = gsub("\\.", "'", | |
the_matrix$Var1) %>% | |
unique(), | |
ordered = TRUE) | |
the_matrix$Var2 <- | |
factor(the_matrix$Var2, | |
labels = gsub("\\.", "'", | |
the_matrix$Var2) %>% | |
unique(), | |
ordered = TRUE) | |
} | |
if (!is.na(limit.x)) { | |
the_matrix <- the_matrix %>% | |
filter(Var2 %in% levels(the_matrix$Var2)[1:limit.x]) | |
} | |
the_matrix <<- the_matrix | |
if (sort.twice) { | |
col_order <- the_matrix %>% | |
group_by(Var2) %>% | |
summarise(value = median(value), | |
.groups = 'drop') %>% | |
arrange(desc(value)) %>% | |
.$Var2 %>% | |
as.character() | |
row_order <- the_matrix %>% | |
group_by(Var1) %>% | |
summarise(value = median(value), | |
.groups = 'drop') %>% | |
arrange(desc(value)) %>% | |
.$Var1 %>% | |
as.character() | |
the_matrix$Var2 <- the_matrix$Var2 %>% | |
factor(levels = col_order) | |
the_matrix$Var1 <- the_matrix$Var1 %>% | |
factor(levels = row_order) | |
} | |
if (!top.down) { | |
the_plot <- the_matrix %>% | |
ggplot(aes(x=Var2, y=Var1, fill=value)) | |
} else { | |
the_plot <- the_matrix %>% | |
ggplot(aes(x = Var2, | |
y = reorder(Var1, desc(Var1)), | |
fill = value)) | |
} | |
the_plot <- the_plot + | |
geom_tile(color = "white") | |
if (amplify) { | |
the_plot <- the_plot + | |
scale_fill_gradient2(low = "blue", | |
high = "red", | |
mid = "white", | |
midpoint = 1, | |
limit = c(0,2), | |
name = element_blank(), | |
breaks = c(0,2), | |
labels = c("far", "near")) | |
} else if (colorset == "red") { | |
the_plot <- the_plot + | |
scale_fill_gradient2(low = "blue", | |
high = "red", | |
mid = "white", | |
midpoint = 0, | |
limit = c(-1,1), | |
name = "Similarity") | |
} else { | |
the_plot <- the_plot + | |
scale_fill_viridis_c(option = colorset, | |
limits = c(0, 1), | |
alpha = alpha, | |
name = "Similarity") | |
# scale_fill_gradient2(low = "blue", | |
# high = "red", | |
# mid = "white", | |
# midpoint = 0, | |
# limit = c(-1,1), | |
# name = "Similarity") | |
} | |
the_plot <- the_plot + | |
theme_bw() + | |
theme(axis.text.x = element_text(angle = 45, | |
vjust = 1, | |
#size = 12, | |
hjust = 1), | |
panel.grid.major = element_blank(), | |
panel.border = element_blank(), | |
panel.background = element_blank(), | |
axis.ticks = element_blank()) + | |
labs(x=element_blank(), | |
y=element_blank()) | |
if(values) { | |
the_plot <- the_plot + | |
geom_text(aes(label = format(round(value, | |
round), | |
nsmall = round) %>% | |
gsub(pattern = "^[ ]?(-?)0.", | |
replacement = "\\1.", | |
x = .)), | |
color = "black") | |
if (is.null(legend)) { | |
legend <- FALSE | |
} | |
} | |
if (is.null(legend)) { | |
legend <- TRUE | |
} | |
if (!legend) { | |
the_plot <- the_plot + | |
guides(fill = "none") | |
} | |
if (top.down && redundant) { | |
the_plot <- the_plot + | |
scale_x_discrete(position = "top") + | |
theme(axis.text.x = element_text(hjust = 0, | |
#angle=90 | |
)) | |
} | |
if (!top.down && !redundant) { | |
the_plot <- the_plot + | |
scale_x_discrete(position = "top") + | |
theme(axis.text.x = element_text(hjust = 0, | |
#angle=90 | |
)) | |
} | |
the_plot | |
} | |
# The amplified_heatmap() function merely amplifies signals it finds; it doesn't validate these signals. | |
# Try setting 'labeled' to "title" or "simple". | |
# Example 1: amplified_heatmap(w2vModel, my_adj, my_nouns) | |
# Example 2: amplified_heatmap(w2vModel, my_words, my_words) | |
amplified_heatmap <- function(wem, x, y, | |
labeled = "title", | |
legend = TRUE, | |
diagonal = TRUE, | |
redundant = TRUE){ | |
if (!identical(x, y)) { | |
if (!redundant) { | |
redundant <- TRUE | |
cat("Values for 'x' and 'y' don't match. Setting 'redundant' to TRUE.") | |
} | |
} | |
if (!redundant) {diagonal <- TRUE} | |
the_matrix <- make_siml_matrix(wem, x, y) %>% | |
scale_matrix(diagonal) | |
if (!redundant) { | |
the_matrix[lower.tri(the_matrix)] <- NA | |
the_matrix <- the_matrix %>% melt(na.rm = TRUE) | |
} else { | |
the_matrix <- the_matrix %>% melt() | |
} | |
the_plot <- the_matrix %>% | |
ggplot(aes(x = Var2, | |
y = reorder(Var1, desc(Var1)), | |
fill = value)) + | |
geom_tile(color = "white") + | |
scale_fill_gradient2(low = "blue", | |
high = "red", | |
mid = "white", | |
midpoint = 1, | |
limit = c(0,2), | |
name = element_blank(), | |
breaks = c(0,2), | |
labels = c("far", "near")) + | |
theme_bw() + | |
theme(axis.text.x = element_text(angle = 45, | |
vjust = 1, | |
# size = 12, | |
hjust = 1), | |
panel.grid.major = element_blank(), | |
panel.border = element_blank(), | |
panel.background = element_blank(), | |
axis.ticks = element_blank()) + | |
if(labeled == "title") { | |
labs(title = paste0("Comparing '", | |
deparse(substitute(y)), | |
"' by '", | |
deparse(substitute(x)), | |
"' in ", | |
deparse(substitute(wem))), | |
x = element_blank(), | |
y = element_blank()) | |
} else if (labeled == "simple") { | |
labs(title = deparse(substitute(wem)), | |
x = element_blank(), | |
y = element_blank()) | |
} else { | |
labs(title = deparse(substitute(wem)), | |
x = deparse(substitute(x)), | |
y = deparse(substitute(y))) | |
} | |
if (!legend) { | |
the_plot <- the_plot + | |
guides(fill = "none") | |
} | |
the_plot | |
} | |
cosine_bars <- function(input, round = 2, | |
sort.y = TRUE, | |
sort.x = TRUE, | |
wrap_cols = NULL, | |
dot.to.apost = TRUE, | |
force.width = TRUE, | |
colorset = "viridis", | |
alpha = NULL) { | |
if (is.null(alpha)) { | |
alpha <- 1 | |
} | |
input_class <- input %>% class() %>% .[1] | |
if (input_class == "data.frame") { | |
input$compare <- input %>% | |
colnames() %>% | |
.[2] %>% | |
gsub(pattern = "similarity to ", | |
replacement = "", | |
x = .) | |
colnames(input) <- c("word", "value","compare") | |
input <- input %>% select(word, compare, value) | |
if (sort.y) { | |
input$word <- | |
factor(input$word, | |
levels = input$word[order(input$value, | |
decreasing = FALSE)]) | |
} | |
input_temp <<- input | |
} else if (input_class == "matrix") { | |
if(ncol(input)>1){ | |
if (sort.x) { | |
col_order <- input %>% | |
apply(2,median) %>% | |
sort(decreasing = TRUE) %>% | |
names() | |
input <- input[,col_order] | |
} | |
} | |
input <- input %>% | |
data.frame() %>% | |
rownames_to_column("rn") %>% | |
pivot_longer(-rn) | |
colnames(input) <- c("word", "compare", "value") | |
if (sort.y) { | |
input <- input %>% | |
group_by(word) %>% | |
mutate(med_value = median(value)) %>% | |
ungroup() %>% | |
arrange(med_value) | |
input$word <- factor(input$word, | |
levels=unique(input$word)) | |
input <- input %>% select(word, compare, value) | |
} | |
input_temp <<- input | |
} | |
if (dot.to.apost) { | |
input$word <- | |
factor(input$word, | |
labels = gsub("\\.", "'", | |
levels(input$word)), | |
ordered = TRUE) | |
input$compare <- gsub("\\.", "'", | |
input$compare) | |
input$compare <- input$compare %>% | |
factor(levels = input$compare %>% | |
unique(), | |
ordered = TRUE) | |
} | |
input2 <<- input | |
the_plot <- ggplot(input, | |
aes(y = word, | |
x = value, | |
fill = value)) + | |
geom_col(aes(color=ifelse(round(value,1)<0.1,"gray","blank")), | |
alpha = alpha, | |
show.legend = FALSE) + | |
# geom_text(aes(y=ifelse(value-0.03<=0,value+0.01,value-0.01), | |
# hjust = ifelse(value-0.03<=0, | |
# 0, | |
# 1), | |
geom_text(aes(x=ifelse(value<=0.9,value+0.01,value-0.01), | |
hjust = ifelse(value<=0.9,0,1), | |
label = format(round(value, | |
round), | |
nsmall = round) %>% | |
gsub(pattern = "^[ ]?(-?)0.", | |
replacement = "\\1.", | |
x = .), | |
color=ifelse(round(value,1)>0.9,"white","black")), show.legend = FALSE, | |
# nudge_y = -0.01, | |
# hjust = 1 | |
) + | |
scale_color_manual(values=c(blank="white", | |
white="black", | |
black="black", | |
gray="#f5f5f5")) + | |
# coord_flip() + | |
labs(x = element_blank(), | |
y = element_blank()) + | |
theme_minimal() | |
if (force.width) { | |
the_plot <- the_plot + | |
scale_x_continuous(limits = c(0,1), | |
expand = expansion(mult = c(0, 0))) | |
} else { | |
the_plot <- the_plot + | |
scale_x_continuous(expand = expansion(mult = c(0, 0.08))) | |
} | |
the_plot <- the_plot + | |
theme(panel.grid = element_blank(), | |
axis.text.x = element_blank()) | |
if (colorset == "red") { | |
the_plot <- the_plot + | |
scale_fill_gradient(low = "white", | |
high = "red", | |
limit = c(0,1), | |
name = "Similarity") | |
} else { | |
the_plot <- the_plot + | |
scale_fill_viridis_c(option=colorset, | |
limits = c(0, 1), | |
alpha = 0.8) | |
} | |
if (input$compare %>% unique() %>% length() > 1) { | |
if(!is.null(wrap_cols)) { | |
the_plot <- the_plot + | |
facet_wrap("compare", ncol=wrap_cols) | |
} else { | |
the_plot <- the_plot + | |
facet_wrap("compare") | |
} | |
} else { | |
the_plot <- the_plot + | |
ggtitle(input$compare %>% unique()) | |
} | |
the_plot | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This code builds on a workflow shared by the Women Writers Project, which relies on Benjamin Schmidt's
wordVectors
package. To use these functions in R, paste the following line into the console or a code chunk:devtools::source_gist("21c6a40c78fd66d708bec45d5c0b52e2")
. Alternatively, copy the functions from the gist and save them locally to run them from your machine.prep_model(model="YourModelName")
.source.dir
parameter.bundle_ngrams
parameter to some other value to combine common phrases using underscores, and setlowercase=FALSE
to retain uppercase characters. These settings are passed along to theprep_word2vec()
function from thewordVectors
package.train_model(model="YourModelName")
.vectors=100, window=6, iter=10, negative_samples=15, threads=3
, but these can each be changed within thetrain_model()
call. These settings are passed along to thetrain_word2vec()
function from thewordVectors
package.attributes(YourModelName)$window
.make_siml_matrix()
command:make_siml_matrix(YourModelName, c("red","green","blue","yellow"), c("air","water","earth","fire"))
closest_to()
function (from thewordVectors
package) to thecosine_bars()
function. Visualize a set of bar charts by passing results frommake_siml_matrix()
.make_siml_matrix()
tocosine_heatmap()
.Everything is explained further in a blog post with examples here: https://jmclawson.net/blog/posts/word-vector-utilities/
Updates are explained in a blog post with examples here: https://jmclawson.net/blog/posts/updates-to-word-vector-utilities/