Last active
February 5, 2019 21:15
-
-
Save C4N4D4M4N/7c0596bb751c2144df246d8b375d8334 to your computer and use it in GitHub Desktop.
RNN Masks not applying
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.IceKontroI.AI; | |
import org.deeplearning4j.eval.RegressionEvaluation; | |
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.LSTM; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.nd4j.linalg.api.buffer.DataBuffer; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.api.MultiDataSet; | |
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; | |
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.learning.config.AdaGrad; | |
import java.util.ArrayDeque; | |
import java.util.Random; | |
import static org.deeplearning4j.nn.api.OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; | |
import static org.deeplearning4j.nn.weights.WeightInit.XAVIER_UNIFORM; | |
import static org.nd4j.linalg.activations.Activation.IDENTITY; | |
import static org.nd4j.linalg.activations.Activation.TANH; | |
import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE; | |
public class RecurrentMaskTestBench { | |
@SuppressWarnings("ConstantConditions") | |
public static void main(String[] args) { | |
Nd4j.setDataType(DataBuffer.Type.FLOAT); | |
int print = 2; | |
int epochs = 5; | |
int size = 1; | |
int timesteps = 24; | |
System.out.println("Testing with normal RNN data and nothing masked out"); | |
System.out.println(); | |
test(print, epochs, false, false, false, false, size, timesteps); | |
System.out.println("Testing with normal RNN data, but everything masked out"); | |
System.out.println(); | |
test(print, epochs, false, true, false, false, size, timesteps); | |
System.out.println("Testing with RNN data values of all 0s, and everything masked out"); | |
System.out.println(); | |
test(print, epochs, true, true, false, false, size, timesteps); | |
} | |
public static void test(int print, int epochs, boolean inputData0, boolean inputMask0, boolean labelData0, boolean labelMask0, int size, int timesteps) { | |
DataConfig[] inputs = { // FF or RNN shape auto-detected ({24} would be FF, {1, 24} would be RNN} | |
DataConfig.makeInput(new Random(123), 32, inputData0, inputMask0, size, timesteps) | |
}; | |
DataConfig[] labels = { | |
DataConfig.makeLabel(new Random(456789), labelData0, labelMask0, 1), | |
// DataConfig.makeLabel(new Random(987654), labelData0, labelMask0, 2) // Can adjust these to achieve more complex architectures as desired | |
}; | |
RandomDataIterator RDI = new RandomDataIterator(inputs, labels, 10, 1); | |
if (print > 0) { | |
int i = 0; | |
while (i++ < print && RDI.hasNext()) { | |
System.out.println(RDI.next()); | |
} | |
System.out.println(); | |
RDI.reset(); | |
} | |
ComputationGraph model = generateModel(inputs, labels, 4, 4); | |
for (int a = 0; a < epochs; a++) { | |
model.fit(RDI); | |
RDI.reset(); | |
RegressionEvaluation[] evals = new RegressionEvaluation[labels.length]; | |
for (int b = 0; b < evals.length; b++) { | |
evals[b] = new RegressionEvaluation(); | |
} | |
while (RDI.hasNext()) { | |
MultiDataSet next = RDI.next(); | |
INDArray[] output = model.output(false, next.getFeatures(), next.getFeaturesMaskArrays(), next.getLabels()); | |
for (int b = 0; b < output.length; b++) { | |
evals[b].eval(next.getLabels(b), output[b], next.getLabelsMaskArray(b)); | |
} | |
} | |
RDI.reset(); | |
System.out.println("EPOCH: " + a); | |
for (RegressionEvaluation eval : evals) { | |
System.out.println(eval.stats()); | |
} | |
} | |
} | |
public static class Example { | |
public Data[] inputs; | |
public Data[] labels; | |
public Example(Data[] inputs, Data[] labels) { | |
this.inputs = inputs; | |
this.labels = labels; | |
} | |
} | |
public static class Data { | |
public INDArray data; | |
public INDArray mask; | |
public Data(INDArray data, INDArray mask) { | |
this.data = data; | |
this.mask = mask; | |
} | |
} | |
public static class DataConfig { | |
public Random random; | |
public int density; | |
public boolean setData0; | |
public boolean setMask0; | |
public int[] shape; | |
public static DataConfig makeInput(Random random, int density, boolean setData0, boolean setMask0, int ... shape) { | |
DataConfig c = new DataConfig(); | |
c.random = random; | |
c.density = density; | |
c.setData0 = setData0; | |
c.setMask0 = setMask0; | |
c.shape = new int[shape.length + 1]; | |
System.arraycopy(shape, 0, c.shape, 1, shape.length); | |
c.shape[0] = 1; | |
return c; | |
} | |
public static DataConfig makeLabel(Random random, boolean setData0, boolean setMask0, int size) { | |
DataConfig c = new DataConfig(); | |
c.random = random; | |
c.setData0 = setData0; | |
c.setMask0 = setMask0; | |
c.shape = new int[]{1, size}; | |
return c; | |
} | |
public Data generate() { | |
INDArray data; | |
if (setData0) { | |
data = Nd4j.zeros(shape); | |
} else { | |
data = Nd4j.rand(shape, random.nextInt(Integer.MAX_VALUE)).subi(0.5).muli(2); // To get in range [-1, 1] | |
} | |
INDArray mask = setMask0 ? Nd4j.zeros(shape) : Nd4j.ones(shape); | |
return new Data(data, mask); | |
} | |
} | |
public static class RandomDataIterator implements MultiDataSetIterator { | |
public DataConfig[] inputs; | |
public DataConfig[] labels; | |
public ArrayDeque<Example> examples; | |
public int totalNumExamples; | |
public int miniBatchSize; | |
public RandomDataIterator(DataConfig[] inputs, DataConfig[] labels, int totalNumExamples, int miniBatchSize) { | |
this.inputs = inputs; | |
this.labels = labels; | |
this.examples = new ArrayDeque<>(totalNumExamples); | |
this.totalNumExamples = totalNumExamples; | |
this.miniBatchSize = miniBatchSize; | |
this.reset(); | |
} | |
@Override | |
public boolean hasNext() { | |
// Doesn't care about returning exact batches | |
return hasNext(1); | |
} | |
@Override | |
public MultiDataSet next() { | |
return next(miniBatchSize); | |
} | |
public boolean hasNext(int batch) { | |
return examples.size() >= batch; | |
} | |
@Override | |
public MultiDataSet next(int batch) { | |
// Might need to produce a smaller than requested batch | |
batch = Math.min(batch, examples.size()); | |
if (batch == 0) { | |
throw new IllegalStateException("Ran out of Examples"); | |
} | |
// Stack several Examples into a batch | |
INDArray[][] inputData = new INDArray[numInputs()][batch]; | |
INDArray[][] inputMask = new INDArray[numInputs()][batch]; | |
INDArray[][] labelData = new INDArray[numLabels()][batch]; | |
INDArray[][] labelMask = new INDArray[numLabels()][batch]; | |
for (int a = 0; a < batch; a++) { | |
Example e = examples.pop(); | |
for (int b = 0; b < numInputs(); b++) { | |
inputData[b][a] = e.inputs[b].data; | |
inputMask[b][a] = e.inputs[b].mask; | |
} | |
for (int b = 0; b < numLabels(); b++) { | |
labelData[b][a] = e.labels[b].data; | |
labelMask[b][a] = e.labels[b].mask; | |
} | |
} | |
INDArray[] inputDataStack = new INDArray[numInputs()]; | |
INDArray[] inputMaskStack = new INDArray[numInputs()]; | |
for (int a = 0; a < numInputs(); a++) { | |
inputDataStack[a] = Nd4j.vstack(inputData[a]); | |
inputMaskStack[a] = Nd4j.vstack(inputMask[a]); | |
} | |
INDArray[] labelDataStack = new INDArray[numLabels()]; | |
INDArray[] labelMaskStack = new INDArray[numLabels()]; | |
for (int a = 0; a < numLabels(); a++) { | |
labelDataStack[a] = Nd4j.vstack(labelData[a]); | |
labelMaskStack[a] = Nd4j.vstack(labelMask[a]); | |
} | |
return new org.nd4j.linalg.dataset.MultiDataSet(inputDataStack, labelDataStack, inputMaskStack, labelMaskStack); | |
} | |
@Override | |
public void setPreProcessor(MultiDataSetPreProcessor preProcessor) { | |
throw new UnsupportedOperationException("MultiDataSetPreProcessor not supported"); | |
} | |
@Override | |
public MultiDataSetPreProcessor getPreProcessor() { | |
throw new UnsupportedOperationException("MultiDataSetPreProcessor not supported"); | |
} | |
@Override | |
public void reset() { | |
examples.clear(); | |
for (int a = 0; a < totalNumExamples; a++) { | |
Data[] inputs = new Data[this.inputs.length]; | |
for (int b = 0; b < this.inputs.length; b++) { | |
inputs[b] = this.inputs[b].generate(); | |
} | |
Data[] labels = new Data[this.labels.length]; | |
for (int b = 0; b < this.labels.length; b++) { | |
labels[b] = this.labels[b].generate(); | |
} | |
examples.add(new Example(inputs, labels)); | |
} | |
} | |
public int numInputs() { | |
return inputs.length; | |
} | |
public int numLabels() { | |
return labels.length; | |
} | |
@Override | |
public boolean resetSupported() { | |
return true; | |
} | |
@Override | |
public boolean asyncSupported() { | |
return false; | |
} | |
} | |
public static ComputationGraph generateModel(DataConfig[] inputs, DataConfig[] labels, int hiddenLayers, int preOutputDensity) { | |
String[] inputNames = new String[inputs.length]; | |
InputType[] inputTypes = new InputType[inputs.length]; | |
for (int a = 0; a < inputs.length; a++) { | |
// All set elements are merged by now, and have same dimensions | |
int[] shape = inputs[a].shape; | |
inputNames[a] = "Input " + a; | |
inputTypes[a] = new InputType.InputTypeRecurrent(shape[1], shape[2]); // [batch, size, timesteps] | |
} | |
ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder() | |
.optimizationAlgo(STOCHASTIC_GRADIENT_DESCENT) | |
.weightInit(XAVIER_UNIFORM) | |
.activation(TANH) | |
.updater(new AdaGrad(0.01)) | |
.l2(0.0001) | |
.seed(1234) | |
.graphBuilder() | |
.addInputs(inputNames) | |
.setInputTypes(inputTypes); | |
String[] layersToMerge = new String[inputs.length]; | |
int mergeSize = 0; | |
for (int a = 0; a < inputs.length; a++) { | |
String thisLayer = "RNN " + inputNames[a]; | |
long nOut = inputs[a].shape[1] * inputs[a].density; // Density is a nOut multiplier | |
builder.addLayer(thisLayer, new LastTimeStep(new LSTM.Builder() | |
.nOut(nOut) | |
.build()), inputNames[a]); | |
System.out.println(thisLayer + ": nOut = " + nOut); | |
layersToMerge[a] = thisLayer; | |
mergeSize += nOut; | |
} | |
String[] outputs = new String[labels.length]; | |
for (int a = 0; a < labels.length; a++) { | |
String[] lastLayer = layersToMerge; | |
int outputSize = labels[a].shape[labels[a].shape.length - 1]; | |
int reduction = Math.max(1, mergeSize / hiddenLayers); | |
long nIn = mergeSize; | |
for (int b = 0; b < hiddenLayers; b++) { | |
long nOut = nIn - reduction; | |
if (nOut < Math.max(preOutputDensity, outputSize)) { | |
break; | |
} | |
String thisLayer = "Dense Hidden " + a + "-" + b; | |
builder.addLayer(thisLayer, new DenseLayer.Builder() | |
.nIn(nIn) | |
.nOut(nOut) | |
.build(), lastLayer); | |
System.out.println(thisLayer + ": nIn = " + nIn + " nOut = " + nOut); | |
lastLayer = new String[]{thisLayer}; | |
nIn = nOut; | |
} | |
String thisLayer = "Output " + a; | |
builder.addLayer(thisLayer, new OutputLayer.Builder() | |
.activation(IDENTITY) | |
.lossFunction(MSE) | |
.nOut(outputSize) | |
.build(), lastLayer); | |
System.out.println(thisLayer + ": nOut = " + outputSize); | |
outputs[a] = thisLayer; | |
} | |
System.out.println(); | |
builder.setOutputs(outputs); | |
ComputationGraph CG = new ComputationGraph(builder.build()); | |
CG.init(); | |
return CG; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment