Skip to content

Instantly share code, notes, and snippets.

@jumpingfella
Created May 25, 2018 14:17
Show Gist options
  • Save jumpingfella/ae884650f393266c34a6571bdb65bdfe to your computer and use it in GitHub Desktop.
Save jumpingfella/ae884650f393266c34a6571bdb65bdfe to your computer and use it in GitHub Desktop.
1..100 deeplearning4j RNN example
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.RmsProp;
import org.nd4j.linalg.lossfunctions.LossFunctions;
/**
* Created by pedro, brought up to date with dl4j 1.0 by jumpingfella
* An example of a recurrent neural network applied to a regression problem.
*/
public class RegressionRNN {
//Random number generator seed, for reproducibility
public static final int seed = 12345;
//Number of epochs (full passes of the data)
public static final int nEpochs = 3000;
//Number of data points
public static final int nSamples = 25;
//Network learning rate
public static final double learningRate = 0.0001;
public static void main(String[] args) {
//Generate the training data
DataSet trainingData = getTrainingData();
trainingData.shuffle();
System.out.println(trainingData);
System.out.println();
DataSet testData = getTestData();
System.out.println(testData);
//Create the network
int numInput = 1;
int numOutputs = 1;
int nHidden = 30;
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder();
builder.seed(seed);
builder.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT);
builder.updater(new RmsProp(learningRate));
builder.gradientNormalization(GradientNormalization.ClipL2PerLayer);
builder.gradientNormalizationThreshold(0.00001);
NeuralNetConfiguration.ListBuilder listBuilder = builder.list();
listBuilder.layer(0, new GravesLSTM.Builder().nIn(numInput).nOut(nHidden)
.activation(Activation.TANH).l2(0.0001).weightInit(WeightInit.XAVIER)
.build());
listBuilder.layer(1, new GravesLSTM.Builder().nIn(nHidden).nOut(nHidden)
.activation(Activation.TANH).l2(0.0001).weightInit(WeightInit.XAVIER)
.build());
listBuilder.layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.IDENTITY).l2(0.0001).weightInit(WeightInit.XAVIER)
.nIn(nHidden).nOut(numOutputs).build());
listBuilder.pretrain(false).backprop(true);
MultiLayerConfiguration conf = listBuilder.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
//net.setListeners(new HistogramIterationListener(1));
INDArray output;
//Train the network on the full data set
for( int i = 0; i < nEpochs; i++ ) {
// train the model
net.fit(trainingData);
output = net.rnnTimeStep(trainingData.getFeatureMatrix());
//System.out.println(output);
net.rnnClearPreviousState();
}
System.out.println("Result on training data: ");
System.out.println(trainingData.getFeatureMatrix());
System.out.println(net.rnnTimeStep(trainingData.getFeatureMatrix()));
System.out.println();
System.out.println("Result on test data: ");
System.out.println(testData.getFeatureMatrix());
System.out.println(net.rnnTimeStep(testData.getFeatureMatrix()));
//
// INDArray test = Nd4j.zeros(1, 1, 1);
// test.putScalar(0, 1.00);
// for (int i = 0; i < nSamples; i++) {
// output = net.rnnTimeStep(test);
// test.putScalar(0, output.getDouble(0));
// System.out.print(" " + output);
// }
}
/*
Generate the training data. The sequence to train is out = 1, 2, 3, ..., 100.
This corresponds to having as input the sequence seq = 0, 1, 2, ..., 99, so for this
training data set the input attribute sequence is seq and the class/target attribute is out.
The RNN should then be able to predict 101, 102, ... given the input 100, 101, ...
That is: the last output is the next input.
*/
private static DataSet getTrainingData() {
double[] seq = new double[nSamples];
double[] out = new double[nSamples];
// seq is 0, 1, 2, 3, .., nSamples-1
for (int i= 0; i < nSamples; i++) {
if(i == 0)
seq[i] = 0;
else
seq[i] = seq[i-1] + 1;
}
// out is the next seq input
for(int i = 0; i < nSamples; i++) {
if (i != (nSamples - 1))
out[i] = seq[i + 1];
else
out[i] = seq[i] + 1;
}
// Scaling to [0, 1] based on the training output
/*
int min = 1;
int max = nSamples;
for(int i = 0; i < nSamples; i++) {
seq[i] = (seq[i] - min)/(max - min);
out[i] = (out[i] - min)/(max - min);
}
*/
INDArray seqNDArray = Nd4j.create(seq, new int[]{nSamples,1});
INDArray inputNDArray = Nd4j.zeros(1,1,nSamples);
inputNDArray.putRow(0, seqNDArray.transpose());
INDArray outNDArray = Nd4j.create(out, new int[]{nSamples,1});
INDArray outputNDArray = Nd4j.zeros(1,1,nSamples);
outputNDArray.putRow(0, outNDArray.transpose());
DataSet dataSet = new DataSet(inputNDArray, outputNDArray);
return dataSet;
}
private static DataSet getTestData() {
int testLength = nSamples;
double[] seq = new double[testLength];
double[] out = new double[testLength];
for (int i= 0; i < testLength; i++) {
if(i == 0)
seq[i] = 25;
else
seq[i] = seq[i-1] + 1;
}
// out is the next seq input
for(int i = 0; i < testLength; i++) {
if (i != (testLength - 1))
out[i] = seq[i + 1];
else
out[i] = seq[i] + 1;
}
// Scaling to [0, 1] using same normalization as training data's
/*
int min = 1;
int max = nSamples;
for(int i = 0; i < nSamples; i++) {
seq[i] = (seq[i] - min)/(max - min);
out[i] = (out[i] - min)/(max - min);
}
*/
INDArray seqNDArray = Nd4j.create(seq, new int[]{testLength,1});
INDArray inputNDArray = Nd4j.zeros(1,1,testLength);
inputNDArray.putColumn(0, seqNDArray);
INDArray outNDArray = Nd4j.create(out, new int[]{testLength,1});
INDArray outputNDArray = Nd4j.zeros(1,1,testLength);
outputNDArray.putColumn(0, outNDArray);
DataSet dataSet = new DataSet(inputNDArray, outputNDArray);
return dataSet;
}
}
@jumpingfella
Copy link
Author

Output:
Result on training data:
[[[ 0, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000, 10.0000, 11.0000, 12.0000, 13.0000, 14.0000, 15.0000, 16.0000, 17.0000, 18.0000, 19.0000, 20.0000, 21.0000, 22.0000, 23.0000, 24.0000]]]
[[[ 0.9449, 1.9531, 3.0839, 4.0111, 4.9409, 5.9952, 7.0329, 7.9967, 8.9157, 10.0582, 11.0989, 11.4641, 11.5335, 11.5449, 11.5469, 11.5473, 11.5474, 11.5474, 11.5475, 11.5475, 11.5476, 11.5476, 11.5476, 11.5476, 11.5476]]]

Result on test data:
[[[ 25.0000, 26.0000, 27.0000, 28.0000, 29.0000, 30.0000, 31.0000, 32.0000, 33.0000, 34.0000, 35.0000, 36.0000, 37.0000, 38.0000, 39.0000, 40.0000, 41.0000, 42.0000, 43.0000, 44.0000, 45.0000, 46.0000, 47.0000, 48.0000, 49.0000]]]
[[[ 11.5476, 11.5476, 11.5476, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477, 11.5477]]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment