Created
August 30, 2018 08:59
-
-
Save liweigu/f0dafd8045d766b39c23e1b52938b2ac to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.liweigu.dls.competition.srad; | |
import java.io.IOException; | |
import java.util.HashMap; | |
import java.util.Map; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.BackpropType; | |
import org.deeplearning4j.nn.conf.GradientNormalization; | |
import org.deeplearning4j.nn.conf.InputPreProcessor; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder; | |
import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex; | |
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; | |
import org.deeplearning4j.nn.conf.layers.LSTM; | |
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; | |
import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor; | |
import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.learning.config.AdaGrad; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.nd4j.linalg.schedule.ISchedule; | |
import org.nd4j.linalg.schedule.MapSchedule; | |
import org.nd4j.linalg.schedule.ScheduleType; | |
/** | |
* Using 30 images to predict new 30 images. | |
* | |
* @author liweigu | |
* | |
*/ | |
public class SimpleConvLSTM { | |
private static int imageSize = 501; | |
private static int minibatch = 1; | |
private static int inputSize = 60; // 30 x 2 = 60, to use 30 images to predict new 30 images. | |
public static void run() throws IOException { | |
ComputationGraph multiLayerNetwork = getNetwork(); | |
DataSet dataSet = getData(); | |
multiLayerNetwork.fit(dataSet); | |
} | |
private static ComputationGraph getNetwork() { | |
ComputationGraph multiLayerNetwork; | |
NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(140) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).weightInit(WeightInit.XAVIER); | |
int kernelSize1 = 3; | |
int padding = 1; | |
int cnnStride1 = 5; | |
int lstmInWidth = (imageSize - kernelSize1 + padding) / cnnStride1 + 1; | |
int lstmHiddenCount = 200; | |
int channels = 1; | |
Map<String, InputPreProcessor> inputPreProcessors = new HashMap<String, InputPreProcessor>(); | |
inputPreProcessors.put("cnn1", new RnnToCnnPreProcessor(imageSize, imageSize, channels)); | |
inputPreProcessors.put("lstm1", new CnnToRnnPreProcessor(lstmInWidth, lstmInWidth, 128)); | |
Map<Integer, Double> lrSchedule = new HashMap<>(); | |
lrSchedule.put(0, 1e-2); | |
ISchedule mapSchedule = new MapSchedule(ScheduleType.ITERATION, lrSchedule); | |
GraphBuilder graphBuilder = builder.graphBuilder().pretrain(false).backprop(true) | |
.backpropType(BackpropType.Standard) | |
.addInputs("inputs") | |
.addLayer("cnn1", | |
new ConvolutionLayer.Builder(new int[] { kernelSize1, kernelSize1 }, | |
new int[] { cnnStride1, cnnStride1 }, | |
new int[] { padding, padding }) | |
.nIn(channels) | |
.nOut(64) | |
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) | |
.gradientNormalizationThreshold(10) | |
.updater(new AdaGrad(mapSchedule)) | |
.weightInit(WeightInit.RELU) | |
.activation(Activation.RELU).build(), "inputs") | |
.addLayer("lstm1", new LSTM.Builder() | |
.activation(Activation.SOFTSIGN) | |
.nIn(lstmInWidth * lstmInWidth * 128) | |
.nOut(lstmHiddenCount) | |
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) | |
.gradientNormalizationThreshold(10) | |
.updater(new AdaGrad(mapSchedule)) | |
.build(), "cnn1") | |
.addVertex("thoughtVector", new LastTimeStepVertex("inputs"), "lstm1") | |
.addVertex("dup", new DuplicateToTimeSeriesVertex("inputs"), "thoughtVector") | |
.addLayer("lstmDecode1", new LSTM.Builder() | |
.activation(Activation.SOFTSIGN) | |
.nIn(lstmHiddenCount) | |
.nOut(lstmHiddenCount) | |
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) | |
.gradientNormalizationThreshold(10) | |
.updater(new AdaGrad(mapSchedule)) | |
.build(), "dup") | |
.addLayer("output", new RnnOutputLayer | |
.Builder(LossFunctions.LossFunction.MSE) | |
.activation(Activation.RELU) | |
.nIn(lstmHiddenCount) | |
.nOut(imageSize * imageSize) | |
.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) | |
.gradientNormalizationThreshold(10) | |
.updater(new AdaGrad(mapSchedule)) | |
.build(), "lstmDecode1") | |
.setOutputs("output"); | |
graphBuilder.setInputPreProcessors(inputPreProcessors); | |
graphBuilder.setInputTypes(InputType.recurrent(imageSize * imageSize, inputSize)); | |
multiLayerNetwork = new ComputationGraph(graphBuilder.build()); | |
multiLayerNetwork.init(); | |
return multiLayerNetwork; | |
} | |
private static DataSet getData() { | |
double[] input = new double[minibatch * (imageSize * imageSize) * inputSize]; | |
for (int i = 1; i <= 30; i++) { | |
String filePath = "xxx_" + i; | |
readToArray(filePath, input, (i - 1)); | |
} | |
INDArray featureData = Nd4j.create(input, new int[] { minibatch, imageSize * imageSize, inputSize }); | |
double[] inputMask = new double[minibatch * inputSize]; | |
// assuming minibatch = 1, the following code is simplied | |
for (int i = 0; i < 30; i++) { | |
inputMask[i] = 1; | |
} | |
INDArray featuresMask = Nd4j.create(inputMask, new int[] { inputSize, minibatch }); | |
double[] output = new double[minibatch * (imageSize * imageSize) * inputSize]; | |
for (int i = 31; i <= 60; i++) { | |
String filePath = "xxx_" + i; | |
readToArray(filePath, output, (i - 1)); | |
} | |
INDArray labelData = Nd4j.create(output, new int[] { minibatch, imageSize * imageSize, inputSize }); | |
// mask | |
double[] outputMask = new double[minibatch * inputSize]; | |
// assuming minibatch = 1, the following code is simplied | |
for (int i = 30; i < minibatch * 1; i++) { | |
outputMask[i] = 1; | |
} | |
INDArray labelsMask = Nd4j.create(outputMask, new int[] { inputSize, minibatch }); | |
return new DataSet(featureData, labelData, featuresMask, labelsMask); | |
} | |
// read data from file to array 'data' | |
private static void readToArray(String filePath, double[] data, int sizeIndex) { | |
// assuming minibatch = 1, the following code is simplied | |
double fakeValue = 0.12345; // TODO read value from filePath | |
for (int i = 0; i < imageSize; i++) { | |
for (int j = 0; j < imageSize; j++) { | |
int demensionIndex = i * imageSize + j; | |
data[sizeIndex + demensionIndex * inputSize] = fakeValue; | |
} | |
} | |
} | |
public static void main(String[] args) throws IOException { | |
SimpleConvLSTM.run(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment