Skip to content

Instantly share code, notes, and snippets.

@omerwase
Created March 24, 2017 19:28
Show Gist options
  • Save omerwase/5dbd2291b346cbedbdc8a1427b61160e to your computer and use it in GitHub Desktop.
Save omerwase/5dbd2291b346cbedbdc8a1427b61160e to your computer and use it in GitHub Desktop.
CNN on Spark
package org.deeplearning4j.cnnspark;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import org.apache.commons.io.FilenameUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.deeplearning4j.util.NetSaverLoaderUtils;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
/**
* SDC Classification based on the Animal Classification example
*
* Example classification of photos from 4 different animals (bear, duck, deer, turtle).
*
* References:
* - U.S. Fish and Wildlife Service (animal sample dataset): http://digitalmedia.fws.gov/cdm/
* - Tiny ImageNet Classification with CNN: http://cs231n.stanford.edu/reports/leonyao_final.pdf
*
* CHALLENGE: Current setup gets low score results. Can you improve the scores? Some approaches:
* - Add additional images to the dataset
* - Apply more transforms to dataset
* - Increase epochs
* - Try different model configurations
* - Tune by adjusting learning rate, updaters, activation & loss functions, regularization, ...
*/
public class SDCCNNSpark {
protected static final Logger log = LoggerFactory.getLogger(SDCCNNSpark.class);
@Parameter(names = "-useSparkLocal", description = "Use spark local (helper for testing/running without spark submit)", arity = 1)
private boolean useSparkLocal = true;
@Parameter(names = "-batchSizePerWorker", description = "Number of examples to fit each worker with")
private int batchSizePerWorker = 32;
@Parameter(names = "-numEpochs", description = "Number of epochs for training")
private int numEpochs = 1;
protected static int height = 100;
protected static int width = 100;
protected static int channels = 3;
protected static int numExamples = 900;
protected static int numLabels = 3;
//protected static int batchSize = 30;
protected static long seed = 42;
protected static Random rng = new Random(seed);
protected static int listenerFreq = 1;
protected static int iterations = 1;
protected static double splitTrainTest = 0.8;
//protected static int nCores = 4;
protected static boolean save = false;
protected static String modelType = "AlexNet"; // LeNet, AlexNet or Custom but you need to fill it out
public static void main(String[] args) throws Exception {
new SDCCNNSpark().run(args);
}
public void run(String[] args) throws Exception {
JCommander jcmdr = new JCommander(this);
try {
jcmdr.parse(args);
} catch (ParameterException e) {
//User provides invalid input -> print the usage info
jcmdr.usage();
try { Thread.sleep(500); } catch (Exception e2) { }
throw e;
}
SparkConf sparkConf = new SparkConf();
if (useSparkLocal) {
sparkConf.setMaster("local[*]");
} else {
sparkConf.setMaster("spark://OmerMBP.local:7077");
}
sparkConf.setAppName("SDC CNN on Spark");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(batchSizePerWorker) //Each DataSet object: contains (by default) 32 examples
.averagingFrequency(5)
.workerPrefetchNumBatches(2) //Async prefetching: 2 examples per worker
.batchSizePerWorker(batchSizePerWorker)
.build();
log.info("Load data ...");
/**cd
* Data Setup -> organize and limit data file paths:
* - mainPath = path to image files
* - fileSplit = define basic dataset split with limits on format
* - pathFilter = define additional file load filter to limit size and balance batch content
**/
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
File mainPath = new File(System.getProperty("user.home"), "sdcdata/");
FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng);
BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, batchSizePerWorker);
/**
* Data Setup -> train test split
* - inputSplit = define train and test split
**/
InputSplit[] inputSplit = fileSplit.sample(pathFilter, numExamples * (splitTrainTest), numExamples * (1 - splitTrainTest));
InputSplit trainData = inputSplit[0];
InputSplit testData = inputSplit[1];
/**
* Data Setup -> normalization
* - how to normalize images and generate large dataset to train on
**/
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
log.info("Build model ...");
// Uncomment below to try AlexNet. Note change height and width to at least 100
//MultiLayerNetwork network = new AlexNet(height, width, channels, numLabels, seed, iterations).init();
MultiLayerNetwork network;
switch (modelType) {
case "LeNet":
network = lenetModel();
break;
case "AlexNet":
network = alexnetModel();
break;
case "custom":
network = customModel();
break;
default:
throw new InvalidInputTypeException("Incorrect model provided.");
}
//network.init();
//network.setListeners(new ScoreIterationListener(listenerFreq));
/**
* Data Setup -> define how to load data into net:
* - recordReader = the reader that loads and converts image data pass in inputSplit to initialize
* - dataIter = a generator that only loads one batch at a time into memory to save memory
* - trainIter = uses MultipleEpochsIterator to ensure model runs through the data for all epochs
**/
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
DataSetIterator dataIter;
// Train without transformations
recordReader.initialize(trainData);
dataIter = new RecordReaderDataSetIterator(recordReader, batchSizePerWorker, 1, numLabels);
scaler.fit(dataIter);
dataIter.setPreProcessor(scaler);
//TODO: JavaRDD here
List<DataSet> trainDataList = new ArrayList<>();
while (dataIter.hasNext()) {
trainDataList.add(dataIter.next());
}
JavaRDD<DataSet> trainDataRDD = sc.parallelize(trainDataList);
recordReader.initialize(testData);
dataIter = new RecordReaderDataSetIterator(recordReader, batchSizePerWorker, 1, numLabels);
scaler.fit(dataIter);
dataIter.setPreProcessor(scaler);
//TODO: JavaRDD here
List<DataSet> testDataList = new ArrayList<>();
while (dataIter.hasNext()) {
testDataList.add(dataIter.next());
}
JavaRDD<DataSet> testDataRDD = sc.parallelize(testDataList);
//Create the Spark network
SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, network, tm);
log.info("Train model ...");
//Execute training:
for (int i = 0; i < numEpochs; i++) {
sparkNet.fit(trainDataRDD);
log.info("Completed Epoch {}", i);
}
//Perform evaluation (distributed)
log.info("Evaluate model ...");
Evaluation evaluation = sparkNet.evaluate(testDataRDD);
log.info("***** Evaluation *****");
log.info(evaluation.stats());
log.info("***** Example Complete *****");
if (save) { //TODO: fix saving CNN
log.info("Save model....");
//String basePath = FilenameUtils.concat(System.getProperty("user.dir"), "src/main/resources/");
//NetSaverLoaderUtils.saveNetworkAndParameters(network, basePath);
//NetSaverLoaderUtils.saveUpdators(network, basePath);
}
log.info("****************Example finished********************");
//Delete the temp training files
tm.deleteTempFiles(sc);
}
private ConvolutionLayer convInit(String name, int in, int out, int[] kernel, int[] stride, int[] pad, double bias) {
return new ConvolutionLayer.Builder(kernel, stride, pad).name(name).nIn(in).nOut(out).biasInit(bias).build();
}
private ConvolutionLayer conv3x3(String name, int out, double bias) {
return new ConvolutionLayer.Builder(new int[]{3,3}, new int[] {1,1}, new int[] {1,1}).name(name).nOut(out).biasInit(bias).build();
}
private ConvolutionLayer conv5x5(String name, int out, int[] stride, int[] pad, double bias) {
return new ConvolutionLayer.Builder(new int[]{5,5}, stride, pad).name(name).nOut(out).biasInit(bias).build();
}
private SubsamplingLayer maxPool(String name, int[] kernel) {
return new SubsamplingLayer.Builder(kernel, new int[]{2,2}).name(name).build();
}
private DenseLayer fullyConnected(String name, int out, double bias, double dropOut, Distribution dist) {
return new DenseLayer.Builder().name(name).nOut(out).biasInit(bias).dropOut(dropOut).dist(dist).build();
}
public MultiLayerNetwork lenetModel() {
/**
* Revisde Lenet Model approach developed by ramgo2 achieves slightly above random
* Reference: https://gist.github.com/ramgo2/833f12e92359a2da9e5c2fb6333351c5
**/
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.regularization(false).l2(0.005) // tried 0.0001, 0.0005
.activation(Activation.RELU)
.learningRate(0.0001) // tried 0.00001, 0.00005, 0.000001
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.RMSPROP).momentum(0.9)
.list()
.layer(0, convInit("cnn1", channels, 50 , new int[]{5, 5}, new int[]{1, 1}, new int[]{0, 0}, 0))
.layer(1, maxPool("maxpool1", new int[]{2,2}))
.layer(2, conv5x5("cnn2", 100, new int[]{5, 5}, new int[]{1, 1}, 0))
.layer(3, maxPool("maxool2", new int[]{2,2}))
.layer(4, new DenseLayer.Builder().nOut(500).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(numLabels)
.activation(Activation.SOFTMAX)
.build())
.backprop(true).pretrain(false)
.setInputType(InputType.convolutional(height, width, channels))
.build();
return new MultiLayerNetwork(conf);
}
public MultiLayerNetwork alexnetModel() {
/**
* AlexNet model interpretation based on the original paper ImageNet Classification with Deep Convolutional Neural Networks
* and the imagenetExample code referenced.
* http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf
**/
double nonZeroBias = 1;
double dropOut = 0.5;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.weightInit(WeightInit.DISTRIBUTION)
.dist(new NormalDistribution(0.0, 0.01))
.activation(Activation.RELU)
.updater(Updater.NESTEROVS)
.iterations(iterations)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) // normalize to prevent vanishing or exploding gradients
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(1e-2)
.biasLearningRate(1e-2*2)
.learningRateDecayPolicy(LearningRatePolicy.Step)
.lrPolicyDecayRate(0.1)
.lrPolicySteps(100000)
.regularization(true)
.l2(5 * 1e-4)
.momentum(0.9)
.miniBatch(false)
.list()
.layer(0, convInit("cnn1", channels, 96, new int[]{11, 11}, new int[]{4, 4}, new int[]{3, 3}, 0))
.layer(1, new LocalResponseNormalization.Builder().name("lrn1").build())
.layer(2, maxPool("maxpool1", new int[]{3,3}))
.layer(3, conv5x5("cnn2", 256, new int[] {1,1}, new int[] {2,2}, nonZeroBias))
.layer(4, new LocalResponseNormalization.Builder().name("lrn2").build())
.layer(5, maxPool("maxpool2", new int[]{3,3}))
.layer(6,conv3x3("cnn3", 384, 0))
.layer(7,conv3x3("cnn4", 384, nonZeroBias))
.layer(8,conv3x3("cnn5", 256, nonZeroBias))
.layer(9, maxPool("maxpool3", new int[]{3,3}))
.layer(10, fullyConnected("ffn1", 4096, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005)))
.layer(11, fullyConnected("ffn2", 4096, nonZeroBias, dropOut, new GaussianDistribution(0, 0.005)))
.layer(12, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.name("output")
.nOut(numLabels)
.activation(Activation.SOFTMAX)
.build())
.backprop(true)
.pretrain(false)
.setInputType(InputType.convolutional(height, width, channels))
.build();
return new MultiLayerNetwork(conf);
}
public static MultiLayerNetwork customModel() {
/**
* Use this method to build your own custom model.
**/
return null;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment