Skip to content

Instantly share code, notes, and snippets.

@mindcrime
Created February 25, 2017 19:46
Show Gist options
  • Save mindcrime/eac54feaeb6181d8ed1732ae7ca4ba1f to your computer and use it in GitHub Desktop.
Save mindcrime/eac54feaeb6181d8ed1732ae7ca4ba1f to your computer and use it in GitHub Desktop.
Loading MNIST in Spark (again)
package org.fogbeam.dl4j.spark;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.input.PortableDataStream;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.writable.Writable;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.spark.functions.RecordReaderFunction;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.datavec.DataVecDataSetFunction;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions;
public class ExpMain2
{
public static void main(String[] args) throws Exception
{
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local");
sparkConf.setAppName("SparkNeuralNetwork");
JavaSparkContext sc = new JavaSparkContext( sparkConf );
// String recursiveSetting = sc.hadoopConfiguration().get("mapreduce.input.fileinputformat.input.dir.recursive");
// System.out.println( "recursiveSetting: " + recursiveSetting );
JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles("/home/prhodes/development/experimental/ai_exp/NeuralNetworkSandbox/mnist_png/training/1/*.png");
ImageRecordReader irr = new ImageRecordReader(28,28,1,new ParentPathLabelGenerator());
List<String> labelsList = Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9");
irr.setLabels(labelsList);
RecordReaderFunction rrf = new RecordReaderFunction(irr);
JavaRDD<List<Writable>> rdd = origData.map(rrf);
System.out.println( "DataSet RDD created");
JavaRDD<DataSet> trainingData = rdd.map(new DataVecDataSetFunction(0,10, false));
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
.activation(Activation.LEAKYRELU)
.weightInit(WeightInit.XAVIER)
.learningRate(0.02)
.updater(Updater.NESTEROVS).momentum(0.9)
.regularization(true).l2(1e-4)
.list()
.layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(500).build())
.layer(1, new DenseLayer.Builder().nIn(500).nOut(100).build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX).nIn(100).nOut(10).build())
.pretrain(false).backprop(true)
.build();
// Create the TrainingMaster instance
int examplesPerDataSetObject = 1;
TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(examplesPerDataSetObject)
.build();
// Create the SparkDl4jMultiLayer instance
// Create the Spark network
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf, trainingMaster);
sparkNet.fit( trainingData );
System.out.println( "done" );
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment