Skip to content

Instantly share code, notes, and snippets.

@mindcrime
Created February 25, 2017 17:43
Show Gist options
  • Save mindcrime/043270fa30ebfe9ba557642585c1609b to your computer and use it in GitHub Desktop.
Save mindcrime/043270fa30ebfe9ba557642585c1609b to your computer and use it in GitHub Desktop.
DL4J code for loading MNIST with Spark
package org.fogbeam.dl4j.spark;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.spark.util.MLLibUtil;
import org.nd4j.linalg.dataset.DataSet;
public class ExpMain
{
public static void main(String[] args) throws Exception
{
SparkConf sparkConf = new SparkConf();
sparkConf.setMaster("local");
sparkConf.setAppName("SparkNeuralNetwork");
JavaSparkContext sc = new JavaSparkContext( sparkConf );
// https://github.com/deeplearning4j/DataVec/blob/master/datavec-spark/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java
// String recursiveSetting = sc.hadoopConfiguration().get("mapreduce.input.fileinputformat.input.dir.recursive");
// System.out.println( "recursiveSetting: " + recursiveSetting );
JavaPairRDD<String, PortableDataStream> files = sc.binaryFiles("/home/prhodes/development/experimental/ai_exp/NeuralNetworkSandbox/mnist_png/training/1/*.png");
System.out.println( "binary data RDD created" );
ImageRecordReader reader = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator());
List<String> labelsList = Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9" ); //Need this for Spark: can't infer without init call
reader.setLabels(labelsList);
JavaRDD<LabeledPoint> labeledPoints = MLLibUtil.fromBinary(files, reader);
System.out.println( "labeledPoints RDD created");
JavaRDD<DataSet> trainingData = MLLibUtil.fromLabeledPoint( labeledPoints, 10, 50);
System.out.println( "DataSet RDD created");
/* TBD: train model here */
System.out.println( "done" );
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment