Created
April 7, 2015 18:46
-
-
Save tdhopper/af1b91706ee1b6300839 to your computer and use it in GitHub Desktop.
Code from "'Infinite LDA' – Implementing the HDP with minimum code complexity".
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* (C) Copyright 2005-2011, Gregor Heinrich (gregor :: arbylon : net) \ | |
* (This file is part of the knowceans-ilda experimental software package | |
*/ | |
/* | |
* knowceans-ilda is free software; you can redistribute it and/or modify it | |
* under the terms of the GNU General Public License as published by the Free | |
* Software Foundation; either version 3 of the License, or (at your option) | |
* any later version. | |
*/ | |
/* | |
* knowceans-ilda is distributed in the hope that it will be useful, but WITHOUT | |
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS | |
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. | |
*/ | |
/* | |
* You should have received a copy of the GNU General Public License along with | |
* this program; if not, write to the Free Software Foundation, Inc., 59 Temple | |
* Place, Suite 330, Boston, MA 02111-1307 USA | |
*/ | |
package org.knowceans.topics.simple; | |
import static java.lang.Math.log; | |
import java.io.FileNotFoundException; | |
import java.io.PrintStream; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.Random; | |
import java.util.SortedSet; | |
import java.util.TreeSet; | |
import org.knowceans.corpus.NumCorpus; | |
import org.knowceans.corpus.VisCorpus; | |
import org.knowceans.util.ArrayIo; | |
import org.knowceans.util.ArrayUtils; | |
import org.knowceans.util.CokusRandom; | |
import org.knowceans.util.DirichletEstimation; | |
import org.knowceans.util.IndexQuickSort; | |
import org.knowceans.util.RandomSamplers; | |
import org.knowceans.util.StopWatch; | |
import org.knowceans.util.Vectors; | |
/** | |
* LDA Gibbs sampler with nonparametric prior (HDP): | |
* <p> | |
* (m,k | alpha * tau | gamma), k->inf, (k,t | beta) | |
* <p> | |
* using Teh et al. (2006) approach for the direct assignment sampler, with | |
* modular LDA parametric sampler first published by Griffiths (2002) and | |
* explained in Heinrich (2005). For the original LDA paper, see Blei et al. | |
* (2002). | |
* <p> | |
* The general idea is to retain as much as possible of the standard LDA Gibbs | |
* sampler, which is possible by alternatingly sampling the finite case with K + | |
* 1 topics and resampling the topic weights taking into account the current | |
* assignments of data items to topics and pruning or expanding the topic set | |
* accordingly. | |
* <p> | |
* I tried to find the (subjectively) best tradeoff between simplicity and the | |
* JASA paper (Teh et al. 2006). Therefore I have only used the direct | |
* assignment method. | |
* <p> | |
* The implementation uses lists instead of primitive arrays, but for | |
* performance reasons, this may be changed to have a bound Kmax to allocate | |
* fixed-size arrays, similar to a truncated DP. | |
* <p> | |
* Caveats: (1) Performance is not a core criterion, and OOP encapsulation is | |
* ignored for compactness' sake. (2) Code still uses the likelihood function of | |
* LDA, and without the hyperparameter terms. | |
* <p> | |
* LICENSE: GPL3, see: http://www.gnu.org/licenses/gpl-3.0.html | |
* <p> | |
* References: | |
* <p> | |
* D.M. Blei, A. Ng, M.I. Jordan. Latent Dirichlet Allocation. NIPS, 2002 | |
* <p> | |
* T. Griffiths. Gibbs sampling in the generative model of Latent Dirichlet | |
* Allocation. TR, 2002, www-psych.stanford.edu/~gruffydd/cogsci02/lda.ps | |
* <p> | |
* G. Heinrich. Parameter estimation for text analysis. TR, 2009, | |
* www.arbylon.net/publications/textest2.pdf | |
* <p> | |
* G. Heinrich. "Infinite LDA" -- implementing the HDP with minimum code | |
* complexity. TN2011/1, www.arbylon.net/publications/ilda.pdf | |
* <p> | |
* Y.W. Teh, M.I. Jordan, M.J. Beal, D.M. Blei. Hierarchical Dirichlet | |
* Processes. JASA, 101:1566-1581, 2006 | |
* | |
* @author (c) 2008-2011 Gregor Heinrich, gregor :: arbylon : net | |
* @version 0.95 | |
*/ | |
public class IldaGibbs implements ISimpleGibbs, ISimpleQueryGibbs, ISimplePpx { | |
private static TopicMatrixPanel vis; | |
/** | |
* test driver for mixture network Gibbs sampler | |
* | |
* @param args | |
*/ | |
public static void main(String[] args) { | |
int niter = 500, niterq = 10; | |
String filebase = "nips/nips"; | |
// file or synthetic | |
boolean usefile = true; | |
// topic display panel | |
boolean display = true; | |
Random rand = new CokusRandom(56567651); | |
NumCorpus corpus; | |
if (usefile) { | |
corpus = new NumCorpus(filebase + ".corpus"); | |
if (display) { | |
// panel is K x K terms large, so let's use the sqrt. | |
vis = new TopicMatrixPanel(900, 400, (int) Math.sqrt(corpus | |
.getNumTerms()), 1); | |
} | |
} else { | |
// test with generated corpus | |
int K = 10; | |
corpus = VisCorpus.generateLdaCorpus(K, 1000, 200); | |
if (display) { | |
vis = new TopicMatrixPanel(900, 400, K, 300 / K); | |
} | |
} | |
// corpus.reduce(100, rand); | |
corpus.split(10, 2, rand); | |
NumCorpus train = (NumCorpus) corpus.getTrainCorpus(); | |
NumCorpus test = (NumCorpus) corpus.getTestCorpus(); | |
int[][] w = train.getDocWords(rand); | |
int[][] wq = test.getDocWords(rand); | |
int K0 = 0; | |
int V = corpus.getNumTerms(); | |
double alpha = 1.; | |
// beta = 1 --> K = 12,.5-->16, .1-->26@200, 75@500, 115@645 (beta | |
// should be larger), | |
// | |
double beta = .1; | |
double gamma = 1.5; | |
// run sampler | |
IldaGibbs gs = new IldaGibbs(w, wq, K0, V, alpha, beta, gamma, rand); | |
gs.init(); | |
System.out.println("initialised"); | |
System.out.println(gs); | |
// initial test | |
gs.initq(); | |
gs.runq(niterq); | |
System.out.println("perplexity = " + gs.ppx()); | |
StopWatch.start(); | |
System.out.println("starting Gibbs sampler with " + niter | |
+ " iterations"); | |
gs.run(niter); | |
System.out.println(StopWatch.format(StopWatch.stop())); | |
// test | |
gs.initq(); | |
gs.runq(niterq); | |
System.out.println("perplexity = " + gs.ppx()); | |
System.out.println(gs); | |
gs.packTopics(); | |
System.out.println("finished"); | |
System.out.println(gs); | |
if (!usefile) | |
try { | |
PrintStream bw = new PrintStream(filebase + ".ilda.result"); | |
gs.print(bw, filebase, corpus.getOrigDocIds()[0], | |
train.getNumWords()); | |
bw.close(); | |
System.out.println("done"); | |
} catch (FileNotFoundException e) { | |
e.printStackTrace(); | |
} | |
} // main | |
/** | |
* print the result of the model, save topics to files | |
* | |
* @param filebase | |
* @param docids | |
* @param W | |
*/ | |
private void print(PrintStream out, String filebase, int[] docids, int W) { | |
phi = new double[K][V]; | |
double[][] theta = new double[M][K]; | |
for (int k = 0; k < K; k++) { | |
for (int t = 0; t < V; t++) { | |
phi[k][t] = (nkt.get(k)[t] + beta) / (nk.get(k) + beta * V); | |
} | |
} | |
for (int m = 0; m < M; m++) { | |
for (int k = 0; k < K; k++) { | |
theta[m][k] = (nmk[m].get(k) + alpha) | |
/ (w[m].length + alpha * K); | |
} | |
} | |
LdaTopics lt = new LdaTopics(filebase, theta, phi); | |
for (int m = 0; m < M; m++) { | |
// no labels | |
out.println(lt.printDocument(m, docids[m], 10, false, true)); | |
// out.println(lt.printDocument(m, docids[m], 10, true, true)); | |
} | |
int[] kk = IndexQuickSort.sort(nk); | |
IndexQuickSort.reverse(kk); | |
for (int k = 0; k < K; k++) { | |
out.println(String.format("%d (%2.5f / %d): %s", k, nk.get(kk[k]) | |
/ (double) W * K, K, lt.printTopic(kk[k], 20))); | |
} | |
ArrayIo.saveBinaryMatrix(filebase + ".ilda.theta.zip", theta); | |
ArrayIo.saveBinaryMatrix(filebase + ".ilda.phi.zip", phi); | |
} | |
private int[][] w; | |
private int[][] wq; | |
/** | |
* inactive components: index = index in count arrays, element = value in z. | |
* Filled on component removal. | |
*/ | |
// reuse the component emptied last | |
// private List<Integer> kgaps; | |
// reuse component with the lowest index | |
private SortedSet<Integer> kgaps; | |
/** | |
* active components: value = value in z and count arrays, which should | |
* always match. This array is never removed elements from but inactive | |
* elements are set to -1. This allows to reuse components (gaps) and to | |
* keep the indices of z and counts identical. | |
*/ | |
private List<Integer> kactive; | |
private List<Integer>[] nmk; | |
private int[][] nmkq; | |
private List<int[]> nkt; | |
private List<Integer> nk; | |
private double[][] phi; | |
private int[][] z; | |
private int[][] zq; | |
private double[] pp; | |
/** | |
* step to increase the sampling array | |
*/ | |
public final int ppstep = 10; | |
/** | |
* precision of the 2nd-level DP | |
*/ | |
private double alpha; | |
/** | |
* mean of the 2nd-level DP = sample from 1st-level DP | |
*/ | |
private ArrayList<Double> tau; | |
/** | |
* parameter of root base measure (= component Dirichlet) | |
*/ | |
private double beta; | |
/** | |
* precision of root DP | |
*/ | |
private double gamma; | |
// hyperparameters for DP and Dirichlet samplers | |
// Teh+06: Docs: (1, 1), M1-3: (0.1, 0.1); HMM: (1, 1) | |
double aalpha = 5; | |
double balpha = 0.1; | |
double abeta = 0.1; | |
double bbeta = 0.1; | |
// Teh+06: Docs: (1, 0.1), M1-3: (5, 0.1), HMM: (1, 1) | |
double agamma = 5; | |
double bgamma = 0.1; | |
// number of samples for parameter samplers | |
int R = 10; | |
/** | |
* total number of tables | |
*/ | |
private double T; | |
private Random rand; | |
RandomSamplers samp; | |
private int iter; | |
/** | |
* current number of non-empty components | |
*/ | |
private int K; | |
private int M; | |
private int Mq; | |
private int Wq; | |
private int V; | |
private boolean inited = false; | |
private boolean fixedK = false; | |
private boolean fixedHyper = false; | |
/** | |
* parametrise gibbs sampler | |
* | |
* @param w | |
* word tokens | |
* @param wq | |
* word tokens (testing) | |
* @param K | |
* initial number of topics: may be 0 if gamma > 0. | |
* @param V | |
* number of terms | |
* @param alpha | |
* node A precision (document DP) | |
* @param gamma | |
* node A precision (root DP), 0 for fixed K: plain LDA. | |
* @param beta | |
* node B hyperparam | |
* @param rand | |
* random number generator | |
*/ | |
public IldaGibbs(int[][] w, int[][] wq, int K, int V, double alpha, | |
double beta, double gamma, Random rand) { | |
// assign | |
this.w = w; | |
this.wq = wq; | |
// start with 0 or more topics | |
this.K = K; | |
this.alpha = alpha; | |
this.beta = beta; | |
this.gamma = gamma; | |
if (gamma == 0) { | |
this.fixedK = true; | |
} | |
this.M = w.length; | |
this.Mq = wq.length; | |
this.V = V; | |
this.rand = rand; | |
this.samp = new RandomSamplers(rand); | |
} | |
/** | |
* initialise Markov chain | |
*/ | |
@SuppressWarnings("unchecked") | |
public void init() { | |
// allocate | |
nmk = new ArrayList[M]; | |
nkt = new ArrayList<int[]>(); | |
nk = new ArrayList<Integer>(); | |
z = new int[M][]; | |
for (int m = 0; m < M; m++) { | |
nmk[m] = new ArrayList<Integer>(); | |
for (int k = 0; k < K; k++) { | |
nmk[m].add(0); | |
} | |
z[m] = new int[w[m].length]; | |
} | |
// indexing lists | |
kactive = new ArrayList<Integer>(); | |
// kgaps = new ArrayList<Integer>(); | |
kgaps = new TreeSet<Integer>(); | |
// create mean weights | |
tau = new ArrayList<Double>(); | |
for (int k = 0; k < K; k++) { | |
kactive.add(k); | |
nkt.add(new int[V]); | |
nk.add(0); | |
// set to value for fixed K | |
tau.add(1. / K); | |
} | |
// tau has one dimension more | |
tau.add(1. / K); | |
pp = new double[K + ppstep]; | |
// initialise (run without decrements because z[*][*] = -1) | |
run(1); | |
if (!fixedK) { | |
updateTau(); | |
} | |
inited = true; | |
} | |
/** | |
* initialise Markov chain for querying | |
*/ | |
public void initq() { | |
// compute parameters | |
int Kg = K + kgaps.size(); | |
phi = new double[Kg][V]; | |
for (int kk = 0; kk < K; kk++) { | |
int k = kactive.get(kk); | |
for (int t = 0; t < V; t++) { | |
phi[k][t] = (nkt.get(k)[t] + beta) / (nk.get(k) + V * beta); | |
} | |
} | |
// allocate | |
nmkq = new int[Mq][Kg]; | |
zq = new int[Mq][]; | |
Wq = 0; | |
// initialise | |
for (int m = 0; m < Mq; m++) { | |
zq[m] = new int[wq[m].length]; | |
for (int n = 0; n < wq[m].length; n++) { | |
int k = rand.nextInt(K); | |
zq[m][n] = k; | |
nmkq[m][k]++; | |
Wq++; | |
} | |
} | |
} | |
/** | |
* run Gibbs sampler | |
* | |
* @param niter | |
* number of Gibbs iterations | |
*/ | |
public void run(int niter) { | |
for (iter = 0; iter < niter; iter++) { | |
System.out.println(iter); | |
System.out.println(this); | |
for (int m = 0; m < M; m++) { | |
for (int n = 0; n < w[m].length; n++) { | |
// sampling z | |
int k, kold = -1; | |
int t = w[m][n]; | |
if (inited) { | |
k = z[m][n]; | |
// decrement | |
nmk[m].set(k, nmk[m].get(k) - 1); | |
nkt.get(k)[t]--; | |
nk.set(k, nk.get(k) - 1); | |
kold = k; | |
} | |
// compute weights | |
double psum = 0; | |
// (37) | |
for (int kk = 0; kk < K; kk++) { | |
k = kactive.get(kk); | |
pp[kk] = (nmk[m].get(k) + alpha * tau.get(k)) * // | |
(nkt.get(k)[t] + beta) / (nk.get(k) + V * beta); | |
psum += pp[kk]; | |
} | |
// likelihood of new component | |
if (!fixedK) { | |
pp[K] = alpha * tau.get(K) / V; | |
psum += pp[K]; | |
} | |
double u = rand.nextDouble(); | |
u *= psum; | |
psum = 0; | |
int kk = 0; | |
for (; kk < K + 1; kk++) { | |
psum += pp[kk]; | |
if (u <= psum) { | |
break; | |
} | |
} | |
// reassign and increment | |
if (kk < K) { | |
k = kactive.get(kk); | |
z[m][n] = k; | |
nmk[m].set(k, nmk[m].get(k) + 1); | |
nkt.get(k)[t]++; | |
nk.set(k, nk.get(k) + 1); | |
} else { | |
assert (!fixedK); | |
z[m][n] = spawnTopic(m, t); | |
updateTau(); | |
System.out.println("K = " + K); | |
} | |
// empty topic? | |
if (inited && nk.get(kold) == 0) { | |
// remove the object not the index | |
kactive.remove((Integer) kold); | |
kgaps.add(kold); | |
assert (Vectors.sum(nkt.get(kold)) == 0 | |
&& nk.get(kold) == 0 && nmk[m].get(kold) == 0); | |
K--; | |
System.out.println("K = " + K); | |
updateTau(); | |
} | |
} // n | |
} // m | |
if (vis != null) { | |
vis.setTopics(nkt); | |
} | |
if (!fixedK) { | |
updateTau(); | |
} | |
if (iter > 10 && !fixedHyper) { | |
updateHyper(); | |
} | |
} // i | |
} | |
/** | |
* query Gibbs sampler. This assumes the standard LDA model as we know the | |
* dimensionality from the training set, therefore topics need to be pruned. | |
* | |
* @param niter | |
* number of Gibbs iterations | |
*/ | |
public void runq(int niter) { | |
for (int qiter = 0; qiter < niter; qiter++) { | |
for (int m = 0; m < nmkq.length; m++) { | |
for (int n = 0; n < wq[m].length; n++) { | |
// decrement | |
int k = zq[m][n]; | |
int t = wq[m][n]; | |
nmkq[m][k]--; | |
// compute weights | |
double psum = 0; | |
for (int kk = 0; kk < K; kk++) { | |
pp[kk] = (nmkq[m][kk] + alpha) * phi[kk][t]; | |
psum += pp[kk]; | |
} | |
// sample | |
double u = rand.nextDouble() * psum; | |
psum = 0; | |
int kk = 0; | |
for (; kk < K; kk++) { | |
psum += pp[kk]; | |
if (u <= psum) { | |
break; | |
} | |
} | |
// reassign and increment | |
zq[m][n] = kk; | |
nmkq[m][kk]++; | |
} // n | |
} // m | |
} // i | |
} | |
/** | |
* adds a topic to the list of active topics, either by reusing an existing | |
* inactive index (gap) or increasing the count arrays. NB: Within this | |
* method, the state is inconsistent. | |
* | |
* @param m | |
* current document | |
* @param t | |
* current term | |
* @return index of topic spawned | |
*/ | |
private int spawnTopic(int m, int t) { | |
int k; | |
if (kgaps.size() > 0) { | |
// reuse gap | |
// k = kgaps.remove(kgaps.size() - 1); | |
k = kgaps.first(); | |
kgaps.remove(k); | |
kactive.add(k); | |
nmk[m].set(k, 1); | |
nkt.get(k)[t] = 1; | |
nk.set(k, 1); | |
} else { | |
// add element to count arrays | |
k = K; | |
for (int i = 0; i < M; i++) { | |
nmk[i].add(0); | |
} | |
kactive.add(K); | |
nmk[m].set(K, 1); | |
nkt.add(new int[V]); | |
nkt.get(K)[t] = 1; | |
nk.add(1); | |
tau.add(0.); | |
} | |
K++; | |
if (pp.length <= K) { | |
pp = new double[K + ppstep]; | |
} | |
return k; | |
} | |
/** | |
* reorders topics such that no gaps exist in the count arrays and topics | |
* are ordered with their counts descending. Removes any gap dimensions. | |
*/ | |
public void packTopics() { | |
// sort topics by size | |
int[] knew2k = IndexQuickSort.sort(nk); | |
IndexQuickSort.reverse(knew2k); | |
// reorder and weed out empty count arrays | |
IndexQuickSort.reorder(nk, knew2k); | |
IndexQuickSort.reorder(nkt, knew2k); | |
for (int i = 0; i < kgaps.size(); i++) { | |
nk.remove(nk.size() - 1); | |
nkt.remove(nkt.size() - 1); | |
} | |
for (int m = 0; m < M; m++) { | |
IndexQuickSort.reorder(nmk[m], knew2k); | |
for (int i = 0; i < kgaps.size(); i++) { | |
nmk[m].remove(nmk[m].size() - 1); | |
} | |
} | |
// any new topics will be appended | |
kgaps.clear(); | |
int[] k2knew = IndexQuickSort.inverse(knew2k); | |
// rewrite topic labels | |
for (int i = 0; i < K; i++) { | |
kactive.set(i, k2knew[kactive.get(i)]); | |
} | |
for (int m = 0; m < M; m++) { | |
for (int n = 0; n < w[m].length; n++) { | |
z[m][n] = k2knew[z[m][n]]; | |
} | |
} | |
} | |
/** | |
* prune topics and update tau, the root DP mixture weights. | |
*/ | |
private void updateTau() { | |
// (40) sample mk | |
double[] mk = new double[K + 1]; | |
// TODO: average multi-sample? | |
for (int kk = 0; kk < K; kk++) { | |
int k = kactive.get(kk); | |
for (int m = 0; m < M; m++) { | |
if (nmk[m].get(k) > 1) { | |
// number of tables a CRP(alpha tau) produces for nmk items | |
mk[kk] += samp.randAntoniak(alpha * tau.get(k), // | |
nmk[m].get(k)); | |
} else { | |
mk[kk] += nmk[m].get(k); | |
} | |
} | |
} | |
// number of tables | |
T = Vectors.sum(mk); | |
mk[K] = gamma; | |
// (36) sample tau | |
double[] tt = samp.randDir(mk); | |
for (int kk = 0; kk < K; kk++) { | |
int k = kactive.get(kk); | |
tau.set(k, tt[kk]); | |
} | |
tau.set(K, tt[K]); | |
} | |
/** | |
* update scalar DP hyperparameters alpha, gamma and Dirichlet | |
* hyperparameter beta. Assumes that T is updated (by updateTau). | |
*/ | |
private void updateHyper() { | |
for (int r = 0; r < R; r++) { | |
// gamma: root level (Escobar+West95) with n = T | |
// (14) | |
double eta = samp.randBeta(gamma + 1, T); | |
double bloge = bgamma - log(eta); | |
// (13') | |
// double pie = 1. / (1. - (T * bloge / (gamma + K - 1))); | |
double pie = 1. / (1. + (T * bloge / (agamma + K - 1))); | |
// (13) | |
int u = samp.randBernoulli(pie); | |
gamma = samp.randGamma(agamma + K - 1 + u, 1. / bloge); | |
// alpha: document level (Teh+06) | |
double qs = 0; | |
double qw = 0; | |
for (int m = 0; m < M; m++) { | |
// (49) (corrected) | |
qs += samp.randBernoulli(w[m].length / (w[m].length + alpha)); | |
// (48) | |
qw += log(samp.randBeta(alpha + 1, w[m].length)); | |
} | |
// (47) | |
alpha = samp.randGamma(aalpha + T - qs, 1. / (balpha - qw)); | |
} | |
int[] ak = (int[]) ArrayUtils.asPrimitiveArray(nk); | |
int[][] akt = new int[K][V]; | |
for (int k = 0; k < K; k++) { | |
akt[k] = nkt.get(k); | |
} | |
beta = DirichletEstimation | |
.estimateAlphaMap(akt, ak, beta, abeta, bbeta); | |
} | |
/** | |
* @return the perplexity of the last query sample. | |
*/ | |
public double ppx() { | |
// TODO: this uses LDA's perplexity --> add hyperparameters and DP stuff | |
double loglik = 0; | |
// compute thetaq | |
double[][] thetaq = new double[Mq][K]; | |
for (int m = 0; m < Mq; m++) { | |
for (int k = 0; k < K; k++) { | |
thetaq[m][k] = (nmkq[m][k] + alpha) | |
/ (wq[m].length + K * alpha); | |
} | |
} | |
// compute ppx | |
for (int m = 0; m < Mq; m++) { | |
for (int n = 0; n < wq[m].length; n++) { | |
double sum = 0; | |
for (int k = 0; k < K; k++) { | |
sum += thetaq[m][k] * phi[k][wq[m][n]]; | |
} | |
loglik += Math.log(sum); | |
} | |
} | |
return Math.exp(-loglik / Wq); | |
} | |
// ////////////// output routines //////////////// | |
/** | |
* assemble a string of overview information. | |
*/ | |
@Override | |
public String toString() { | |
return String.format("ILDA: M = %d, K = %d, V = %d, " | |
+ "alpha = %2.5f, beta = %2.5f, gamma = %2.5f", // | |
M, K, V, alpha, beta, gamma); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment