Created
February 26, 2017 18:59
-
-
Save mindcrime/f30b6641a58ecf6f059db9380a865f86 to your computer and use it in GitHub Desktop.
Complete example where values seem to not get scaled.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.fogbeam.dl4j.spark; | |
import java.io.File; | |
import java.util.Arrays; | |
import java.util.List; | |
import java.util.concurrent.TimeUnit; | |
import org.apache.spark.SparkConf; | |
import org.apache.spark.api.java.JavaPairRDD; | |
import org.apache.spark.api.java.JavaRDD; | |
import org.apache.spark.api.java.JavaSparkContext; | |
import org.apache.spark.api.java.function.VoidFunction; | |
import org.apache.spark.input.PortableDataStream; | |
import org.datavec.api.io.labels.ParentPathLabelGenerator; | |
import org.datavec.api.writable.Writable; | |
import org.datavec.image.recordreader.ImageRecordReader; | |
import org.datavec.spark.functions.RecordReaderFunction; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.Updater; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.spark.api.TrainingMaster; | |
import org.deeplearning4j.spark.datavec.DataVecDataSetFunction; | |
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; | |
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; | |
import org.deeplearning4j.util.ModelSerializer; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; | |
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import com.google.common.base.Stopwatch; | |
public class ExpMain2 | |
{ | |
public static void main(String[] args) throws Exception | |
{ | |
SparkConf sparkConf = new SparkConf(); | |
sparkConf.setMaster("local"); | |
sparkConf.setAppName("SparkNeuralNetwork"); | |
Stopwatch sw = Stopwatch.createStarted(); | |
JavaSparkContext sc = new JavaSparkContext( sparkConf ); | |
sc.hadoopConfiguration().set("mapreduce.input.fileinputformat.input.dir.recursive", "true"); | |
JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles("/home/prhodes/development/experimental/ai_exp/NeuralNetworkSandbox/mnist_png/cutdown/0/**"); | |
ImageRecordReader irr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator() ); | |
List<String> labelsList = Arrays.asList( "0", "1", "2", "3", "4", "5", "6", "7", "8", "9" ); | |
irr.setLabels(labelsList); | |
RecordReaderFunction rrf = new RecordReaderFunction(irr); | |
JavaRDD<List<Writable>> rdd = origData.map(rrf); | |
System.out.println( "DataSet RDD created"); | |
DataNormalization scaler = new ImagePreProcessingScaler(0,1); | |
JavaRDD<DataSet> trainingData = rdd.map(new DataVecDataSetFunction(1,10, false, scaler, null )); | |
trainingData.foreach( new VoidFunction<DataSet>() { | |
int count = 0; | |
@Override | |
public void call(DataSet arg0) throws Exception { | |
System.out.println( "count: " + count++ + "\n"); | |
System.out.println( "features: " + arg0.getFeatures() + "\n"); | |
} | |
} ); | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(12345) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(10) | |
.activation(Activation.RELU) | |
.weightInit(WeightInit.XAVIER) | |
.learningRate(0.02) | |
.updater(Updater.NESTEROVS).momentum(0.9) | |
.regularization(true).l2(1e-4) | |
.list() | |
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build()) | |
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build()) | |
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build()) | |
.pretrain(false).backprop(true) | |
.setInputType(InputType.convolutional(28, 28, 1)) | |
.build(); | |
// Create the TrainingMaster instance | |
int examplesPerDataSetObject = 1; | |
TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(examplesPerDataSetObject) | |
.build(); | |
// Create the SparkDl4jMultiLayer instance | |
// Create the Spark network | |
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf, trainingMaster); | |
long elapsedPhase1 = sw.elapsed(TimeUnit.SECONDS); | |
System.out.println( "Loading data took " + elapsedPhase1 + " seconds. Starting to train model now."); | |
MultiLayerNetwork trainedNetwork = null; | |
for( int i = 0; i < 20; i++ ) | |
{ | |
trainedNetwork = sparkNet.fit( trainingData ); | |
} | |
long elapsedPhase2 = sw.elapsed(TimeUnit.SECONDS); | |
System.out.println( "Training model took " + ( elapsedPhase2 - elapsedPhase1) + " seconds."); | |
System.out.println( "Total elapsed time: " + elapsedPhase2 ); | |
/* delete any existing model if there is one */ | |
File oldModelFile = new File( "sparkTrainedNetwork.zip" ); | |
if( oldModelFile.exists()) | |
{ | |
oldModelFile.delete(); | |
oldModelFile = null; | |
} | |
ModelSerializer.writeModel(trainedNetwork, new File("sparkTrainedNetwork.zip"), false); | |
System.out.println( "done" ); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment