Created
May 25, 2018 14:17
-
-
Save jumpingfella/ae884650f393266c34a6571bdb65bdfe to your computer and use it in GitHub Desktop.
1..100 deeplearning4j RNN example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.deeplearning4j.nn.conf.GradientNormalization; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.layers.GravesLSTM; | |
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.learning.config.RmsProp; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
/** | |
* Created by pedro, brought up to date with dl4j 1.0 by jumpingfella | |
* An example of a recurrent neural network applied to a regression problem. | |
*/ | |
public class RegressionRNN { | |
//Random number generator seed, for reproducibility | |
public static final int seed = 12345; | |
//Number of epochs (full passes of the data) | |
public static final int nEpochs = 3000; | |
//Number of data points | |
public static final int nSamples = 25; | |
//Network learning rate | |
public static final double learningRate = 0.0001; | |
public static void main(String[] args) { | |
//Generate the training data | |
DataSet trainingData = getTrainingData(); | |
trainingData.shuffle(); | |
System.out.println(trainingData); | |
System.out.println(); | |
DataSet testData = getTestData(); | |
System.out.println(testData); | |
//Create the network | |
int numInput = 1; | |
int numOutputs = 1; | |
int nHidden = 30; | |
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder(); | |
builder.seed(seed); | |
builder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); | |
builder.updater(new RmsProp(learningRate)); | |
builder.gradientNormalization(GradientNormalization.ClipL2PerLayer); | |
builder.gradientNormalizationThreshold(0.00001); | |
NeuralNetConfiguration.ListBuilder listBuilder = builder.list(); | |
listBuilder.layer(0, new GravesLSTM.Builder().nIn(numInput).nOut(nHidden) | |
.activation(Activation.TANH).l2(0.0001).weightInit(WeightInit.XAVIER) | |
.build()); | |
listBuilder.layer(1, new GravesLSTM.Builder().nIn(nHidden).nOut(nHidden) | |
.activation(Activation.TANH).l2(0.0001).weightInit(WeightInit.XAVIER) | |
.build()); | |
listBuilder.layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE) | |
.activation(Activation.IDENTITY).l2(0.0001).weightInit(WeightInit.XAVIER) | |
.nIn(nHidden).nOut(numOutputs).build()); | |
listBuilder.pretrain(false).backprop(true); | |
MultiLayerConfiguration conf = listBuilder.build(); | |
MultiLayerNetwork net = new MultiLayerNetwork(conf); | |
net.init(); | |
//net.setListeners(new HistogramIterationListener(1)); | |
INDArray output; | |
//Train the network on the full data set | |
for( int i = 0; i < nEpochs; i++ ) { | |
// train the model | |
net.fit(trainingData); | |
output = net.rnnTimeStep(trainingData.getFeatureMatrix()); | |
//System.out.println(output); | |
net.rnnClearPreviousState(); | |
} | |
System.out.println("Result on training data: "); | |
System.out.println(trainingData.getFeatureMatrix()); | |
System.out.println(net.rnnTimeStep(trainingData.getFeatureMatrix())); | |
System.out.println(); | |
System.out.println("Result on test data: "); | |
System.out.println(testData.getFeatureMatrix()); | |
System.out.println(net.rnnTimeStep(testData.getFeatureMatrix())); | |
// | |
// INDArray test = Nd4j.zeros(1, 1, 1); | |
// test.putScalar(0, 1.00); | |
// for (int i = 0; i < nSamples; i++) { | |
// output = net.rnnTimeStep(test); | |
// test.putScalar(0, output.getDouble(0)); | |
// System.out.print(" " + output); | |
// } | |
} | |
/* | |
Generate the training data. The sequence to train is out = 1, 2, 3, ..., 100. | |
This corresponds to having as input the sequence seq = 0, 1, 2, ..., 99, so for this | |
training data set the input attribute sequence is seq and the class/target attribute is out. | |
The RNN should then be able to predict 101, 102, ... given the input 100, 101, ... | |
That is: the last output is the next input. | |
*/ | |
private static DataSet getTrainingData() { | |
double[] seq = new double[nSamples]; | |
double[] out = new double[nSamples]; | |
// seq is 0, 1, 2, 3, .., nSamples-1 | |
for (int i= 0; i < nSamples; i++) { | |
if(i == 0) | |
seq[i] = 0; | |
else | |
seq[i] = seq[i-1] + 1; | |
} | |
// out is the next seq input | |
for(int i = 0; i < nSamples; i++) { | |
if (i != (nSamples - 1)) | |
out[i] = seq[i + 1]; | |
else | |
out[i] = seq[i] + 1; | |
} | |
// Scaling to [0, 1] based on the training output | |
/* | |
int min = 1; | |
int max = nSamples; | |
for(int i = 0; i < nSamples; i++) { | |
seq[i] = (seq[i] - min)/(max - min); | |
out[i] = (out[i] - min)/(max - min); | |
} | |
*/ | |
INDArray seqNDArray = Nd4j.create(seq, new int[]{nSamples,1}); | |
INDArray inputNDArray = Nd4j.zeros(1,1,nSamples); | |
inputNDArray.putRow(0, seqNDArray.transpose()); | |
INDArray outNDArray = Nd4j.create(out, new int[]{nSamples,1}); | |
INDArray outputNDArray = Nd4j.zeros(1,1,nSamples); | |
outputNDArray.putRow(0, outNDArray.transpose()); | |
DataSet dataSet = new DataSet(inputNDArray, outputNDArray); | |
return dataSet; | |
} | |
private static DataSet getTestData() { | |
int testLength = nSamples; | |
double[] seq = new double[testLength]; | |
double[] out = new double[testLength]; | |
for (int i= 0; i < testLength; i++) { | |
if(i == 0) | |
seq[i] = 25; | |
else | |
seq[i] = seq[i-1] + 1; | |
} | |
// out is the next seq input | |
for(int i = 0; i < testLength; i++) { | |
if (i != (testLength - 1)) | |
out[i] = seq[i + 1]; | |
else | |
out[i] = seq[i] + 1; | |
} | |
// Scaling to [0, 1] using same normalization as training data's | |
/* | |
int min = 1; | |
int max = nSamples; | |
for(int i = 0; i < nSamples; i++) { | |
seq[i] = (seq[i] - min)/(max - min); | |
out[i] = (out[i] - min)/(max - min); | |
} | |
*/ | |
INDArray seqNDArray = Nd4j.create(seq, new int[]{testLength,1}); | |
INDArray inputNDArray = Nd4j.zeros(1,1,testLength); | |
inputNDArray.putColumn(0, seqNDArray); | |
INDArray outNDArray = Nd4j.create(out, new int[]{testLength,1}); | |
INDArray outputNDArray = Nd4j.zeros(1,1,testLength); | |
outputNDArray.putColumn(0, outNDArray); | |
DataSet dataSet = new DataSet(inputNDArray, outputNDArray); | |
return dataSet; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output:
Result on training data:
[[[ 0, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, 10.0000, 11.0000, 12.0000, 13.0000, 14.0000, 15.0000, 16.0000, 17.0000, 18.0000, 19.0000, 20.0000, 21.0000, 22.0000, 23.0000, 24.0000]]]
[[[ 0.9449, 1.9531, 3.0839, 4.0111, 4.9409, 5.9952, 7.0329, 7.9967, 8.9157, 10.0582, 11.0989, 11.4641, 11.5335, 11.5449, 11.5469, 11.5473, 11.5474, 11.5474, 11.5475, 11.5475, 11.5476, 11.5476, 11.5476, 11.5476, 11.5476]]]
Result on test data:
[[[ 25.0000, 26.0000, 27.0000, 28.0000, 29.0000, 30.0000, 31.0000, 32.0000, 33.0000, 34.0000, 35.0000, 36.0000, 37.0000, 38.0000, 39.0000, 40.0000, 41.0000, 42.0000, 43.0000, 44.0000, 45.0000, 46.0000, 47.0000, 48.0000, 49.0000]]]
[[[ 11.5476, 11.5476, 11.5476, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477]]]