Created
October 11, 2016 14:54
-
-
Save osipov/2da9af5273dd2d169b9f04be503aebd1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.feedforward.xor; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.api.Model; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.Updater; | |
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; | |
import org.deeplearning4j.nn.conf.distribution.UniformDistribution; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
//import org.deeplearning4j.nn.conf.distribution.UniformDistribution; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; | |
//import org.deeplearning4j.eval.Evaluation; | |
//import org.deeplearning4j.nn.api.Model; | |
//import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
//import org.deeplearning4j.nn.conf.Updater; | |
//import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
//import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
//import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
//import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
//import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
//import org.deeplearning4j.nn.weights.WeightInit; | |
//import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
//import org.nd4j.linalg.api.ndarray.INDArray; | |
//import org.nd4j.linalg.dataset.DataSet; | |
//import org.nd4j.linalg.factory.Nd4j; | |
//import org.nd4j.linalg.lossfunctions.LossFunctions; | |
// | |
////import org.nd4j.jita.conf.CudaEnvironment; | |
import java.util.Arrays; | |
import java.util.Random; | |
/** | |
* Created by osipov on 6/28/16. | |
*/ | |
public class FizzBuzz { | |
public static int decodeBinary(INDArray arr) { | |
int i = 0; | |
for (int j = 0; j < arr.length(); j++) { | |
i += Math.pow(2, j) * arr.getInt(j); | |
} | |
return i; | |
} | |
public static int decodeBinary(float[] b) { | |
int i = 0; | |
for (int j = 0; j < b.length; j++) { | |
i += Math.pow(2, j)*b[j]; | |
} | |
return i; | |
} | |
// public static float[] encodeBinary(int val, int numDigits) { | |
// float[] result = new float[numDigits]; | |
// for (int i = 0; i < numDigits; i++) { | |
// result[i] = (val >> i) & 1; | |
// } | |
// return result; | |
// } | |
public static INDArray encodeBinary(int val, int numDigits) { | |
INDArray encoded = Nd4j.zeros(numDigits); | |
for (int i = 0; i < numDigits; i++) | |
encoded.putScalar(i, (val >> i) & 1); | |
return encoded; | |
} | |
// public static float[] encodeFizzBuzz(int i) { | |
// if (i % 15 == 0) return new float[]{0.0f, 0.0f, 0.0f, 1.0f}; | |
// else | |
// if (i % 5 == 0) return new float[]{0.0f, 0.0f, 1.0f, 0.0f}; | |
// else | |
// if (i % 3 == 0) return new float[]{0.0f, 1.0f, 0.0f, 0.0f}; | |
// | |
// else return new float[]{1.0f, 0.0f, 0.0f, 0.0f}; | |
// } | |
public static INDArray encodeFizzBuzz(int i) { | |
INDArray encoded = Nd4j.zeros(4); | |
if (i % 15 == 0) return encoded.putScalar(3, 1); | |
else if (i % 5 == 0) return encoded.putScalar(2, 1); | |
else if (i % 3 == 0) return encoded.putScalar(1, 1); | |
else return encoded.putScalar(0, 1); | |
} | |
// | |
// public static int[] encodeFizzBuzz(int i) { | |
// if (i % 15 == 0) return new int[]{0, 0, 0, 1}; | |
// else | |
// if (i % 5 == 0) return new int[]{0, 0, 1, 0}; | |
// else | |
// if (i % 3 == 0) return new int[]{0, 1, 0, 0}; | |
// | |
// else return new int[]{1, 0, 0, 0}; | |
// } | |
public static void main(String[] args) { | |
// org.nd4j.jita.conf.CudaEnvironment.getInstance().getConfiguration().allowMultiGPU(true); | |
Nd4j.ENFORCE_NUMERICAL_STABILITY = true; | |
// final int NUM_UPPER = 32768; | |
final int NUM_UPPER = 8192; | |
final int NUM_DIGITS = 10; | |
int rngSeed = 12345; | |
int numEpochs = 5000; | |
int batchSize = 128; | |
double learningRate = 0.3; | |
double regularizationRate = learningRate * 0.0005; | |
double nesterovsMomentum = 0.9; | |
Random rnd = new Random(rngSeed); | |
// int numEpochs = 1000; | |
INDArray trainFeaturesTmp = Nd4j.zeros(NUM_UPPER - 101, NUM_DIGITS); | |
INDArray trainLabelsTmp = Nd4j.zeros(NUM_UPPER - 101, 4); | |
int trainCount = 0; | |
for (int i = 101; i < NUM_UPPER; i++) { | |
INDArray features = encodeBinary(i, NUM_DIGITS); | |
INDArray labels = encodeFizzBuzz(i); | |
boolean lucky = false; | |
if (labels.getInt(0) == 1) lucky = rnd.nextInt(8) == 0; | |
else | |
if (labels.getInt(1) == 1) lucky = rnd.nextInt(4) == 0; | |
else | |
if (labels.getInt(2) == 1) lucky = rnd.nextInt(2) == 0; | |
else | |
if (labels.getInt(3) == 1) lucky = true; | |
if (lucky) { | |
trainFeaturesTmp.putRow(trainCount, features); | |
trainLabelsTmp.putRow(trainCount, labels); | |
trainCount++; | |
} | |
} | |
int[] counts = new int[4]; | |
for (int i = 0; i < trainCount; i++) { | |
if (trainLabelsTmp.getRow(i).getInt(0) == 1) counts[0] += 1; | |
else | |
if (trainLabelsTmp.getRow(i).getInt(1) == 1) counts[1] += 1; | |
else | |
if (trainLabelsTmp.getRow(i).getInt(2) == 1) counts[2] += 1; | |
else | |
if (trainLabelsTmp.getRow(i).getInt(3) == 1) counts[3] += 1; | |
} | |
System.out.println("Train count: " + Arrays.toString(counts)); | |
INDArray trainFeatures = Nd4j.zeros(trainCount, NUM_DIGITS); | |
INDArray trainLabels = Nd4j.zeros(trainCount, 4); | |
for (int i = 0; i < trainCount; i++) { | |
trainFeatures.putRow(i, trainFeaturesTmp.getRow(i)); | |
trainLabels.putRow(i, trainLabelsTmp.getRow(i)); | |
} | |
INDArray testFeatures = Nd4j.zeros(100, NUM_DIGITS); | |
for (int i = 1; i < 101; i++) testFeatures.putRow(i - 1, encodeBinary(i, NUM_DIGITS)); | |
INDArray testLabels = Nd4j.zeros(100, 4); | |
for (int i = 1; i < 101; i++) testLabels.putRow(i - 1, encodeFizzBuzz(i)); | |
final DataSet trainDataset = new DataSet(trainFeatures, trainLabels); | |
final DataSet testDataset = new DataSet(testFeatures, testLabels); | |
// for (int i = 0; i < 100; i++) { | |
// System.out.println(testFeatures.getRow(i).toString() + " " + testLabels.getRow(i).toString()); | |
// } | |
// if (true) return; | |
trainDataset.shuffle(rngSeed); | |
DataSetIterator trainDatasetBatches = new ListDataSetIterator(trainDataset.asList(), batchSize); | |
// DataSetIterator testDatasetBatches = new ListDataSetIterator(testDataset.asList(), batchSize); | |
System.out.println("Build model...."); | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(rngSeed) | |
.optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) | |
// .biasInit(0) | |
.iterations(1) | |
.learningRate(learningRate) | |
.activation("relu") | |
.weightInit(WeightInit.XAVIER) | |
.miniBatch(true) | |
.useDropConnect(false) | |
.updater(Updater.NESTEROVS).momentum(nesterovsMomentum) | |
// .regularization(true).l2(regularizationRate) | |
.list() | |
.layer(0, new DenseLayer.Builder() | |
.nIn(10) | |
.nOut(100) | |
// .weightInit(WeightInit.DISTRIBUTION) | |
// .dist(new NormalDistribution(0.0, 0.01)) | |
// .activation("relu") | |
.build()) | |
.layer(1, new DenseLayer.Builder() | |
.nIn(100) | |
.nOut(100) | |
// .weightInit(WeightInit.DISTRIBUTION) | |
// .dist(new NormalDistribution(0.0, 0.01)) | |
// .activation("relu") | |
.build()) | |
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) | |
.nIn(100) | |
.nOut(4) | |
// .weightInit(WeightInit.DISTRIBUTION) | |
// .dist(new UniformDistribution(0.1, 1)) | |
.activation("softmax") | |
.build()) | |
.pretrain(false).backprop(true) | |
.build(); | |
final MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
model.init(); | |
// add an listener which outputs the error every 100 parameter updates | |
// model.setListeners(new ScoreIterationListener(100)); | |
model.setListeners(new ScoreIterationListener[]{ | |
new ScoreIterationListener(100), | |
// new ScoreIterationListener(200) { | |
// private int myCount = 0; | |
// | |
// @Override | |
// public void iterationDone(Model m, int iter) { | |
//// super.iterationDone(m, iter); | |
// try { | |
// if (myCount % 200 == 0 && myCount > 0) { | |
// org.deeplearning4j.nn.multilayer.MultiLayerNetwork mod = (org.deeplearning4j.nn.multilayer.MultiLayerNetwork) m; | |
// Evaluation eval = new Evaluation(4); | |
// INDArray output = mod.output(testDataset.getFeatures()); | |
// eval.eval(testDataset.getLabels(), output); | |
// System.out.println("Test Iteration " + myCount); | |
// System.out.println(eval.stats(true)); | |
// } | |
// myCount++; | |
// } catch (Throwable t) { | |
// System.out.println("caught throwable " + t); | |
// } | |
// } | |
// } | |
}); | |
System.out.println("Train model...."); | |
for( int i=0; i<numEpochs; i++ ) { | |
model.fit(trainDatasetBatches); | |
// model.fit(trainDataset); | |
} | |
System.out.println("Evaluate model...."); | |
{ | |
System.out.println("****************Train eval********************"); | |
Evaluation eval = new Evaluation(4); | |
eval.eval(trainDataset.getLabels(), model.output(trainDataset.getFeatures())); | |
System.out.println(eval.stats()); | |
System.out.println("****************Train eval********************"); | |
} | |
{ | |
System.out.println("****************Test eval********************"); | |
Evaluation eval = new Evaluation(4); | |
eval.eval(testDataset.getLabels(), model.output(testDataset.getFeatures())); | |
System.out.println(eval.stats()); | |
System.out.println("****************Test eval********************"); | |
} | |
System.out.println("****************Example finished********************"); | |
for (int i = 0; i < 16; i++) { | |
System.out.println((i + 1) + " " + testFeatures.getRow(i).toString() + " " + model.output(testFeatures.getRow(i)).toString()); | |
} | |
// System.out.println(model.output(testDataset.getFeatures())); | |
// for (int i = 1; i < 101; i++) { | |
// INDArray o = model.output(encodeBinary(i, NUM_DIGITS)); | |
// | |
// System.out.println(i + " " + o.toString() + " " + o.maxNumber() + " " + o.eps(o.maxNumber()).toString()); | |
// } | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment