Last active
August 29, 2015 14:27
-
-
Save Seppo420/6fe1544cbc1764b224a3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ===========INPUT=================== | |
| [[ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00] | |
| [ 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00, 1.00] | |
| [ 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00, 2.00] | |
| [ 3.00, 3.00, 3.00, 3.00, 3.00, 3.00, 3.00, 3.00, 3.00, 3.00, 3.00] | |
| [ 4.00, 4.00, 4.00, 4.00, 4.00, 4.00, 4.00, 4.00, 4.00, 4.00, 4.00] | |
| [ 5.00, 5.00, 5.00, 5.00, 5.00, 5.00, 5.00, 5.00, 5.00, 5.00, 5.00] | |
| [ 6.00, 6.00, 6.00, 6.00, 6.00, 6.00, 6.00, 6.00, 6.00, 6.00, 6.00] | |
| [ 7.00, 7.00, 7.00, 7.00, 7.00, 7.00, 7.00, 7.00, 7.00, 7.00, 7.00] | |
| [ 8.00, 8.00, 8.00, 8.00, 8.00, 8.00, 8.00, 8.00, 8.00, 8.00, 8.00] | |
| [ 9.00, 9.00, 9.00, 9.00, 9.00, 9.00, 9.00, 9.00, 9.00, 9.00, 9.00] | |
| [ 10.00, 10.00, 10.00, 10.00, 10.00, 10.00, 10.00, 10.00, 10.00, 10.00, 10.00]] | |
| =================OUTPUT================== | |
| [[ 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00] | |
| [ 0.84, 0.84, 0.84, 0.84, 0.84, 0.84, 0.84, 0.84, 0.84, 0.84, 0.84] | |
| [ 0.91, 0.91, 0.91, 0.91, 0.91, 0.91, 0.91, 0.91, 0.91, 0.91, 0.91] | |
| [ 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14] | |
| [ -0.76, -0.76, -0.76, -0.76, -0.76, -0.76, -0.76, -0.76, -0.76, -0.76, -0.76] | |
| [ -0.96, -0.96, -0.96, -0.96, -0.96, -0.96, -0.96, -0.96, -0.96, -0.96, -0.96] | |
| [ -0.28, -0.28, -0.28, -0.28, -0.28, -0.28, -0.28, -0.28, -0.28, -0.28, -0.28] | |
| [ 0.66, 0.66, 0.66, 0.66, 0.66, 0.66, 0.66, 0.66, 0.66, 0.66, 0.66] | |
| [ 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99, 0.99] | |
| [ 0.41, 0.41, 0.41, 0.41, 0.41, 0.41, 0.41, 0.41, 0.41, 0.41, 0.41] | |
| [ -0.54, -0.54, -0.54, -0.54, -0.54, -0.54, -0.54, -0.54, -0.54, -0.54, -0.54]] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package testicle; | |
| import org.nd4j.linalg.api.ndarray.INDArray; | |
| import org.nd4j.linalg.dataset.api.DataSet; | |
| import org.nd4j.linalg.factory.Nd4j; | |
| import java.util.List; | |
| import java.util.function.Function; | |
| import java.util.function.IntFunction; | |
| /** | |
| * Created by jp on 8/19/15. | |
| */ | |
| public class DataSetCreator { | |
| public static DataSet create(int nIn, int nOut, int timeSeriesLength,int miniBatchSize, | |
| IntFunction<double[]> inputGenerator, | |
| IntFunction<double[]> outputGenerator){ | |
| INDArray input = Nd4j.zeros(miniBatchSize, nIn, timeSeriesLength); | |
| for( int i=0; i<miniBatchSize; i++ ){ | |
| double[] in = inputGenerator.apply(i); | |
| for( int j=0; j<nIn; j++ ){ | |
| for( int k=0; k<timeSeriesLength; k++ ){ | |
| input.putScalar(new int[]{i,j,k},in[j]); | |
| } | |
| } | |
| } | |
| INDArray output = Nd4j.zeros(miniBatchSize, nIn, timeSeriesLength); | |
| for( int i=0; i<miniBatchSize; i++ ){ | |
| double[] in = inputGenerator.apply(i); | |
| double[] out = outputGenerator.apply(i); | |
| for( int j=0; j<nIn; j++ ){ | |
| for( int k=0; k<timeSeriesLength; k++ ){ | |
| output.putScalar(new int[]{i,j,k},out[j]); | |
| } | |
| } | |
| } | |
| org.nd4j.linalg.dataset.DataSet ds = new org.nd4j.linalg.dataset.DataSet(input,output); | |
| return ds; | |
| } | |
| public static DataSet fromList(List<double[]> inputs, List<double[]> outputs){ | |
| int notSureWhatThisDoes=11; | |
| return create(inputs.get(0).length,outputs.get(0).length,inputs.size(),notSureWhatThisDoes, | |
| i->inputs.get(i), | |
| i->outputs.get(i)); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| package testicle; | |
| import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
| import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
| import org.deeplearning4j.nn.weights.WeightInit; | |
| import org.nd4j.linalg.factory.Nd4j; | |
| import org.slf4j.Logger; | |
| import org.slf4j.LoggerFactory; | |
| import org.deeplearning4j.nn.conf.Updater; | |
| import org.deeplearning4j.nn.conf.distribution.NormalDistribution; | |
| import org.deeplearning4j.nn.conf.layers.GravesLSTM; | |
| import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
| import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; | |
| import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
| import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; | |
| /** | |
| * Created by jp on 8/19/15. | |
| */ | |
| public class LSTMTest{ | |
| static Logger log = LoggerFactory.getLogger(LSTMTest.class); | |
| public static void main(String... args)throws Exception { | |
| Nd4j.getRandom().setSeed(12345L); | |
| int timeSeriesLength = 11; //works when this is 1, crashes otherwise, see stacktrace | |
| int nIn = 1; | |
| int layerSize = 2; | |
| int nOut = 1; | |
| int miniBatchSize = 11; | |
| MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
| .weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0,0.1)) | |
| .regularization(false) | |
| .updater(Updater.NONE) | |
| .seed(12345L) | |
| .list(2) | |
| .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).activation("tanh").build()) | |
| .layer(1, new OutputLayer.Builder(LossFunction.MCXENT).activation("softmax").nIn(layerSize).nOut(nOut).build()) | |
| .inputPreProcessor(1, new RnnToFeedForwardPreProcessor(timeSeriesLength)) | |
| .pretrain(false).backprop(true) | |
| .build(); | |
| MultiLayerNetwork mln = new MultiLayerNetwork(conf); | |
| mln.init(); | |
| mln.fit(DataSetCreator.create(nIn, nOut, timeSeriesLength, miniBatchSize, | |
| i -> new double[]{i}, | |
| i -> new double[]{Math.sin(i)}) | |
| ); | |
| log.info("result: {}",mln.output(Nd4j.zeros(1,nIn))); | |
| } | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Exception in thread "main" java.lang.IllegalArgumentException: Shape must be <= buffer length | |
| at org.nd4j.linalg.api.ndarray.BaseNDArray.<init>(BaseNDArray.java:133) | |
| at org.nd4j.linalg.jblas.NDArray.<init>(NDArray.java:65) | |
| at org.nd4j.linalg.jblas.JblasNDArrayFactory.create(JblasNDArrayFactory.java:229) | |
| at org.nd4j.linalg.factory.Nd4j.create(Nd4j.java:2919) | |
| at org.nd4j.linalg.api.ndarray.BaseNDArray.create(BaseNDArray.java:3305) | |
| at org.nd4j.linalg.api.ndarray.BaseNDArray.reshape(BaseNDArray.java:3348) | |
| at org.nd4j.linalg.api.ndarray.BaseNDArray.reshape(BaseNDArray.java:3372) | |
| at org.nd4j.linalg.api.ndarray.BaseNDArray.reshape(BaseNDArray.java:3605) | |
| at org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor.preProcess(RnnToFeedForwardPreProcessor.java:42) | |
| at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.activationFromPrevLayer(MultiLayerNetwork.java:450) | |
| at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.feedForward(MultiLayerNetwork.java:513) | |
| at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.feedForward(MultiLayerNetwork.java:499) | |
| at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:1250) | |
| at org.deeplearning4j.nn.multilayer.MultiLayerNetwork.output(MultiLayerNetwork.java:1266) | |
| at testicle.LSTMTest.main(LSTMTest.java:50) | |
| at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) | |
| at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) | |
| at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) | |
| at java.lang.reflect.Method.invoke(Method.java:483) | |
| at com.intellij.rt.execution.application.AppMain.main(AppMain.java:140) | |
| -- org.jblas INFO Deleting /tmp/jblas5445841866225978621/libjblas.so | |
| -- org.jblas INFO Deleting /tmp/jblas5445841866225978621/libjblas_arch_flavor.so | |
| -- org.jblas INFO Deleting /tmp/jblas5445841866225978621 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment