Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save gamemachine/4712d51f2ba8cc53c3df06eff60ffca5 to your computer and use it in GitHub Desktop.

Select an option

Save gamemachine/4712d51f2ba8cc53c3df06eff60ffca5 to your computer and use it in GitHub Desktop.
dl4jtest1
int rngSeed = 123; // random number seed for reproducibility
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed) //include a random seed for reproducibility
// use stochastic gradient descent as an optimization algorithm
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.learningRate(0.006) //specify the learning rate
.updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS).momentum(0.9) //specify the rate of change of the learning rate.
.regularization(true).l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder() //create the first, input layer with xavier initialization
.nIn(1)
.nOut(2)
.activation(Activation.SIGMOID)
.weightInit(WeightInit.XAVIER)
.build())
.layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
.nIn(2)
.nOut(1)
.activation(Activation.SIGMOID)
.weightInit(WeightInit.XAVIER)
//.dist(new UniformDistribution(0, 1))
.build())
.pretrain(false).backprop(true) //use backpropagation to adjust weights
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
//print the score with every 1 iteration
model.setListeners(new ScoreIterationListener(100));
double[][] input = new double[2][];
input[0] = new double[] { 0D };
input[1] = new double[] { 1D };
double[][] output = new double[2][];
output[0] = new double[] { 0D };
output[1] = new double[] { 1D };
INDArray ndIn = Nd4j.create(input);
INDArray ndOut = Nd4j.create(output);
for (int i=0;i<50;i++)
{
model.fit(ndIn, ndOut);
}
//int[] res = model.predict(Nd4j.create(input[0]));
//string msg = string.Format("{0} {1}",res.Length, res[0]);
//Console.WriteLine(msg);
INDArray res2 = model.output(Nd4j.create(input[0]));
string msg = string.Format("{0}", res2.getDouble(0,0));
Console.WriteLine(msg);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment