Skip to content

Instantly share code, notes, and snippets.

@agoldst
Last active November 12, 2023 07:23
Show Gist options
  • Save agoldst/edcfd45b5ac371296b76 to your computer and use it in GitHub Desktop.
Save agoldst/edcfd45b5ac371296b76 to your computer and use it in GitHub Desktop.
Functions for using MALLET's topic-inference capability from R: given an existing topic model, estimate topic proportions for new documents
# mallet-inference.R
#
# functions for using MALLET's topic-inference functionality: given an
# existing topic model, estimate topic proportions for new documents
#
# source() this file
#
# Workflow
# --------
#
# 1. Create instances-list object of base corpus
# (or load from disk with litdata::read_mallet_instances)
# 2. Create topic model of base corpus
# 3. Get inferencer object for the model with inferencer()
# (or, having done this earlier and saved it with write_inferencer(),
# load it from disk)
# 4. Use compatible_instances() to create instances-list object of new corpus
# (or, having done this, load it from disk)
# 5. Use infer_topics() to infer topics for new corpus
#
# Step (4) can be done any time after step (1).
#
# Last update: 2015-05-19 by AG
# -----
# given a trained model object, return its topic inferencer object
#
# model: model object from MalletLDA(). Run model$train() first.
#
# returns a reference to a topic inferencer object
inferencer <- function (model) {
model$model$getInferencer()
}
# -----
# save an inferencer object to a file
#
# inf: a reference to a topic inferencer, from inferencer()
#
# out_file: the name of a file to save to (will overwrite an existing file)
write_inferencer <- function (inf, out_file) {
fos <- .jnew("java/io/FileOutputStream", out_file)
oos <- .jnew("java/io/ObjectOutputStream",
.jcast(fos, "java/io/OutputStream"))
oos$writeObject(inf)
oos$close()
}
# -----
# retrieve an inferencer object from a file
#
# returns a reference to a topic inferencer object
read_inferencer <- function (in_file) {
J("cc.mallet.topics.TopicInferencer")$read(
new(J("java.io.File"), in_file)
)
}
# -----
# infer document topics. This is like the Gibbs sampling process for making a
# topic model, but the topic-word proportions are not updated.
#
# inferencer: a topic inferencer object
#
# instances: an instances list object from compatible_instances()---or
# any instances that are compatible with the inferencer, i.e. their
# vocabulary has to correspond to that of the instances used to create
# the model that yielded the inference
#
# n_iterations: number of Gibbs sampling iterations
#
# sampling_interval: thinning interval
#
# burn_in: number of burn-in iterations
#
# random_seed: integer random seed; set for reproducibility
#
# returns a matrix of estimated document-topic proportions m, where m[i, j]
# gives the proportion (between 0 and 1) of topic j in document i. The
# inferencer sampling state is not accessible.
infer_topics <- function (inferencer, instances,
n_iterations=100,
sampling_interval=10, # aka "thinning"
burn_in=10,
random_seed=NULL) {
iter <- instances$iterator()
n_iterations <- as.integer(n_iterations)
sampling_interval <- as.integer(sampling_interval)
burn_in <- as.integer(burn_in)
if (!is.null(random_seed)) {
inferencer$setRandomSeed(as.integer(random.seed))
}
doc_topics <- vector("list", instances$size())
for (j in 1:instances$size()) {
inst <- .jcall(iter, "Ljava/lang/Object;", "next")
doc_topics[[j]] <- inferencer$getSampledDistribution(inst,
n_iterations, sampling_interval, burn_in)
}
do.call(rbind, doc_topics)
}
# -----
# given an existing instances list object and some new texts,
# generate a compatible instances list object which can be input
# into the inferencer
#
# ids: character vector of item ids
#
# texts: character vector of texts (same length as ids)
#
# instances: instances to enforce compatibility with
#
# returns a reference to the new instances list object. Save this to disk
# with the litdata package function write_instances()
compatible_instances <- function (ids, texts, instances) {
mallet_pipe <- instances$getPipe()
new_insts <- .jnew("cc/mallet/types/InstanceList",
.jcast(mallet_pipe, "cc/mallet/pipe/Pipe"))
J("cc/mallet/topics/RTopicModel")$addInstances(new_insts, ids, texts)
new_insts
}
# -----
# number of tokens in each document in an instance list
#
# instances: reference to an instances list
#
# returns a vector of integers with token counts
instances_lengths <- function (instances) {
iter <- instances$iterator()
replicate(instances$size(),
.jcall(iter, "Ljava/lang/Object;", "next")$getData()$size()
)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment