Created
February 25, 2017 17:43
-
-
Save mindcrime/043270fa30ebfe9ba557642585c1609b to your computer and use it in GitHub Desktop.
DL4J code for loading MNIST with Spark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.fogbeam.dl4j.spark; | |
import java.util.Arrays; | |
import java.util.List; | |
import org.apache.spark.SparkConf; | |
import org.apache.spark.api.java.JavaPairRDD; | |
import org.apache.spark.api.java.JavaRDD; | |
import org.apache.spark.api.java.JavaSparkContext; | |
import org.apache.spark.input.PortableDataStream; | |
import org.apache.spark.mllib.regression.LabeledPoint; | |
import org.datavec.api.io.labels.ParentPathLabelGenerator; | |
import org.datavec.image.recordreader.ImageRecordReader; | |
import org.deeplearning4j.spark.util.MLLibUtil; | |
import org.nd4j.linalg.dataset.DataSet; | |
public class ExpMain | |
{ | |
public static void main(String[] args) throws Exception | |
{ | |
SparkConf sparkConf = new SparkConf(); | |
sparkConf.setMaster("local"); | |
sparkConf.setAppName("SparkNeuralNetwork"); | |
JavaSparkContext sc = new JavaSparkContext( sparkConf ); | |
// https://github.com/deeplearning4j/DataVec/blob/master/datavec-spark/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java | |
// String recursiveSetting = sc.hadoopConfiguration().get("mapreduce.input.fileinputformat.input.dir.recursive"); | |
// System.out.println( "recursiveSetting: " + recursiveSetting ); | |
JavaPairRDD<String, PortableDataStream> files = sc.binaryFiles("/home/prhodes/development/experimental/ai_exp/NeuralNetworkSandbox/mnist_png/training/1/*.png"); | |
System.out.println( "binary data RDD created" ); | |
ImageRecordReader reader = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator()); | |
List<String> labelsList = Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9" ); //Need this for Spark: can't infer without init call | |
reader.setLabels(labelsList); | |
JavaRDD<LabeledPoint> labeledPoints = MLLibUtil.fromBinary(files, reader); | |
System.out.println( "labeledPoints RDD created"); | |
JavaRDD<DataSet> trainingData = MLLibUtil.fromLabeledPoint( labeledPoints, 10, 50); | |
System.out.println( "DataSet RDD created"); | |
/* TBD: train model here */ | |
System.out.println( "done" ); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment