Last active
November 12, 2023 07:23
-
-
Save agoldst/edcfd45b5ac371296b76 to your computer and use it in GitHub Desktop.
Functions for using MALLET's topic-inference capability from R: given an existing topic model, estimate topic proportions for new documents
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# mallet-inference.R | |
# | |
# functions for using MALLET's topic-inference functionality: given an | |
# existing topic model, estimate topic proportions for new documents | |
# | |
# source() this file | |
# | |
# Workflow | |
# -------- | |
# | |
# 1. Create instances-list object of base corpus | |
# (or load from disk with litdata::read_mallet_instances) | |
# 2. Create topic model of base corpus | |
# 3. Get inferencer object for the model with inferencer() | |
# (or, having done this earlier and saved it with write_inferencer(), | |
# load it from disk) | |
# 4. Use compatible_instances() to create instances-list object of new corpus | |
# (or, having done this, load it from disk) | |
# 5. Use infer_topics() to infer topics for new corpus | |
# | |
# Step (4) can be done any time after step (1). | |
# | |
# Last update: 2015-05-19 by AG | |
# ----- | |
# given a trained model object, return its topic inferencer object | |
# | |
# model: model object from MalletLDA(). Run model$train() first. | |
# | |
# returns a reference to a topic inferencer object | |
inferencer <- function (model) { | |
model$model$getInferencer() | |
} | |
# ----- | |
# save an inferencer object to a file | |
# | |
# inf: a reference to a topic inferencer, from inferencer() | |
# | |
# out_file: the name of a file to save to (will overwrite an existing file) | |
write_inferencer <- function (inf, out_file) { | |
fos <- .jnew("java/io/FileOutputStream", out_file) | |
oos <- .jnew("java/io/ObjectOutputStream", | |
.jcast(fos, "java/io/OutputStream")) | |
oos$writeObject(inf) | |
oos$close() | |
} | |
# ----- | |
# retrieve an inferencer object from a file | |
# | |
# returns a reference to a topic inferencer object | |
read_inferencer <- function (in_file) { | |
J("cc.mallet.topics.TopicInferencer")$read( | |
new(J("java.io.File"), in_file) | |
) | |
} | |
# ----- | |
# infer document topics. This is like the Gibbs sampling process for making a | |
# topic model, but the topic-word proportions are not updated. | |
# | |
# inferencer: a topic inferencer object | |
# | |
# instances: an instances list object from compatible_instances()---or | |
# any instances that are compatible with the inferencer, i.e. their | |
# vocabulary has to correspond to that of the instances used to create | |
# the model that yielded the inference | |
# | |
# n_iterations: number of Gibbs sampling iterations | |
# | |
# sampling_interval: thinning interval | |
# | |
# burn_in: number of burn-in iterations | |
# | |
# random_seed: integer random seed; set for reproducibility | |
# | |
# returns a matrix of estimated document-topic proportions m, where m[i, j] | |
# gives the proportion (between 0 and 1) of topic j in document i. The | |
# inferencer sampling state is not accessible. | |
infer_topics <- function (inferencer, instances, | |
n_iterations=100, | |
sampling_interval=10, # aka "thinning" | |
burn_in=10, | |
random_seed=NULL) { | |
iter <- instances$iterator() | |
n_iterations <- as.integer(n_iterations) | |
sampling_interval <- as.integer(sampling_interval) | |
burn_in <- as.integer(burn_in) | |
if (!is.null(random_seed)) { | |
inferencer$setRandomSeed(as.integer(random.seed)) | |
} | |
doc_topics <- vector("list", instances$size()) | |
for (j in 1:instances$size()) { | |
inst <- .jcall(iter, "Ljava/lang/Object;", "next") | |
doc_topics[[j]] <- inferencer$getSampledDistribution(inst, | |
n_iterations, sampling_interval, burn_in) | |
} | |
do.call(rbind, doc_topics) | |
} | |
# ----- | |
# given an existing instances list object and some new texts, | |
# generate a compatible instances list object which can be input | |
# into the inferencer | |
# | |
# ids: character vector of item ids | |
# | |
# texts: character vector of texts (same length as ids) | |
# | |
# instances: instances to enforce compatibility with | |
# | |
# returns a reference to the new instances list object. Save this to disk | |
# with the litdata package function write_instances() | |
compatible_instances <- function (ids, texts, instances) { | |
mallet_pipe <- instances$getPipe() | |
new_insts <- .jnew("cc/mallet/types/InstanceList", | |
.jcast(mallet_pipe, "cc/mallet/pipe/Pipe")) | |
J("cc/mallet/topics/RTopicModel")$addInstances(new_insts, ids, texts) | |
new_insts | |
} | |
# ----- | |
# number of tokens in each document in an instance list | |
# | |
# instances: reference to an instances list | |
# | |
# returns a vector of integers with token counts | |
instances_lengths <- function (instances) { | |
iter <- instances$iterator() | |
replicate(instances$size(), | |
.jcall(iter, "Ljava/lang/Object;", "next")$getData()$size() | |
) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment