Skip to content

Instantly share code, notes, and snippets.

@jmclawson
Last active February 7, 2021 14:00
Show Gist options
  • Save jmclawson/21c6a40c78fd66d708bec45d5c0b52e2 to your computer and use it in GitHub Desktop.
Save jmclawson/21c6a40c78fd66d708bec45d5c0b52e2 to your computer and use it in GitHub Desktop.
library(tidyverse)
library(tidytext)
library(reshape2)
library(wordVectors)
##### Modeling a Corpus #####
# This process for preparing and modeling the corpus is adapted from Women Writers Project's template_word2vec.Rmd
# These adaptations should allow for for preservation of modeling settings to aid in replicability.
# After training the model, recall its setting parameters by exploring the object's attributes.
# Example 1: attributes(w2vModel)$window
# Example 2: attributes(w2vModel)$negative_samples
# Example 3: attributes(w2vModel)$vectors
readTextFiles <- function(file, path2file) {
message(file)
rawText = paste(scan(file,
sep = "\n",
what = "raw",
strip.white = TRUE))
output = tibble(filename = gsub(path2file,
"",
file),
text = rawText) %>%
group_by(filename) %>%
summarise(text = paste(rawText, collapse = " "))
return(output)
}
# If the name of the folder with your corpus differs from your model name, be sure to set source.dir
prep_model <- function(model = "w2vModel",
lowercase = TRUE,
bundle_ngrams = 1,
source.dir = NULL){
modelInput <- paste0("data/",
model,
".txt")
modelCleaned <- paste0("data/",
model,
"_cleaned.txt")
if (is.null(source.dir)) {
source.dir = paste0("data/",
model)
}
if (!file.exists(modelInput)) {
fileList <- list.files(source.dir,full.names = TRUE)
combinedTexts <- tibble(filename = fileList) %>%
group_by(filename) %>%
do(readTextFiles(.$filename, source.dir))
combinedTexts$text %>% write_lines(modelInput)
} else {message("'", getwd(), "/",
modelInput,
"' already exists.")}
if (!file.exists(modelCleaned)) {
prep_word2vec(origin = modelInput,
destination = modelCleaned,
lowercase = lowercase,
bundle_ngrams = bundle_ngrams)
} else {message("'", getwd(), "/",
modelCleaned,
"' already exists.")}
}
# For later recall, this function saves a metadata_model.Rdata file beside the model.bin in your data folder
train_model <- function(model="w2vModel",
vectors=100,
window=6,
iter=10,
negative_samples=15,
threads=3){
if(!exists(".Random.seed")) set.seed(NULL)
modelSeed <- .Random.seed
modelBin <- paste0("data/",model,".bin")
modelInput <- paste0("data/",model,".txt")
modelCleaned <- paste0("data/",model,"_cleaned.txt")
if (!file.exists(modelBin)) {
the_model <- train_word2vec(
modelCleaned,
output_file = modelBin,
vectors = vectors,
threads = threads,
window = window,
iter = iter,
negative_samples = negative_samples
)
# This metadata gets lost after first run, so store it
attributes(the_model)$vectors <- vectors
attributes(the_model)$window <- window
attributes(the_model)$iter <- iter
attributes(the_model)$negative_samples <- negative_samples
attributes(the_model)$seed <- modelSeed
model_metadata <-
c("vectors" = vectors,
"window" = window,
"iter" = iter,
"negative_samples" = negative_samples,
"seed" = modelSeed)
save(model_metadata,
file = paste0("data/metadata_",
model,
".Rdata"))
# Save this model in a global object
assign(model, the_model, envir = .GlobalEnv)
} else {
the_model <- read.vectors(modelBin)
meta_filename <- paste0("data/metadata_",model,".Rdata")
if (file.exists(meta_filename)) {
load(file = meta_filename)
attributes(the_model)$vectors <-
model_metadata["vectors"]
attributes(the_model)$window <-
model_metadata["window"]
attributes(the_model)$iter <-
model_metadata["iter"]
attributes(the_model)$negative_samples <-
model_metadata["negative_samples"]
attributes(the_model)$seed <-
model_metadata["seed"]
}
assign(model, the_model, envir = .GlobalEnv)
}
}
##### Managing the Results #####
# The get_siml() function returns distances between one word 'x' and a vector of words 'y' for model 'wem'
# Example: get_siml(w2vModel, "salty", c("food", "ocean", "attitude", "air"))
get_siml <- function(wem, x, y){
sapply(y, function(z) {
cosineSimilarity(wem[[x]],
wem[[z]]) %>%
round(9)
})
}
# The make_siml_matrix() function returns a matrix of distances between two vectors of words 'x' and 'y' in model 'wem'
# Example: make_siml_matrix(w2vModel, c("salty", "sweet", "fresh"), c("food", "ocean", "attitude", "air"))
make_siml_matrix <- function(wem, x, y){
dis_col <- read.table(text = "",
colClasses = "double",
col.names = y)
for (each in x) {
dis_col[each,] <- wem %>%
get_siml(each,y) %>%
data.frame() %>%
t()
}
rownames(dis_col) <- gsub("hue_","",rownames(dis_col))
colnames(dis_col) <- gsub("hue_","",colnames(dis_col))
as.matrix(dis_col) %>% t()
}
# The scale_matrix() function scales values in a matrix, amplifying the signal in each row and column
# Example 1: scale_matrix(my_matrix)
# Example 2: make_siml_matrix(w2vModel, my_adj, my_nouns) %>% scale_matrix()
scale_matrix <- function(x, diagonal=TRUE){
if (!diagonal) {x[x == 1] <- NA}
# scale each column 0 to 1
scale_cols <- x %>%
apply(1, function(x) {
sapply(x, function(y) {
((y - min(x, na.rm = TRUE)) /
(max(x, na.rm = TRUE) -
min(x, na.rm = TRUE))) %>%
suppressWarnings()
})
}) %>%
as.matrix()
# scale each row 0 to 1
scale_rows <- x %>%
apply(2, function(x) {
sapply(x, function(y) {
((y - min(x, na.rm = TRUE)) /
(max(x, na.rm = TRUE) -
min(x, na.rm = TRUE))) %>%
suppressWarnings()
})
}) %>%
as.matrix() %>%
t()
# add these matrices together
scale_join <- scale_cols + scale_rows
return(t(scale_join))
}
##### Mapping the Proximities #####
# For cosine_heatmap(), set x and y to words you'd like to compare, distance-wise, along x and y axes.
# For instance, x=c("salty","sweet","fresh"), y=c("food","ocean","air") or x=my_flavors, etc.
# Try setting 'labeled' to "title" or "simple"; toggle 'values' to TRUE or FALSE
# Example 1: cosine_heatmap(w2vModel, my_adj, my_nouns)
# Example 2: cosine_heatmap(w2vModel, my_words, my_words)
cosine_heatmap <- function(input,
round = 2,
legend = NULL,
values = TRUE,
redundant = TRUE,
sort.y = TRUE,
sort.x = TRUE,
limit.y = NA,
limit.x = NA,
omit.y,
top.down = TRUE,
amplify = FALSE,
diagonal = TRUE,
sort.twice = NULL,
dot.to.apost = TRUE,
colorset = "viridis",
alpha = NULL){
if (ncol(input)==1) {
if (is.null(sort.twice) && sort.y) {
sort.twice <- TRUE
}
sort.x <- FALSE
sort.y <- FALSE
}
if (is.null(alpha) && values) {
alpha <- 0.75
} else if (is.null(alpha) && !values) {
alpha <- 1
}
if (is.null(sort.twice)) {sort.twice <- FALSE}
if (!identical(rownames(input),
colnames(input))) {
if (!redundant) {
redundant <- TRUE
message("Values for 'x' and 'y' don't match. Setting 'redundant' to TRUE.")
}
}
if (amplify) {
values <- FALSE
message("Amplifying a heatmap converts measures from cosine similarity, which is meaningful, to something based wholly on terms from a single row and column. Because these values lose meaning and are useful only for comparison, they will not be displayed.")
}
the_matrix <- input
if (sort.y) {
row_order <- the_matrix %>%
apply(1, median) %>%
sort(decreasing = TRUE) %>%
names()
the_matrix <- the_matrix[row_order,]
}
if (!missing(omit.y)) {
not_i <- omit.y %>%
paste0(collapse = "|") %>%
grep(rownames(the_matrix))
the_matrix <- the_matrix[-not_i,]
}
if (!is.na(limit.y)) {
the_matrix <- the_matrix[1:limit.y,]
}
if (sort.x) {
col_order <- the_matrix %>%
apply(2, median) %>%
sort(decreasing = TRUE) %>%
names()
the_matrix <- the_matrix[,col_order]
}
if (amplify) {
the_matrix <- the_matrix %>%
scale_matrix(diagonal)
}
if (!redundant) {
the_matrix[upper.tri(the_matrix)] <- NA
the_matrix <- the_matrix %>% melt(na.rm = TRUE)
} else {
the_matrix <- the_matrix %>% melt()
}
# this is an undocumented option. Perhaps cut it?
if (dot.to.apost) {
the_matrix$Var1 <-
factor(the_matrix$Var1,
labels = gsub("\\.", "'",
the_matrix$Var1) %>%
unique(),
ordered = TRUE)
the_matrix$Var2 <-
factor(the_matrix$Var2,
labels = gsub("\\.", "'",
the_matrix$Var2) %>%
unique(),
ordered = TRUE)
}
if (!is.na(limit.x)) {
the_matrix <- the_matrix %>%
filter(Var2 %in% levels(the_matrix$Var2)[1:limit.x])
}
the_matrix <<- the_matrix
if (sort.twice) {
col_order <- the_matrix %>%
group_by(Var2) %>%
summarise(value = median(value),
.groups = 'drop') %>%
arrange(desc(value)) %>%
.$Var2 %>%
as.character()
row_order <- the_matrix %>%
group_by(Var1) %>%
summarise(value = median(value),
.groups = 'drop') %>%
arrange(desc(value)) %>%
.$Var1 %>%
as.character()
the_matrix$Var2 <- the_matrix$Var2 %>%
factor(levels = col_order)
the_matrix$Var1 <- the_matrix$Var1 %>%
factor(levels = row_order)
}
if (!top.down) {
the_plot <- the_matrix %>%
ggplot(aes(x=Var2, y=Var1, fill=value))
} else {
the_plot <- the_matrix %>%
ggplot(aes(x = Var2,
y = reorder(Var1, desc(Var1)),
fill = value))
}
the_plot <- the_plot +
geom_tile(color = "white")
if (amplify) {
the_plot <- the_plot +
scale_fill_gradient2(low = "blue",
high = "red",
mid = "white",
midpoint = 1,
limit = c(0,2),
name = element_blank(),
breaks = c(0,2),
labels = c("far", "near"))
} else if (colorset == "red") {
the_plot <- the_plot +
scale_fill_gradient2(low = "blue",
high = "red",
mid = "white",
midpoint = 0,
limit = c(-1,1),
name = "Similarity")
} else {
the_plot <- the_plot +
scale_fill_viridis_c(option = colorset,
limits = c(0, 1),
alpha = alpha,
name = "Similarity")
# scale_fill_gradient2(low = "blue",
# high = "red",
# mid = "white",
# midpoint = 0,
# limit = c(-1,1),
# name = "Similarity")
}
the_plot <- the_plot +
theme_bw() +
theme(axis.text.x = element_text(angle = 45,
vjust = 1,
#size = 12,
hjust = 1),
panel.grid.major = element_blank(),
panel.border = element_blank(),
panel.background = element_blank(),
axis.ticks = element_blank()) +
labs(x=element_blank(),
y=element_blank())
if(values) {
the_plot <- the_plot +
geom_text(aes(label = format(round(value,
round),
nsmall = round) %>%
gsub(pattern = "^[ ]?(-?)0.",
replacement = "\\1.",
x = .)),
color = "black")
if (is.null(legend)) {
legend <- FALSE
}
}
if (is.null(legend)) {
legend <- TRUE
}
if (!legend) {
the_plot <- the_plot +
guides(fill = "none")
}
if (top.down && redundant) {
the_plot <- the_plot +
scale_x_discrete(position = "top") +
theme(axis.text.x = element_text(hjust = 0,
#angle=90
))
}
if (!top.down && !redundant) {
the_plot <- the_plot +
scale_x_discrete(position = "top") +
theme(axis.text.x = element_text(hjust = 0,
#angle=90
))
}
the_plot
}
# The amplified_heatmap() function merely amplifies signals it finds; it doesn't validate these signals.
# Try setting 'labeled' to "title" or "simple".
# Example 1: amplified_heatmap(w2vModel, my_adj, my_nouns)
# Example 2: amplified_heatmap(w2vModel, my_words, my_words)
amplified_heatmap <- function(wem, x, y,
labeled = "title",
legend = TRUE,
diagonal = TRUE,
redundant = TRUE){
if (!identical(x, y)) {
if (!redundant) {
redundant <- TRUE
cat("Values for 'x' and 'y' don't match. Setting 'redundant' to TRUE.")
}
}
if (!redundant) {diagonal <- TRUE}
the_matrix <- make_siml_matrix(wem, x, y) %>%
scale_matrix(diagonal)
if (!redundant) {
the_matrix[lower.tri(the_matrix)] <- NA
the_matrix <- the_matrix %>% melt(na.rm = TRUE)
} else {
the_matrix <- the_matrix %>% melt()
}
the_plot <- the_matrix %>%
ggplot(aes(x = Var2,
y = reorder(Var1, desc(Var1)),
fill = value)) +
geom_tile(color = "white") +
scale_fill_gradient2(low = "blue",
high = "red",
mid = "white",
midpoint = 1,
limit = c(0,2),
name = element_blank(),
breaks = c(0,2),
labels = c("far", "near")) +
theme_bw() +
theme(axis.text.x = element_text(angle = 45,
vjust = 1,
# size = 12,
hjust = 1),
panel.grid.major = element_blank(),
panel.border = element_blank(),
panel.background = element_blank(),
axis.ticks = element_blank()) +
if(labeled == "title") {
labs(title = paste0("Comparing '",
deparse(substitute(y)),
"' by '",
deparse(substitute(x)),
"' in ",
deparse(substitute(wem))),
x = element_blank(),
y = element_blank())
} else if (labeled == "simple") {
labs(title = deparse(substitute(wem)),
x = element_blank(),
y = element_blank())
} else {
labs(title = deparse(substitute(wem)),
x = deparse(substitute(x)),
y = deparse(substitute(y)))
}
if (!legend) {
the_plot <- the_plot +
guides(fill = "none")
}
the_plot
}
cosine_bars <- function(input, round = 2,
sort.y = TRUE,
sort.x = TRUE,
wrap_cols = NULL,
dot.to.apost = TRUE,
force.width = TRUE,
colorset = "viridis",
alpha = NULL) {
if (is.null(alpha)) {
alpha <- 1
}
input_class <- input %>% class() %>% .[1]
if (input_class == "data.frame") {
input$compare <- input %>%
colnames() %>%
.[2] %>%
gsub(pattern = "similarity to ",
replacement = "",
x = .)
colnames(input) <- c("word", "value","compare")
input <- input %>% select(word, compare, value)
if (sort.y) {
input$word <-
factor(input$word,
levels = input$word[order(input$value,
decreasing = FALSE)])
}
input_temp <<- input
} else if (input_class == "matrix") {
if(ncol(input)>1){
if (sort.x) {
col_order <- input %>%
apply(2,median) %>%
sort(decreasing = TRUE) %>%
names()
input <- input[,col_order]
}
}
input <- input %>%
data.frame() %>%
rownames_to_column("rn") %>%
pivot_longer(-rn)
colnames(input) <- c("word", "compare", "value")
if (sort.y) {
input <- input %>%
group_by(word) %>%
mutate(med_value = median(value)) %>%
ungroup() %>%
arrange(med_value)
input$word <- factor(input$word,
levels=unique(input$word))
input <- input %>% select(word, compare, value)
}
input_temp <<- input
}
if (dot.to.apost) {
input$word <-
factor(input$word,
labels = gsub("\\.", "'",
levels(input$word)),
ordered = TRUE)
input$compare <- gsub("\\.", "'",
input$compare)
input$compare <- input$compare %>%
factor(levels = input$compare %>%
unique(),
ordered = TRUE)
}
input2 <<- input
the_plot <- ggplot(input,
aes(y = word,
x = value,
fill = value)) +
geom_col(aes(color=ifelse(round(value,1)<0.1,"gray","blank")),
alpha = alpha,
show.legend = FALSE) +
# geom_text(aes(y=ifelse(value-0.03<=0,value+0.01,value-0.01),
# hjust = ifelse(value-0.03<=0,
# 0,
# 1),
geom_text(aes(x=ifelse(value<=0.9,value+0.01,value-0.01),
hjust = ifelse(value<=0.9,0,1),
label = format(round(value,
round),
nsmall = round) %>%
gsub(pattern = "^[ ]?(-?)0.",
replacement = "\\1.",
x = .),
color=ifelse(round(value,1)>0.9,"white","black")), show.legend = FALSE,
# nudge_y = -0.01,
# hjust = 1
) +
scale_color_manual(values=c(blank="white",
white="black",
black="black",
gray="#f5f5f5")) +
# coord_flip() +
labs(x = element_blank(),
y = element_blank()) +
theme_minimal()
if (force.width) {
the_plot <- the_plot +
scale_x_continuous(limits = c(0,1),
expand = expansion(mult = c(0, 0)))
} else {
the_plot <- the_plot +
scale_x_continuous(expand = expansion(mult = c(0, 0.08)))
}
the_plot <- the_plot +
theme(panel.grid = element_blank(),
axis.text.x = element_blank())
if (colorset == "red") {
the_plot <- the_plot +
scale_fill_gradient(low = "white",
high = "red",
limit = c(0,1),
name = "Similarity")
} else {
the_plot <- the_plot +
scale_fill_viridis_c(option=colorset,
limits = c(0, 1),
alpha = 0.8)
}
if (input$compare %>% unique() %>% length() > 1) {
if(!is.null(wrap_cols)) {
the_plot <- the_plot +
facet_wrap("compare", ncol=wrap_cols)
} else {
the_plot <- the_plot +
facet_wrap("compare")
}
} else {
the_plot <- the_plot +
ggtitle(input$compare %>% unique())
}
the_plot
}
@jmclawson
Copy link
Author

jmclawson commented Jul 30, 2019

This code builds on a workflow shared by the Women Writers Project, which relies on Benjamin Schmidt's wordVectors package. To use these functions in R, paste the following line into the console or a code chunk: devtools::source_gist("21c6a40c78fd66d708bec45d5c0b52e2"). Alternatively, copy the functions from the gist and save them locally to run them from your machine.

  1. Put corpus files in "data\YourModelName" and then prep the texts with the function prep_model(model="YourModelName").
  • Alternatively, put them in some other directory and set the path to that directory with the source.dir parameter.
  • Optionally, set the bundle_ngrams parameter to some other value to combine common phrases using underscores, and set lowercase=FALSE to retain uppercase characters. These settings are passed along to the prep_word2vec() function from the wordVectors package.
  1. Train the model with the function train_model(model="YourModelName").
  • Setting defaults are vectors=100, window=6, iter=10, negative_samples=15, threads=3, but these can each be changed within the train_model() call. These settings are passed along to the train_word2vec() function from the wordVectors package.
  1. Recall these setting parameters with, e.g., attributes(YourModelName)$window.
  2. Prepare a similarity matrix of cosine distances among terms in two vectors or groups of words by using the make_siml_matrix() command: make_siml_matrix(YourModelName, c("red","green","blue","yellow"), c("air","water","earth","fire"))
  3. Visualize a bar chart of cosine distances by passing results of a closest_to() function (from the wordVectors package) to the cosine_bars() function. Visualize a set of bar charts by passing results from make_siml_matrix().
  4. Visualize a heatmap of cosine similarities between two groups of words by passing the results from make_siml_matrix() to cosine_heatmap().

Everything is explained further in a blog post with examples here: https://jmclawson.net/blog/posts/word-vector-utilities/
Updates are explained in a blog post with examples here: https://jmclawson.net/blog/posts/updates-to-word-vector-utilities/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment