Skip to content

Instantly share code, notes, and snippets.

@C4N4D4M4N
Last active February 5, 2019 21:15
Show Gist options
  • Save C4N4D4M4N/7c0596bb751c2144df246d8b375d8334 to your computer and use it in GitHub Desktop.
Save C4N4D4M4N/7c0596bb751c2144df246d8b375d8334 to your computer and use it in GitHub Desktop.
RNN Masks not applying
package com.IceKontroI.AI;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad;
import java.util.ArrayDeque;
import java.util.Random;
import static org.deeplearning4j.nn.api.OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
import static org.deeplearning4j.nn.weights.WeightInit.XAVIER_UNIFORM;
import static org.nd4j.linalg.activations.Activation.IDENTITY;
import static org.nd4j.linalg.activations.Activation.TANH;
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE;
public class RecurrentMaskTestBench {
@SuppressWarnings("ConstantConditions")
public static void main(String[] args) {
Nd4j.setDataType(DataBuffer.Type.FLOAT);
int print = 2;
int epochs = 5;
int size = 1;
int timesteps = 24;
System.out.println("Testing with normal RNN data and nothing masked out");
System.out.println();
test(print, epochs, false, false, false, false, size, timesteps);
System.out.println("Testing with normal RNN data, but everything masked out");
System.out.println();
test(print, epochs, false, true, false, false, size, timesteps);
System.out.println("Testing with RNN data values of all 0s, and everything masked out");
System.out.println();
test(print, epochs, true, true, false, false, size, timesteps);
}
public static void test(int print, int epochs, boolean inputData0, boolean inputMask0, boolean labelData0, boolean labelMask0, int size, int timesteps) {
DataConfig[] inputs = { // FF or RNN shape auto-detected ({24} would be FF, {1, 24} would be RNN}
DataConfig.makeInput(new Random(123), 32, inputData0, inputMask0, size, timesteps)
};
DataConfig[] labels = {
DataConfig.makeLabel(new Random(456789), labelData0, labelMask0, 1),
// DataConfig.makeLabel(new Random(987654), labelData0, labelMask0, 2) // Can adjust these to achieve more complex architectures as desired
};
RandomDataIterator RDI = new RandomDataIterator(inputs, labels, 10, 1);
if (print > 0) {
int i = 0;
while (i++ < print && RDI.hasNext()) {
System.out.println(RDI.next());
}
System.out.println();
RDI.reset();
}
ComputationGraph model = generateModel(inputs, labels, 4, 4);
for (int a = 0; a < epochs; a++) {
model.fit(RDI);
RDI.reset();
RegressionEvaluation[] evals = new RegressionEvaluation[labels.length];
for (int b = 0; b < evals.length; b++) {
evals[b] = new RegressionEvaluation();
}
while (RDI.hasNext()) {
MultiDataSet next = RDI.next();
INDArray[] output = model.output(false, next.getFeatures(), next.getFeaturesMaskArrays(), next.getLabels());
for (int b = 0; b < output.length; b++) {
evals[b].eval(next.getLabels(b), output[b], next.getLabelsMaskArray(b));
}
}
RDI.reset();
System.out.println("EPOCH: " + a);
for (RegressionEvaluation eval : evals) {
System.out.println(eval.stats());
}
}
}
public static class Example {
public Data[] inputs;
public Data[] labels;
public Example(Data[] inputs, Data[] labels) {
this.inputs = inputs;
this.labels = labels;
}
}
public static class Data {
public INDArray data;
public INDArray mask;
public Data(INDArray data, INDArray mask) {
this.data = data;
this.mask = mask;
}
}
public static class DataConfig {
public Random random;
public int density;
public boolean setData0;
public boolean setMask0;
public int[] shape;
public static DataConfig makeInput(Random random, int density, boolean setData0, boolean setMask0, int ... shape) {
DataConfig c = new DataConfig();
c.random = random;
c.density = density;
c.setData0 = setData0;
c.setMask0 = setMask0;
c.shape = new int[shape.length + 1];
System.arraycopy(shape, 0, c.shape, 1, shape.length);
c.shape[0] = 1;
return c;
}
public static DataConfig makeLabel(Random random, boolean setData0, boolean setMask0, int size) {
DataConfig c = new DataConfig();
c.random = random;
c.setData0 = setData0;
c.setMask0 = setMask0;
c.shape = new int[]{1, size};
return c;
}
public Data generate() {
INDArray data;
if (setData0) {
data = Nd4j.zeros(shape);
} else {
data = Nd4j.rand(shape, random.nextInt(Integer.MAX_VALUE)).subi(0.5).muli(2); // To get in range [-1, 1]
}
INDArray mask = setMask0 ? Nd4j.zeros(shape) : Nd4j.ones(shape);
return new Data(data, mask);
}
}
public static class RandomDataIterator implements MultiDataSetIterator {
public DataConfig[] inputs;
public DataConfig[] labels;
public ArrayDeque<Example> examples;
public int totalNumExamples;
public int miniBatchSize;
public RandomDataIterator(DataConfig[] inputs, DataConfig[] labels, int totalNumExamples, int miniBatchSize) {
this.inputs = inputs;
this.labels = labels;
this.examples = new ArrayDeque<>(totalNumExamples);
this.totalNumExamples = totalNumExamples;
this.miniBatchSize = miniBatchSize;
this.reset();
}
@Override
public boolean hasNext() {
// Doesn't care about returning exact batches
return hasNext(1);
}
@Override
public MultiDataSet next() {
return next(miniBatchSize);
}
public boolean hasNext(int batch) {
return examples.size() >= batch;
}
@Override
public MultiDataSet next(int batch) {
// Might need to produce a smaller than requested batch
batch = Math.min(batch, examples.size());
if (batch == 0) {
throw new IllegalStateException("Ran out of Examples");
}
// Stack several Examples into a batch
INDArray[][] inputData = new INDArray[numInputs()][batch];
INDArray[][] inputMask = new INDArray[numInputs()][batch];
INDArray[][] labelData = new INDArray[numLabels()][batch];
INDArray[][] labelMask = new INDArray[numLabels()][batch];
for (int a = 0; a < batch; a++) {
Example e = examples.pop();
for (int b = 0; b < numInputs(); b++) {
inputData[b][a] = e.inputs[b].data;
inputMask[b][a] = e.inputs[b].mask;
}
for (int b = 0; b < numLabels(); b++) {
labelData[b][a] = e.labels[b].data;
labelMask[b][a] = e.labels[b].mask;
}
}
INDArray[] inputDataStack = new INDArray[numInputs()];
INDArray[] inputMaskStack = new INDArray[numInputs()];
for (int a = 0; a < numInputs(); a++) {
inputDataStack[a] = Nd4j.vstack(inputData[a]);
inputMaskStack[a] = Nd4j.vstack(inputMask[a]);
}
INDArray[] labelDataStack = new INDArray[numLabels()];
INDArray[] labelMaskStack = new INDArray[numLabels()];
for (int a = 0; a < numLabels(); a++) {
labelDataStack[a] = Nd4j.vstack(labelData[a]);
labelMaskStack[a] = Nd4j.vstack(labelMask[a]);
}
return new org.nd4j.linalg.dataset.MultiDataSet(inputDataStack, labelDataStack, inputMaskStack, labelMaskStack);
}
@Override
public void setPreProcessor(MultiDataSetPreProcessor preProcessor) {
throw new UnsupportedOperationException("MultiDataSetPreProcessor not supported");
}
@Override
public MultiDataSetPreProcessor getPreProcessor() {
throw new UnsupportedOperationException("MultiDataSetPreProcessor not supported");
}
@Override
public void reset() {
examples.clear();
for (int a = 0; a < totalNumExamples; a++) {
Data[] inputs = new Data[this.inputs.length];
for (int b = 0; b < this.inputs.length; b++) {
inputs[b] = this.inputs[b].generate();
}
Data[] labels = new Data[this.labels.length];
for (int b = 0; b < this.labels.length; b++) {
labels[b] = this.labels[b].generate();
}
examples.add(new Example(inputs, labels));
}
}
public int numInputs() {
return inputs.length;
}
public int numLabels() {
return labels.length;
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public boolean asyncSupported() {
return false;
}
}
public static ComputationGraph generateModel(DataConfig[] inputs, DataConfig[] labels, int hiddenLayers, int preOutputDensity) {
String[] inputNames = new String[inputs.length];
InputType[] inputTypes = new InputType[inputs.length];
for (int a = 0; a < inputs.length; a++) {
// All set elements are merged by now, and have same dimensions
int[] shape = inputs[a].shape;
inputNames[a] = "Input " + a;
inputTypes[a] = new InputType.InputTypeRecurrent(shape[1], shape[2]); // [batch, size, timesteps]
}
ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder()
.optimizationAlgo(STOCHASTIC_GRADIENT_DESCENT)
.weightInit(XAVIER_UNIFORM)
.activation(TANH)
.updater(new AdaGrad(0.01))
.l2(0.0001)
.seed(1234)
.graphBuilder()
.addInputs(inputNames)
.setInputTypes(inputTypes);
String[] layersToMerge = new String[inputs.length];
int mergeSize = 0;
for (int a = 0; a < inputs.length; a++) {
String thisLayer = "RNN " + inputNames[a];
long nOut = inputs[a].shape[1] * inputs[a].density; // Density is a nOut multiplier
builder.addLayer(thisLayer, new LastTimeStep(new LSTM.Builder()
.nOut(nOut)
.build()), inputNames[a]);
System.out.println(thisLayer + ": nOut = " + nOut);
layersToMerge[a] = thisLayer;
mergeSize += nOut;
}
String[] outputs = new String[labels.length];
for (int a = 0; a < labels.length; a++) {
String[] lastLayer = layersToMerge;
int outputSize = labels[a].shape[labels[a].shape.length - 1];
int reduction = Math.max(1, mergeSize / hiddenLayers);
long nIn = mergeSize;
for (int b = 0; b < hiddenLayers; b++) {
long nOut = nIn - reduction;
if (nOut < Math.max(preOutputDensity, outputSize)) {
break;
}
String thisLayer = "Dense Hidden " + a + "-" + b;
builder.addLayer(thisLayer, new DenseLayer.Builder()
.nIn(nIn)
.nOut(nOut)
.build(), lastLayer);
System.out.println(thisLayer + ": nIn = " + nIn + " nOut = " + nOut);
lastLayer = new String[]{thisLayer};
nIn = nOut;
}
String thisLayer = "Output " + a;
builder.addLayer(thisLayer, new OutputLayer.Builder()
.activation(IDENTITY)
.lossFunction(MSE)
.nOut(outputSize)
.build(), lastLayer);
System.out.println(thisLayer + ": nOut = " + outputSize);
outputs[a] = thisLayer;
}
System.out.println();
builder.setOutputs(outputs);
ComputationGraph CG = new ComputationGraph(builder.build());
CG.init();
return CG;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment