Skip to content

Instantly share code, notes, and snippets.

@saudet
Created February 17, 2017 14:19
Show Gist options
  • Save saudet/be8c695b447e7f0fc296e2347ec596b1 to your computer and use it in GitHub Desktop.
Save saudet/be8c695b447e7f0fc296e2347ec596b1 to your computer and use it in GitHub Desktop.
Sample code for t-SNE visualization with T-SNE-Java or Deeplearning4j
import java.io.File;
import java.util.Arrays;
import javax.swing.JFrame;
import org.math.plot.FrameView;
import org.math.plot.Plot2DPanel;
import org.math.plot.PlotPanel;
import org.math.plot.plots.ColoredScatterPlot;
import org.math.plot.plots.ScatterPlot;
import com.jujutsu.tsne.barneshut.BHTSne;
import com.jujutsu.tsne.barneshut.BarnesHutTSne;
import com.jujutsu.tsne.barneshut.ParallelBHTsne;
import com.jujutsu.utils.MatrixOps;
import com.jujutsu.utils.MatrixUtils;
import org.deeplearning4j.plot.BarnesHutTsne;
import org.nd4j.linalg.factory.Nd4j;
public class TSnePlot {
public static void main(String [] args) throws Exception {
int initial_dims = -1;
double perplexity = 20.0;
int iterations = 100;
double theta = 0.500000;
// Some random data
String [] labels = new String[50000];
double [][] X = new double[50000][100];
for (int i = 0; i < 50000; i++) {
for (int j = 0; j < 100; j++) {
X[i][j] = Math.random();
}
}
if (false) {
System.out.println(MatrixOps.doubleArrayToPrintString(X, ", ", 50,10));
BarnesHutTSne tsne;
boolean parallel = true;
if(parallel) {
tsne = new ParallelBHTsne();
} else {
tsne = new BHTSne();
}
double [][] Y = tsne.tsne(X, 2, initial_dims, perplexity, iterations, false, theta);
// Plot Y or save Y to file and plot with some other tool such as for instance R
Plot2DPanel plot = new Plot2DPanel();
if(labels != null) {
ColoredScatterPlot setosaPlot = new ColoredScatterPlot("TSne Result", Y, labels);
plot.plotCanvas.addPlot(setosaPlot);
} else {
ScatterPlot dataPlot = new ScatterPlot("Data", PlotPanel.COLORLIST[0], Y);
plot.plotCanvas.addPlot(dataPlot);
}
plot.plotCanvas.setNotable(true);
plot.plotCanvas.setNoteCoords(true);
FrameView plotframe = new FrameView(plot);
plotframe.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
plotframe.setVisible(true);
} else {
BarnesHutTsne tsne = new BarnesHutTsne.Builder()
.setMaxIter(iterations)
.perplexity(perplexity)
.theta(theta)
.normalize(false)
// .usePca(false)
.build();
tsne.plot(Nd4j.create(X), 2, Arrays.asList(labels), "plot.csv");
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment