Skip to content

Instantly share code, notes, and snippets.

@liweigu
Created August 30, 2018 08:59
Show Gist options
  • Save liweigu/f0dafd8045d766b39c23e1b52938b2ac to your computer and use it in GitHub Desktop.
Save liweigu/f0dafd8045d766b39c23e1b52938b2ac to your computer and use it in GitHub Desktop.
package com.liweigu.dls.competition.srad;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder;
import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex;
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;
/**
* Using 30 images to predict new 30 images.
*
* @author liweigu
*
*/
public class SimpleConvLSTM {
private static int imageSize = 501;
private static int minibatch = 1;
private static int inputSize = 60; // 30 x 2 = 60, to use 30 images to predict new 30 images.
public static void run() throws IOException {
ComputationGraph multiLayerNetwork = getNetwork();
DataSet dataSet = getData();
multiLayerNetwork.fit(dataSet);
}
private static ComputationGraph getNetwork() {
ComputationGraph multiLayerNetwork;
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(140)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).weightInit(WeightInit.XAVIER);
int kernelSize1 = 3;
int padding = 1;
int cnnStride1 = 5;
int lstmInWidth = (imageSize - kernelSize1 + padding) / cnnStride1 + 1;
int lstmHiddenCount = 200;
int channels = 1;
Map<String, InputPreProcessor> inputPreProcessors = new HashMap<String, InputPreProcessor>();
inputPreProcessors.put("cnn1", new RnnToCnnPreProcessor(imageSize, imageSize, channels));
inputPreProcessors.put("lstm1", new CnnToRnnPreProcessor(lstmInWidth, lstmInWidth, 128));
Map<Integer, Double> lrSchedule = new HashMap<>();
lrSchedule.put(0, 1e-2);
ISchedule mapSchedule = new MapSchedule(ScheduleType.ITERATION, lrSchedule);
GraphBuilder graphBuilder = builder.graphBuilder().pretrain(false).backprop(true)
.backpropType(BackpropType.Standard)
.addInputs("inputs")
.addLayer("cnn1",
new ConvolutionLayer.Builder(new int[] { kernelSize1, kernelSize1 },
new int[] { cnnStride1, cnnStride1 },
new int[] { padding, padding })
.nIn(channels)
.nOut(64)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10)
.updater(new AdaGrad(mapSchedule))
.weightInit(WeightInit.RELU)
.activation(Activation.RELU).build(), "inputs")
.addLayer("lstm1", new LSTM.Builder()
.activation(Activation.SOFTSIGN)
.nIn(lstmInWidth * lstmInWidth * 128)
.nOut(lstmHiddenCount)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10)
.updater(new AdaGrad(mapSchedule))
.build(), "cnn1")
.addVertex("thoughtVector", new LastTimeStepVertex("inputs"), "lstm1")
.addVertex("dup", new DuplicateToTimeSeriesVertex("inputs"), "thoughtVector")
.addLayer("lstmDecode1", new LSTM.Builder()
.activation(Activation.SOFTSIGN)
.nIn(lstmHiddenCount)
.nOut(lstmHiddenCount)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10)
.updater(new AdaGrad(mapSchedule))
.build(), "dup")
.addLayer("output", new RnnOutputLayer
.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.RELU)
.nIn(lstmHiddenCount)
.nOut(imageSize * imageSize)
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(10)
.updater(new AdaGrad(mapSchedule))
.build(), "lstmDecode1")
.setOutputs("output");
graphBuilder.setInputPreProcessors(inputPreProcessors);
graphBuilder.setInputTypes(InputType.recurrent(imageSize * imageSize, inputSize));
multiLayerNetwork = new ComputationGraph(graphBuilder.build());
multiLayerNetwork.init();
return multiLayerNetwork;
}
private static DataSet getData() {
double[] input = new double[minibatch * (imageSize * imageSize) * inputSize];
for (int i = 1; i <= 30; i++) {
String filePath = "xxx_" + i;
readToArray(filePath, input, (i - 1));
}
INDArray featureData = Nd4j.create(input, new int[] { minibatch, imageSize * imageSize, inputSize });
double[] inputMask = new double[minibatch * inputSize];
// assuming minibatch = 1, the following code is simplied
for (int i = 0; i < 30; i++) {
inputMask[i] = 1;
}
INDArray featuresMask = Nd4j.create(inputMask, new int[] { inputSize, minibatch });
double[] output = new double[minibatch * (imageSize * imageSize) * inputSize];
for (int i = 31; i <= 60; i++) {
String filePath = "xxx_" + i;
readToArray(filePath, output, (i - 1));
}
INDArray labelData = Nd4j.create(output, new int[] { minibatch, imageSize * imageSize, inputSize });
// mask
double[] outputMask = new double[minibatch * inputSize];
// assuming minibatch = 1, the following code is simplied
for (int i = 30; i < minibatch * 1; i++) {
outputMask[i] = 1;
}
INDArray labelsMask = Nd4j.create(outputMask, new int[] { inputSize, minibatch });
return new DataSet(featureData, labelData, featuresMask, labelsMask);
}
// read data from file to array 'data'
private static void readToArray(String filePath, double[] data, int sizeIndex) {
// assuming minibatch = 1, the following code is simplied
double fakeValue = 0.12345; // TODO read value from filePath
for (int i = 0; i < imageSize; i++) {
for (int j = 0; j < imageSize; j++) {
int demensionIndex = i * imageSize + j;
data[sizeIndex + demensionIndex * inputSize] = fakeValue;
}
}
}
public static void main(String[] args) throws IOException {
SimpleConvLSTM.run();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment