Last active
September 12, 2024 21:59
-
-
Save bartvollebregt/0f6b9c0fcb81cf74583104613ee1d9c6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.recurrent.regression; | |
import org.datavec.api.records.reader.SequenceRecordReader; | |
import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader; | |
import org.datavec.api.writable.DoubleWritable; | |
import org.datavec.api.writable.Writable; | |
import org.deeplearning4j.api.storage.StatsStorage; | |
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; | |
import org.deeplearning4j.eval.RegressionEvaluation; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.Updater; | |
import org.deeplearning4j.nn.conf.layers.GravesLSTM; | |
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.deeplearning4j.ui.api.UIServer; | |
import org.deeplearning4j.ui.stats.StatsListener; | |
import org.deeplearning4j.ui.storage.InMemoryStatsStorage; | |
import org.jfree.chart.ChartFactory; | |
import org.jfree.chart.ChartPanel; | |
import org.jfree.chart.JFreeChart; | |
import org.jfree.chart.axis.NumberAxis; | |
import org.jfree.chart.plot.PlotOrientation; | |
import org.jfree.chart.plot.XYPlot; | |
import org.jfree.data.xy.XYSeries; | |
import org.jfree.data.xy.XYSeriesCollection; | |
import org.jfree.ui.RefineryUtilities; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.SplitTestAndTrain; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import javax.swing.*; | |
import java.io.*; | |
import java.nio.charset.Charset; | |
import java.nio.file.Files; | |
import java.nio.file.Path; | |
import java.nio.file.Paths; | |
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.List; | |
/** | |
* This example was inspired by Jason Brownlee's regression examples for Keras, found here: | |
* http://machinelearningmastery.com/time-series-prediction-lstm-recurrent-neural-networks-python-keras/ | |
* | |
* It demonstrates single time step regression using LSTM | |
*/ | |
public class TimestempTest { | |
private static final Logger LOGGER = LoggerFactory.getLogger(TimestempTest.class); | |
private static File baseDir = new File("dl4j-examples/src/main/resources"); | |
private static int timeStepsPerDay = 1; //10 minute steps | |
private static int miniBatchSize = 32; | |
private static int trainsize = 2000; | |
public static int countLines(String filename) throws IOException { | |
InputStream is = new BufferedInputStream(new FileInputStream(filename)); | |
try { | |
byte[] c = new byte[1024]; | |
int count = 0; | |
int readChars = 0; | |
boolean empty = true; | |
while ((readChars = is.read(c)) != -1) { | |
empty = false; | |
for (int i = 0; i < readChars; ++i) { | |
if (c[i] == '\n') { | |
++count; | |
} | |
} | |
} | |
return (count == 0 && !empty) ? 1 : count; | |
} finally { | |
is.close(); | |
} | |
} | |
private static List<List<List<Writable>>> prepareTempData(List<String> rawStrings, int from, int to) { | |
List<List<List<Writable>>> topSequences = new ArrayList<>(); | |
List<List<Writable>> listOfSequences = new ArrayList<>(); | |
List<Writable> sequence = new ArrayList<>(); | |
boolean first = true; | |
for(int i=from;i < (to - 1);i++) { | |
if(first && from == 0) { | |
first = false; | |
continue; | |
} | |
sequence.add(new DoubleWritable(Double.parseDouble(rawStrings.get(i)))); | |
} | |
listOfSequences.add(sequence); | |
topSequences.add(listOfSequences); | |
return topSequences; | |
} | |
public static void main(String[] args) throws Exception { | |
String filePath = baseDir.getAbsolutePath() + "/DataExamples/temps_raw.csv"; | |
Path rawPath = Paths.get(filePath); | |
List<String> rawStrings = null; | |
try { | |
rawStrings = Files.readAllLines(rawPath, Charset.defaultCharset()); | |
} catch (IOException e) { | |
e.printStackTrace(); | |
} | |
int fullSize = countLines(filePath); | |
// Traindata | |
List<List<List<Writable>>> sequenceTrainData1 = prepareTempData(rawStrings, 0, trainsize); | |
List<List<List<Writable>>> sequenceTrainData2 = prepareTempData(rawStrings, 0, trainsize); | |
SequenceRecordReader trainReader = new CollectionSequenceRecordReader(sequenceTrainData1); | |
SequenceRecordReader trainReaderLabels = new CollectionSequenceRecordReader(sequenceTrainData2); | |
DataSetIterator trainIter = new SequenceRecordReaderDataSetIterator(trainReader, trainReaderLabels, miniBatchSize, -1, true); | |
// Testdata | |
List<List<List<Writable>>> sequenceTestData1 = prepareTempData(rawStrings, trainsize, fullSize); | |
List<List<List<Writable>>> sequenceTestData2 = prepareTempData(rawStrings, trainsize, fullSize); | |
SequenceRecordReader testReader = new CollectionSequenceRecordReader(sequenceTestData1); | |
SequenceRecordReader testReaderLabels = new CollectionSequenceRecordReader(sequenceTestData2); | |
DataSetIterator testIter = new SequenceRecordReaderDataSetIterator(testReader, testReaderLabels, miniBatchSize, -1, true); | |
//Create data set from iterator here since we only have a single data set | |
DataSet trainData = trainIter.next(); | |
DataSet testData = testIter.next(); | |
System.out.println("===DATA==="); | |
System.out.println(trainData); | |
System.out.println(testData); | |
//Normalize data, including labels (fitLabel=true) | |
NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1); | |
normalizer.fitLabel(true); | |
normalizer.fit(trainData); //Collect training data statistics | |
normalizer.transform(trainData); | |
normalizer.transform(testData); | |
System.out.println("===Normalized==="); | |
System.out.println(trainData); | |
System.out.println(testData); | |
System.exit(0); | |
// Configure to Network | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(140) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.iterations(1) | |
.weightInit(WeightInit.XAVIER) | |
.updater(Updater.NESTEROVS).momentum(0.9) | |
.learningRate(0.0001) | |
.list() | |
.layer(0, new GravesLSTM.Builder().activation("tanh").nIn(1).nOut(10) | |
.build()) | |
.layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE) | |
.activation("identity").nIn(10).nOut(1).build()) | |
.pretrain(false) | |
.build(); | |
MultiLayerNetwork net = new MultiLayerNetwork(conf); | |
net.init(); | |
net.setListeners(new ScoreIterationListener(20)); | |
// Setup the UI server so we can monitor its progress | |
StatsStorage statsStorage = new InMemoryStatsStorage(); | |
net.setListeners(new StatsListener(statsStorage)); | |
UIServer uiServer = UIServer.getInstance(); | |
uiServer.attach(statsStorage); | |
// ----- Train the network, evaluating the test set performance at each epoch ----- | |
int nEpochs = 50; | |
for (int i = 0; i < nEpochs; i++) { | |
net.fit(trainData); | |
LOGGER.info("Epoch " + i + " complete. Time series evaluation:"); | |
//Run regression evaluation on our single column input | |
RegressionEvaluation evaluation = new RegressionEvaluation(1); | |
INDArray features = testData.getFeatureMatrix(); | |
INDArray lables = testData.getLabels(); | |
INDArray predicted = net.output(features, false); | |
evaluation.evalTimeSeries(lables, predicted); | |
//Just do sout here since the logger will shift the shift the columns of the stats | |
System.out.println(evaluation.stats()); | |
} | |
//Init rrnTimeStemp with train data and predict test data | |
net.rnnTimeStep(trainData.getFeatureMatrix()); | |
INDArray predicted = net.rnnTimeStep(testData.getFeatureMatrix()); | |
//Revert data back to original values for plotting | |
normalizer.revert(trainData); | |
normalizer.revert(testData); | |
normalizer.revertLabels(predicted); | |
//Create plot with out data | |
XYSeriesCollection c = new XYSeriesCollection(); | |
createSeries(c, trainData.getFeatures(), 0, "Train data"); | |
createSeries(c, testData.getFeatures(), trainsize, "Actual test data"); | |
createSeries(c, predicted, fullSize-1, "Predicted test data"); | |
plotDataset(c); | |
LOGGER.info("----- Example Complete -----"); | |
} | |
/** | |
* Creates an IndArray from a list of strings | |
* Used for plotting purposes | |
*/ | |
private static INDArray createIndArrayFromStringList(List<String> rawStrings, int startIndex, int length) { | |
List<String> stringList = rawStrings.subList(startIndex,startIndex+length); | |
double[] primitives = new double[stringList.size()]; | |
for (int i = 0; i < stringList.size(); i++) { | |
primitives[i] = Double.valueOf(stringList.get(i)); | |
} | |
return Nd4j.create(new int[]{1,length},primitives); | |
} | |
private static XYSeriesCollection createSeries(XYSeriesCollection seriesCollection, INDArray data, int offset, String name) { | |
int nRows = data.shape()[2]; | |
System.out.println(nRows); | |
XYSeries series = new XYSeries(name); | |
for (int i = 0; i < nRows; i++) { | |
series.add(i + offset, data.getDouble(i)); | |
} | |
seriesCollection.addSeries(series); | |
return seriesCollection; | |
} | |
/** | |
* Generate an xy plot of the datasets provided. | |
*/ | |
private static void plotDataset(XYSeriesCollection c) { | |
String title = "Regression example"; | |
String xAxisLabel = "Timestep"; | |
String yAxisLabel = "Temperature"; | |
PlotOrientation orientation = PlotOrientation.VERTICAL; | |
boolean legend = true; | |
boolean tooltips = false; | |
boolean urls = false; | |
JFreeChart chart = ChartFactory.createXYLineChart(title, xAxisLabel, yAxisLabel, c, orientation, legend, tooltips, urls); | |
// get a reference to the plot for further customisation... | |
final XYPlot plot = chart.getXYPlot(); | |
// Auto zoom to fit time series in initial window | |
final NumberAxis rangeAxis = (NumberAxis) plot.getRangeAxis(); | |
rangeAxis.setAutoRange(true); | |
JPanel panel = new ChartPanel(chart); | |
JFrame f = new JFrame(); | |
f.add(panel); | |
f.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); | |
f.pack(); | |
f.setTitle("Training Data"); | |
RefineryUtilities.centerFrameOnScreen(f); | |
f.setVisible(true); | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Having error based on your snippet, however dunno the versions you are using (tested on 1.0.0-M1.1)
Shapes do not match: dimensions[0] - x[1] must match y[1], x shape [1, 1640, 1], y shape [1, 1998], dimensions [1]
I got working one LSTM within this stack (v. 1.0.0-M1.1) that were found here:
https://github.com/deeplearning4j/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/Conv1DUCISequenceClassification.java
If you're still interested we may collaborate.