Skip to content

Instantly share code, notes, and snippets.

@bartvollebregt
Last active September 12, 2024 21:59
Show Gist options
  • Save bartvollebregt/0f6b9c0fcb81cf74583104613ee1d9c6 to your computer and use it in GitHub Desktop.
Save bartvollebregt/0f6b9c0fcb81cf74583104613ee1d9c6 to your computer and use it in GitHub Desktop.
package org.deeplearning4j.examples.recurrent.regression;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartPanel;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import org.jfree.ui.RefineryUtilities;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.swing.*;
import java.io.*;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* This example was inspired by Jason Brownlee's regression examples for Keras, found here:
* http://machinelearningmastery.com/time-series-prediction-lstm-recurrent-neural-networks-python-keras/
*
* It demonstrates single time step regression using LSTM
*/
public class TimestempTest {
private static final Logger LOGGER = LoggerFactory.getLogger(TimestempTest.class);
private static File baseDir = new File("dl4j-examples/src/main/resources");
private static int timeStepsPerDay = 1; //10 minute steps
private static int miniBatchSize = 32;
private static int trainsize = 2000;
public static int countLines(String filename) throws IOException {
InputStream is = new BufferedInputStream(new FileInputStream(filename));
try {
byte[] c = new byte[1024];
int count = 0;
int readChars = 0;
boolean empty = true;
while ((readChars = is.read(c)) != -1) {
empty = false;
for (int i = 0; i < readChars; ++i) {
if (c[i] == '\n') {
++count;
}
}
}
return (count == 0 && !empty) ? 1 : count;
} finally {
is.close();
}
}
private static List<List<List<Writable>>> prepareTempData(List<String> rawStrings, int from, int to) {
List<List<List<Writable>>> topSequences = new ArrayList<>();
List<List<Writable>> listOfSequences = new ArrayList<>();
List<Writable> sequence = new ArrayList<>();
boolean first = true;
for(int i=from;i < (to - 1);i++) {
if(first && from == 0) {
first = false;
continue;
}
sequence.add(new DoubleWritable(Double.parseDouble(rawStrings.get(i))));
}
listOfSequences.add(sequence);
topSequences.add(listOfSequences);
return topSequences;
}
public static void main(String[] args) throws Exception {
String filePath = baseDir.getAbsolutePath() + "/DataExamples/temps_raw.csv";
Path rawPath = Paths.get(filePath);
List<String> rawStrings = null;
try {
rawStrings = Files.readAllLines(rawPath, Charset.defaultCharset());
} catch (IOException e) {
e.printStackTrace();
}
int fullSize = countLines(filePath);
// Traindata
List<List<List<Writable>>> sequenceTrainData1 = prepareTempData(rawStrings, 0, trainsize);
List<List<List<Writable>>> sequenceTrainData2 = prepareTempData(rawStrings, 0, trainsize);
SequenceRecordReader trainReader = new CollectionSequenceRecordReader(sequenceTrainData1);
SequenceRecordReader trainReaderLabels = new CollectionSequenceRecordReader(sequenceTrainData2);
DataSetIterator trainIter = new SequenceRecordReaderDataSetIterator(trainReader, trainReaderLabels, miniBatchSize, -1, true);
// Testdata
List<List<List<Writable>>> sequenceTestData1 = prepareTempData(rawStrings, trainsize, fullSize);
List<List<List<Writable>>> sequenceTestData2 = prepareTempData(rawStrings, trainsize, fullSize);
SequenceRecordReader testReader = new CollectionSequenceRecordReader(sequenceTestData1);
SequenceRecordReader testReaderLabels = new CollectionSequenceRecordReader(sequenceTestData2);
DataSetIterator testIter = new SequenceRecordReaderDataSetIterator(testReader, testReaderLabels, miniBatchSize, -1, true);
//Create data set from iterator here since we only have a single data set
DataSet trainData = trainIter.next();
DataSet testData = testIter.next();
System.out.println("===DATA===");
System.out.println(trainData);
System.out.println(testData);
//Normalize data, including labels (fitLabel=true)
NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
normalizer.fitLabel(true);
normalizer.fit(trainData); //Collect training data statistics
normalizer.transform(trainData);
normalizer.transform(testData);
System.out.println("===Normalized===");
System.out.println(trainData);
System.out.println(testData);
System.exit(0);
// Configure to Network
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(140)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.weightInit(WeightInit.XAVIER)
.updater(Updater.NESTEROVS).momentum(0.9)
.learningRate(0.0001)
.list()
.layer(0, new GravesLSTM.Builder().activation("tanh").nIn(1).nOut(10)
.build())
.layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation("identity").nIn(10).nOut(1).build())
.pretrain(false)
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(20));
// Setup the UI server so we can monitor its progress
StatsStorage statsStorage = new InMemoryStatsStorage();
net.setListeners(new StatsListener(statsStorage));
UIServer uiServer = UIServer.getInstance();
uiServer.attach(statsStorage);
// ----- Train the network, evaluating the test set performance at each epoch -----
int nEpochs = 50;
for (int i = 0; i < nEpochs; i++) {
net.fit(trainData);
LOGGER.info("Epoch " + i + " complete. Time series evaluation:");
//Run regression evaluation on our single column input
RegressionEvaluation evaluation = new RegressionEvaluation(1);
INDArray features = testData.getFeatureMatrix();
INDArray lables = testData.getLabels();
INDArray predicted = net.output(features, false);
evaluation.evalTimeSeries(lables, predicted);
//Just do sout here since the logger will shift the shift the columns of the stats
System.out.println(evaluation.stats());
}
//Init rrnTimeStemp with train data and predict test data
net.rnnTimeStep(trainData.getFeatureMatrix());
INDArray predicted = net.rnnTimeStep(testData.getFeatureMatrix());
//Revert data back to original values for plotting
normalizer.revert(trainData);
normalizer.revert(testData);
normalizer.revertLabels(predicted);
//Create plot with out data
XYSeriesCollection c = new XYSeriesCollection();
createSeries(c, trainData.getFeatures(), 0, "Train data");
createSeries(c, testData.getFeatures(), trainsize, "Actual test data");
createSeries(c, predicted, fullSize-1, "Predicted test data");
plotDataset(c);
LOGGER.info("----- Example Complete -----");
}
/**
* Creates an IndArray from a list of strings
* Used for plotting purposes
*/
private static INDArray createIndArrayFromStringList(List<String> rawStrings, int startIndex, int length) {
List<String> stringList = rawStrings.subList(startIndex,startIndex+length);
double[] primitives = new double[stringList.size()];
for (int i = 0; i < stringList.size(); i++) {
primitives[i] = Double.valueOf(stringList.get(i));
}
return Nd4j.create(new int[]{1,length},primitives);
}
private static XYSeriesCollection createSeries(XYSeriesCollection seriesCollection, INDArray data, int offset, String name) {
int nRows = data.shape()[2];
System.out.println(nRows);
XYSeries series = new XYSeries(name);
for (int i = 0; i < nRows; i++) {
series.add(i + offset, data.getDouble(i));
}
seriesCollection.addSeries(series);
return seriesCollection;
}
/**
* Generate an xy plot of the datasets provided.
*/
private static void plotDataset(XYSeriesCollection c) {
String title = "Regression example";
String xAxisLabel = "Timestep";
String yAxisLabel = "Temperature";
PlotOrientation orientation = PlotOrientation.VERTICAL;
boolean legend = true;
boolean tooltips = false;
boolean urls = false;
JFreeChart chart = ChartFactory.createXYLineChart(title, xAxisLabel, yAxisLabel, c, orientation, legend, tooltips, urls);
// get a reference to the plot for further customisation...
final XYPlot plot = chart.getXYPlot();
// Auto zoom to fit time series in initial window
final NumberAxis rangeAxis = (NumberAxis) plot.getRangeAxis();
rangeAxis.setAutoRange(true);
JPanel panel = new ChartPanel(chart);
JFrame f = new JFrame();
f.add(panel);
f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
f.pack();
f.setTitle("Training Data");
RefineryUtilities.centerFrameOnScreen(f);
f.setVisible(true);
}
}
@soulaway
Copy link

Having error based on your snippet, however dunno the versions you are using (tested on 1.0.0-M1.1)
Shapes do not match: dimensions[0] - x[1] must match y[1], x shape [1, 1640, 1], y shape [1, 1998], dimensions [1]
I got working one LSTM within this stack (v. 1.0.0-M1.1) that were found here:
https://github.com/deeplearning4j/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/Conv1DUCISequenceClassification.java
If you're still interested we may collaborate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment