Skip to content

Instantly share code, notes, and snippets.

@rubenfiszel
Created August 10, 2016 16:51
Show Gist options
  • Save rubenfiszel/9c0f8d2ca83b1d37e56f22ef2f2ccde8 to your computer and use it in GitHub Desktop.
Save rubenfiszel/9c0f8d2ca83b1d37e56f22ef2f2ccde8 to your computer and use it in GitHub Desktop.
package org.deeplearning4j.rl4j;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.rl4j.network.dqn.DQN;
import org.deeplearning4j.rl4j.network.dqn.DQNFactoryStdDense;
import org.deeplearning4j.rl4j.util.Constants;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
/**
* @author rubenfiszel ([email protected]) on 8/10/16.
*/
public class Example3 {
public static MultiLayerNetwork buildModel() {
NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder()
.seed(Constants.NEURAL_NET_SEED)
.iterations(1)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(0.001)
.updater(Updater.NESTEROVS).momentum(0.9)
.weightInit(WeightInit.XAVIER)
.regularization(true)
.l2(0.01)
.list()
.layer(0, new DenseLayer.Builder()
.nIn(10)
.nOut(Constants.NUM_HIDDEN_NODES)
.activation("relu")
.build());
for (int i = 1; i < 4; i++) {
confB
.layer(i, new DenseLayer.Builder()
.nIn(Constants.NUM_HIDDEN_NODES)
.nOut(Constants.NUM_HIDDEN_NODES)
.activation("relu")
.build());
}
confB
.layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation("identity")
.nIn(Constants.NUM_HIDDEN_NODES)
.nOut(10)
.build());
MultiLayerConfiguration mlnconf = confB.pretrain(false).backprop(true).build();
MultiLayerNetwork model = new MultiLayerNetwork(mlnconf);
model.init();
model.setListeners(new ScoreIterationListener(1));
return model;
}
public static void main(String[] args){
INDArray input1 = Nd4j.readTxt("input0.txt");
INDArray label1 = Nd4j.readTxt("label0.txt");
INDArray input2 = Nd4j.readTxt("input1.txt");
INDArray label2 = Nd4j.readTxt("label1.txt");
MultiLayerNetwork mln = buildModel();
mln.fit(input1, label1);
mln.fit(input2, label2);
System.out.println("lol");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment