Created
March 24, 2017 19:28
-
-
Save omerwase/5dbd2291b346cbedbdc8a1427b61160e to your computer and use it in GitHub Desktop.
CNN on Spark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.cnnspark; | |
import com.beust.jcommander.JCommander; | |
import com.beust.jcommander.Parameter; | |
import com.beust.jcommander.ParameterException; | |
import org.apache.commons.io.FilenameUtils; | |
import org.apache.spark.SparkConf; | |
import org.apache.spark.api.java.JavaRDD; | |
import org.apache.spark.api.java.JavaSparkContext; | |
import org.datavec.api.io.filters.BalancedPathFilter; | |
import org.datavec.api.io.labels.ParentPathLabelGenerator; | |
import org.datavec.api.split.FileSplit; | |
import org.datavec.api.split.InputSplit; | |
import org.datavec.image.loader.NativeImageLoader; | |
import org.datavec.image.recordreader.ImageRecordReader; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.*; | |
import org.deeplearning4j.nn.conf.distribution.Distribution; | |
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution; | |
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; | |
import org.deeplearning4j.nn.conf.layers.*; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.deeplearning4j.spark.api.TrainingMaster; | |
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; | |
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; | |
import org.deeplearning4j.util.NetSaverLoaderUtils; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; | |
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.io.File; | |
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.List; | |
import java.util.Random; | |
/** | |
* SDC Classification based on the Animal Classification example | |
* | |
* Example classification of photos from 4 different animals (bear, duck, deer, turtle). | |
* | |
* References: | |
* - U.S. Fish and Wildlife Service (animal sample dataset): http://digitalmedia.fws.gov/cdm/ | |
* - Tiny ImageNet Classification with CNN: http://cs231n.stanford.edu/reports/leonyao_final.pdf | |
* | |
* CHALLENGE: Current setup gets low score results. Can you improve the scores? Some approaches: | |
* - Add additional images to the dataset | |
* - Apply more transforms to dataset | |
* - Increase epochs | |
* - Try different model configurations | |
* - Tune by adjusting learning rate, updaters, activation & loss functions, regularization, ... | |
*/ | |
public class SDCCNNSpark { | |
protected static final Logger log = LoggerFactory.getLogger(SDCCNNSpark.class); | |
@Parameter(names = "-useSparkLocal", description = "Use spark local (helper for testing/running without spark submit)", arity = 1) | |
private boolean useSparkLocal = true; | |
@Parameter(names = "-batchSizePerWorker", description = "Number of examples to fit each worker with") | |
private int batchSizePerWorker = 32; | |
@Parameter(names = "-numEpochs", description = "Number of epochs for training") | |
private int numEpochs = 1; | |
protected static int height = 100; | |
protected static int width = 100; | |
protected static int channels = 3; | |
protected static int numExamples = 900; | |
protected static int numLabels = 3; | |
//protected static int batchSize = 30; | |
protected static long seed = 42; | |
protected static Random rng = new Random(seed); | |
protected static int listenerFreq = 1; | |
protected static int iterations = 1; | |
protected static double splitTrainTest = 0.8; | |
//protected static int nCores = 4; | |
protected static boolean save = false; | |
protected static String modelType = "AlexNet"; // LeNet, AlexNet or Custom but you need to fill it out | |
public static void main(String[] args) throws Exception { | |
new SDCCNNSpark().run(args); | |
} | |
public void run(String[] args) throws Exception { | |
JCommander jcmdr = new JCommander(this); | |
try { | |
jcmdr.parse(args); | |
} catch (ParameterException e) { | |
//User provides invalid input -> print the usage info | |
jcmdr.usage(); | |
try { Thread.sleep(500); } catch (Exception e2) { } | |
throw e; | |
} | |
SparkConf sparkConf = new SparkConf(); | |
if (useSparkLocal) { | |
sparkConf.setMaster("local[*]"); | |
} else { | |
sparkConf.setMaster("spark://OmerMBP.local:7077"); | |
} | |
sparkConf.setAppName("SDC CNN on Spark"); | |
JavaSparkContext sc = new JavaSparkContext(sparkConf); | |
TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(batchSizePerWorker) //Each DataSet object: contains (by default) 32 examples | |
.averagingFrequency(5) | |
.workerPrefetchNumBatches(2) //Async prefetching: 2 examples per worker | |
.batchSizePerWorker(batchSizePerWorker) | |
.build(); | |
log.info("Load data ..."); | |
/**cd | |
* Data Setup -> organize and limit data file paths: | |
* - mainPath = path to image files | |
* - fileSplit = define basic dataset split with limits on format | |
* - pathFilter = define additional file load filter to limit size and balance batch content | |
**/ | |
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); | |
File mainPath = new File(System.getProperty("user.home"), "sdcdata/"); | |
FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng); | |
BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, batchSizePerWorker); | |
/** | |
* Data Setup -> train test split | |
* - inputSplit = define train and test split | |
**/ | |
InputSplit[] inputSplit = fileSplit.sample(pathFilter, numExamples * (splitTrainTest), numExamples * (1 - splitTrainTest)); | |
InputSplit trainData = inputSplit[0]; | |
InputSplit testData = inputSplit[1]; | |
/** | |
* Data Setup -> normalization | |
* - how to normalize images and generate large dataset to train on | |
**/ | |
DataNormalization scaler = new ImagePreProcessingScaler(0, 1); | |
log.info("Build model ..."); | |
// Uncomment below to try AlexNet. Note change height and width to at least 100 | |
//MultiLayerNetwork network = new AlexNet(height, width, channels, numLabels, seed, iterations).init(); | |
MultiLayerNetwork network; | |
switch (modelType) { | |
case "LeNet": | |
network = lenetModel(); | |
break; | |
case "AlexNet": | |
network = alexnetModel(); | |
break; | |
case "custom": | |
network = customModel(); | |
break; | |
default: | |
throw new InvalidInputTypeException("Incorrect model provided."); | |
} | |
//network.init(); | |
//network.setListeners(new ScoreIterationListener(listenerFreq)); | |
/** | |
* Data Setup -> define how to load data into net: | |
* - recordReader = the reader that loads and converts image data pass in inputSplit to initialize | |
* - dataIter = a generator that only loads one batch at a time into memory to save memory | |
* - trainIter = uses MultipleEpochsIterator to ensure model runs through the data for all epochs | |
**/ | |
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker); | |
DataSetIterator dataIter; | |
// Train without transformations | |
recordReader.initialize(trainData); | |
dataIter = new RecordReaderDataSetIterator(recordReader, batchSizePerWorker, 1, numLabels); | |
scaler.fit(dataIter); | |
dataIter.setPreProcessor(scaler); | |
//TODO: JavaRDD here | |
List<DataSet> trainDataList = new ArrayList<>(); | |
while (dataIter.hasNext()) { | |
trainDataList.add(dataIter.next()); | |
} | |
JavaRDD<DataSet> trainDataRDD = sc.parallelize(trainDataList); | |
recordReader.initialize(testData); | |
dataIter = new RecordReaderDataSetIterator(recordReader, batchSizePerWorker, 1, numLabels); | |
scaler.fit(dataIter); | |
dataIter.setPreProcessor(scaler); | |
//TODO: JavaRDD here | |
List<DataSet> testDataList = new ArrayList<>(); | |
while (dataIter.hasNext()) { | |
testDataList.add(dataIter.next()); | |
} | |
JavaRDD<DataSet> testDataRDD = sc.parallelize(testDataList); | |
//Create the Spark network | |
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, network, tm); | |
log.info("Train model ..."); | |
//Execute training: | |
for (int i = 0; i < numEpochs; i++) { | |
sparkNet.fit(trainDataRDD); | |
log.info("Completed Epoch {}", i); | |
} | |
//Perform evaluation (distributed) | |
log.info("Evaluate model ..."); | |
Evaluation evaluation = sparkNet.evaluate(testDataRDD); | |
log.info("***** Evaluation *****"); | |
log.info(evaluation.stats()); | |
log.info("***** Example Complete *****"); | |
if (save) { //TODO: fix saving CNN | |
log.info("Save model...."); | |
//String basePath = FilenameUtils.concat(System.getProperty("user.dir"), "src/main/resources/"); | |
//NetSaverLoaderUtils.saveNetworkAndParameters(network, basePath); | |
//NetSaverLoaderUtils.saveUpdators(network, basePath); | |
} | |
log.info("****************Example finished********************"); | |
//Delete the temp training files | |
tm.deleteTempFiles(sc); | |
} | |
private ConvolutionLayer convInit(String name, int in, int out, int[] kernel, int[] stride, int[] pad, double bias) { | |
return new ConvolutionLayer.Builder(kernel, stride, pad).name(name).nIn(in).nOut(out).biasInit(bias).build(); | |
} | |
private ConvolutionLayer conv3x3(String name, int out, double bias) { | |
return new ConvolutionLayer.Builder(new int[]{3,3}, new int[] {1,1}, new int[] {1,1}).name(name).nOut(out).biasInit(bias).build(); | |
} | |
private ConvolutionLayer conv5x5(String name, int out, int[] stride, int[] pad, double bias) { | |
return new ConvolutionLayer.Builder(new int[]{5,5}, stride, pad).name(name).nOut(out).biasInit(bias).build(); | |
} | |
private SubsamplingLayer maxPool(String name, int[] kernel) { | |
return new SubsamplingLayer.Builder(kernel, new int[]{2,2}).name(name).build(); | |
} | |
private DenseLayer fullyConnected(String name, int out, double bias, double dropOut, Distribution dist) { | |
return new DenseLayer.Builder().name(name).nOut(out).biasInit(bias).dropOut(dropOut).dist(dist).build(); | |
} | |
public MultiLayerNetwork lenetModel() { | |
/** | |
* Revisde Lenet Model approach developed by ramgo2 achieves slightly above random | |
* Reference: https://gist.github.com/ramgo2/833f12e92359a2da9e5c2fb6333351c5 | |
**/ | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.iterations(iterations) | |
.regularization(false).l2(0.005) // tried 0.0001, 0.0005 | |
.activation(Activation.RELU) | |
.learningRate(0.0001) // tried 0.00001, 0.00005, 0.000001 | |
.weightInit(WeightInit.XAVIER) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.updater(Updater.RMSPROP).momentum(0.9) | |
.list() | |
.layer(0, convInit("cnn1", channels, 50 , new int[]{5, 5}, new int[]{1, 1}, new int[]{0, 0}, 0)) | |
.layer(1, maxPool("maxpool1", new int[]{2,2})) | |
.layer(2, conv5x5("cnn2", 100, new int[]{5, 5}, new int[]{1, 1}, 0)) | |
.layer(3, maxPool("maxool2", new int[]{2,2})) | |
.layer(4, new DenseLayer.Builder().nOut(500).build()) | |
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
.nOut(numLabels) | |
.activation(Activation.SOFTMAX) | |
.build()) | |
.backprop(true).pretrain(false) | |
.setInputType(InputType.convolutional(height, width, channels)) | |
.build(); | |
return new MultiLayerNetwork(conf); | |
} | |
public MultiLayerNetwork alexnetModel() { | |
/** | |
* AlexNet model interpretation based on the original paper ImageNet Classification with Deep Convolutional Neural Networks | |
* and the imagenetExample code referenced. | |
* http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf | |
**/ | |
double nonZeroBias = 1; | |
double dropOut = 0.5; | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.weightInit(WeightInit.DISTRIBUTION) | |
.dist(new NormalDistribution(0.0, 0.01)) | |
.activation(Activation.RELU) | |
.updater(Updater.NESTEROVS) | |
.iterations(iterations) | |
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) // normalize to prevent vanishing or exploding gradients | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.learningRate(1e-2) | |
.biasLearningRate(1e-2*2) | |
.learningRateDecayPolicy(LearningRatePolicy.Step) | |
.lrPolicyDecayRate(0.1) | |
.lrPolicySteps(100000) | |
.regularization(true) | |
.l2(5 * 1e-4) | |
.momentum(0.9) | |
.miniBatch(false) | |
.list() | |
.layer(0, convInit("cnn1", channels, 96, new int[]{11, 11}, new int[]{4, 4}, new int[]{3, 3}, 0)) | |
.layer(1, new LocalResponseNormalization.Builder().name("lrn1").build()) | |
.layer(2, maxPool("maxpool1", new int[]{3,3})) | |
.layer(3, conv5x5("cnn2", 256, new int[] {1,1}, new int[] {2,2}, nonZeroBias)) | |
.layer(4, new LocalResponseNormalization.Builder().name("lrn2").build()) | |
.layer(5, maxPool("maxpool2", new int[]{3,3})) | |
.layer(6,conv3x3("cnn3", 384, 0)) | |
.layer(7,conv3x3("cnn4", 384, nonZeroBias)) | |
.layer(8,conv3x3("cnn5", 256, nonZeroBias)) | |
.layer(9, maxPool("maxpool3", new int[]{3,3})) | |
.layer(10, fullyConnected("ffn1", 4096, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005))) | |
.layer(11, fullyConnected("ffn2", 4096, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005))) | |
.layer(12, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
.name("output") | |
.nOut(numLabels) | |
.activation(Activation.SOFTMAX) | |
.build()) | |
.backprop(true) | |
.pretrain(false) | |
.setInputType(InputType.convolutional(height, width, channels)) | |
.build(); | |
return new MultiLayerNetwork(conf); | |
} | |
public static MultiLayerNetwork customModel() { | |
/** | |
* Use this method to build your own custom model. | |
**/ | |
return null; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment