Created
May 29, 2012 18:44
-
-
Save jasonbaldridge/2829963 to your computer and use it in GitHub Desktop.
Gibbs sampler for topic models for artificial data in Steyvers and Griffiths 2007.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## An implementation of Gibbs sampling for topic models for the | |
## example in section 4 of Steyvers and Griffiths (2007): | |
## http://cocosci.berkeley.edu/tom/papers/SteyversGriffiths.pdf | |
## | |
## Author: Jason Baldridge ([email protected]) | |
# Functions to parse the input data | |
words.to.indices = data.frame(row.names=c("r","s","b","m","l"),1:5) | |
mysplit = function(x) { strsplit(x,"")[[1]] } | |
word.vector = function(x) { words.to.indices[mysplit(x),] } | |
# The document by words matrix | |
d = matrix(c(word.vector("bbbbmmmmmmllllll"), | |
word.vector("bbbbbmmmmmmmllll"), | |
word.vector("bbbbbbbmmmmmllll"), | |
word.vector("bbbbbbbmmmmmmlll"), | |
word.vector("bbbbbbbmmlllllll"), | |
word.vector("bbbbbbbbbmmmllll"), | |
word.vector("rbbbbmmmmmmlllll"), | |
word.vector("rssbbbbbbmmmmlll"), | |
word.vector("rsssbbbbbbmmmmll"), | |
word.vector("rrlllbbbbbbmllll"), | |
word.vector("rrsssbbbbbbbmmml"), | |
word.vector("rrrssssssbbbbbbm"), | |
word.vector("rrrrrrsssbbbbbbl"), | |
word.vector("rrssssssssbbbbbb"), | |
word.vector("rrrrsssssssbbbbb"), | |
word.vector("rrrrrsssssssbbbb")),ncol=16,byrow=T) | |
# The document by topics matrix | |
t.start = matrix(as.numeric(c( | |
mysplit("2222122212112122"), | |
mysplit("2212211111121221"), | |
mysplit("2221222212121222"), | |
mysplit("1112122211222222"), | |
mysplit("1121212122122222"), | |
mysplit("2112111112122211"), | |
mysplit("2211111221221112"), | |
mysplit("1212211112112112"), | |
mysplit("1221222221121121"), | |
mysplit("1211212222211221"), | |
mysplit("2121122211221111"), | |
mysplit("2222222121211212"), | |
mysplit("2221112121222112"), | |
mysplit("2211222111112122"), | |
mysplit("2111111221212121"), | |
mysplit("1211211222211112"))),ncol=16,byrow=T) | |
t = t.start | |
# Parameters | |
alpha = .1 | |
beta = .1 | |
num.iterations = 64 | |
# Constants | |
num.docs = nrow(d) | |
vocab.size = length(unique(as.vector(d))) | |
num.topics = length(unique(as.vector(t))) | |
# Populate the matrix of document by topic counts | |
cdt = matrix(nrow=num.docs,ncol=num.topics) | |
for (i in 1:num.docs) { | |
cdt[i,] = xtabs(~t.start[i,]) | |
} | |
# Populate the matrix of word by topic counts | |
cwt = matrix(rep(0,vocab.size*num.topics),nrow=vocab.size,ncol=num.topics) | |
for (i in 1:num.docs) { | |
for (j in 1:length(d[i,])) { | |
word.id = d[i,j] | |
topic.id = t[i,j] | |
cwt[word.id,topic.id] = cwt[word.id,topic.id] + 1 | |
} | |
} | |
# Gibbs sampling iterations | |
for (iteration in 1:num.iterations) { | |
print(iteration) | |
for (i in 1:num.docs) { | |
for (j in 1:length(d[i,])) { | |
word.id = d[i,j] | |
topic.old = t[i,j] | |
# Decrement counts before computing equation (3) | |
cdt[i,topic.old] = cdt[i,topic.old] - 1 | |
cwt[word.id,topic.old] = cwt[word.id,topic.old] - 1 | |
# Calculate equation (3) for each topic | |
vals = prop.table(cwt+beta,2)[word.id,] * prop.table(cdt[i,]+alpha) | |
# Sample the new topic from the normalized results for (3) | |
topic.new = sample(num.topics,1,prob=vals/sum(vals)) | |
# Set the new topic and update counts | |
t[i,j] = topic.new | |
cdt[i,topic.new] = cdt[i,topic.new] + 1 | |
cwt[word.id,topic.new] = cwt[word.id,topic.new] + 1 | |
} | |
} | |
} | |
# Document-topic distributions | |
theta = prop.table(cdt+alpha,1) | |
# Word-topic distributions | |
phi = prop.table(cwt+beta,2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment