Created
August 10, 2016 16:51
-
-
Save rubenfiszel/9c0f8d2ca83b1d37e56f22ef2f2ccde8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.rl4j; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.Updater; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.deeplearning4j.rl4j.network.dqn.DQN; | |
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense; | |
import org.deeplearning4j.rl4j.util.Constants; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
/** | |
* @author rubenfiszel ([email protected]) on 8/10/16. | |
*/ | |
public class Example3 { | |
public static MultiLayerNetwork buildModel() { | |
NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder() | |
.seed(Constants.NEURAL_NET_SEED) | |
.iterations(1) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.learningRate(0.001) | |
.updater(Updater.NESTEROVS).momentum(0.9) | |
.weightInit(WeightInit.XAVIER) | |
.regularization(true) | |
.l2(0.01) | |
.list() | |
.layer(0, new DenseLayer.Builder() | |
.nIn(10) | |
.nOut(Constants.NUM_HIDDEN_NODES) | |
.activation("relu") | |
.build()); | |
for (int i = 1; i < 4; i++) { | |
confB | |
.layer(i, new DenseLayer.Builder() | |
.nIn(Constants.NUM_HIDDEN_NODES) | |
.nOut(Constants.NUM_HIDDEN_NODES) | |
.activation("relu") | |
.build()); | |
} | |
confB | |
.layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.MSE) | |
.activation("identity") | |
.nIn(Constants.NUM_HIDDEN_NODES) | |
.nOut(10) | |
.build()); | |
MultiLayerConfiguration mlnconf = confB.pretrain(false).backprop(true).build(); | |
MultiLayerNetwork model = new MultiLayerNetwork(mlnconf); | |
model.init(); | |
model.setListeners(new ScoreIterationListener(1)); | |
return model; | |
} | |
public static void main(String[] args){ | |
INDArray input1 = Nd4j.readTxt("input0.txt"); | |
INDArray label1 = Nd4j.readTxt("label0.txt"); | |
INDArray input2 = Nd4j.readTxt("input1.txt"); | |
INDArray label2 = Nd4j.readTxt("label1.txt"); | |
MultiLayerNetwork mln = buildModel(); | |
mln.fit(input1, label1); | |
mln.fit(input2, label2); | |
System.out.println("lol"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment