Skip to content

Instantly share code, notes, and snippets.

@Seppo420
Last active August 29, 2015 14:27
Show Gist options
  • Save Seppo420/6fe1544cbc1764b224a3 to your computer and use it in GitHub Desktop.
Save Seppo420/6fe1544cbc1764b224a3 to your computer and use it in GitHub Desktop.
===========INPUT===================
[[ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00]
[ 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00]
[ 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00]
[ 3.00, 3.00, 3.00, 3.00, 3.00, 3.00, 3.00, 3.00, 3.00, 3.00, 3.00]
[ 4.00, 4.00, 4.00, 4.00, 4.00, 4.00, 4.00, 4.00, 4.00, 4.00, 4.00]
[ 5.00, 5.00, 5.00, 5.00, 5.00, 5.00, 5.00, 5.00, 5.00, 5.00, 5.00]
[ 6.00, 6.00, 6.00, 6.00, 6.00, 6.00, 6.00, 6.00, 6.00, 6.00, 6.00]
[ 7.00, 7.00, 7.00, 7.00, 7.00, 7.00, 7.00, 7.00, 7.00, 7.00, 7.00]
[ 8.00, 8.00, 8.00, 8.00, 8.00, 8.00, 8.00, 8.00, 8.00, 8.00, 8.00]
[ 9.00, 9.00, 9.00, 9.00, 9.00, 9.00, 9.00, 9.00, 9.00, 9.00, 9.00]
[ 10.00, 10.00, 10.00, 10.00, 10.00, 10.00, 10.00, 10.00, 10.00, 10.00, 10.00]]
=================OUTPUT==================
[[ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00]
[ 0.84, 0.84, 0.84, 0.84, 0.84, 0.84, 0.84, 0.84, 0.84, 0.84, 0.84]
[ 0.91, 0.91, 0.91, 0.91, 0.91, 0.91, 0.91, 0.91, 0.91, 0.91, 0.91]
[ 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14]
[ -0.76, -0.76, -0.76, -0.76, -0.76, -0.76, -0.76, -0.76, -0.76, -0.76, -0.76]
[ -0.96, -0.96, -0.96, -0.96, -0.96, -0.96, -0.96, -0.96, -0.96, -0.96, -0.96]
[ -0.28, -0.28, -0.28, -0.28, -0.28, -0.28, -0.28, -0.28, -0.28, -0.28, -0.28]
[ 0.66, 0.66, 0.66, 0.66, 0.66, 0.66, 0.66, 0.66, 0.66, 0.66, 0.66]
[ 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99]
[ 0.41, 0.41, 0.41, 0.41, 0.41, 0.41, 0.41, 0.41, 0.41, 0.41, 0.41]
[ -0.54, -0.54, -0.54, -0.54, -0.54, -0.54, -0.54, -0.54, -0.54, -0.54, -0.54]]
package testicle;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import java.util.List;
import java.util.function.Function;
import java.util.function.IntFunction;
/**
* Created by jp on 8/19/15.
*/
public class DataSetCreator {
public static DataSet create(int nIn, int nOut, int timeSeriesLength,int miniBatchSize,
IntFunction<double[]> inputGenerator,
IntFunction<double[]> outputGenerator){
INDArray input = Nd4j.zeros(miniBatchSize, nIn, timeSeriesLength);
for( int i=0; i<miniBatchSize; i++ ){
double[] in = inputGenerator.apply(i);
for( int j=0; j<nIn; j++ ){
for( int k=0; k<timeSeriesLength; k++ ){
input.putScalar(new int[]{i,j,k},in[j]);
}
}
}
INDArray output = Nd4j.zeros(miniBatchSize, nIn, timeSeriesLength);
for( int i=0; i<miniBatchSize; i++ ){
double[] in = inputGenerator.apply(i);
double[] out = outputGenerator.apply(i);
for( int j=0; j<nIn; j++ ){
for( int k=0; k<timeSeriesLength; k++ ){
output.putScalar(new int[]{i,j,k},out[j]);
}
}
}
org.nd4j.linalg.dataset.DataSet ds = new org.nd4j.linalg.dataset.DataSet(input,output);
return ds;
}
public static DataSet fromList(List<double[]> inputs, List<double[]> outputs){
int notSureWhatThisDoes=11;
return create(inputs.get(0).length,outputs.get(0).length,inputs.size(),notSureWhatThisDoes,
i->inputs.get(i),
i->outputs.get(i));
}
}
package testicle;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
/**
* Created by jp on 8/19/15.
*/
public class LSTMTest{
static Logger log = LoggerFactory.getLogger(LSTMTest.class);
public static void main(String... args)throws Exception {
Nd4j.getRandom().setSeed(12345L);
int timeSeriesLength = 11; //works when this is 1, crashes otherwise, see stacktrace
int nIn = 1;
int layerSize = 2;
int nOut = 1;
int miniBatchSize = 11;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0,0.1))
.regularization(false)
.updater(Updater.NONE)
.seed(12345L)
.list(2)
.layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).activation("tanh").build())
.layer(1, new OutputLayer.Builder(LossFunction.MCXENT).activation("softmax").nIn(layerSize).nOut(nOut).build())
.inputPreProcessor(1, new RnnToFeedForwardPreProcessor(timeSeriesLength))
.pretrain(false).backprop(true)
.build();
MultiLayerNetwork mln = new MultiLayerNetwork(conf);
mln.init();
mln.fit(DataSetCreator.create(nIn, nOut, timeSeriesLength, miniBatchSize,
i -> new double[]{i},
i -> new double[]{Math.sin(i)})
);
log.info("result: {}",mln.output(Nd4j.zeros(1,nIn)));
}
}
Exception in thread "main" java.lang.IllegalArgumentException: Shape must be <= buffer length
at org.nd4j.linalg.api.ndarray.BaseNDArray.<init>(BaseNDArray.java:133)
at org.nd4j.linalg.jblas.NDArray.<init>(NDArray.java:65)
at org.nd4j.linalg.jblas.JblasNDArrayFactory.create(JblasNDArrayFactory.java:229)
at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:2919)
at org.nd4j.linalg.api.ndarray.BaseNDArray.create(BaseNDArray.java:3305)
at org.nd4j.linalg.api.ndarray.BaseNDArray.reshape(BaseNDArray.java:3348)
at org.nd4j.linalg.api.ndarray.BaseNDArray.reshape(BaseNDArray.java:3372)
at org.nd4j.linalg.api.ndarray.BaseNDArray.reshape(BaseNDArray.java:3605)
at org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor.preProcess(RnnToFeedForwardPreProcessor.java:42)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.activationFromPrevLayer(MultiLayerNetwork.java:450)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.feedForward(MultiLayerNetwork.java:513)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.feedForward(MultiLayerNetwork.java:499)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:1250)
at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:1266)
at testicle.LSTMTest.main(LSTMTest.java:50)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:483)
at com.intellij.rt.execution.application.AppMain.main(AppMain.java:140)
-- org.jblas INFO Deleting /tmp/jblas5445841866225978621/libjblas.so
-- org.jblas INFO Deleting /tmp/jblas5445841866225978621/libjblas_arch_flavor.so
-- org.jblas INFO Deleting /tmp/jblas5445841866225978621
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment