Created
May 24, 2018 23:10
-
-
Save MaxLeiter/a939deda099d33e175b9bd8b065336ef to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package examples; | |
import org.datavec.api.io.filters.BalancedPathFilter; | |
import org.datavec.api.io.labels.ParentPathLabelGenerator; | |
import org.datavec.api.split.FileSplit; | |
import org.datavec.api.split.InputSplit; | |
import org.datavec.api.util.ClassPathResource; | |
import org.datavec.image.loader.BaseImageLoader; | |
import org.datavec.image.recordreader.ImageRecordReader; | |
import org.datavec.image.transform.*; | |
import org.deeplearning4j.api.storage.StatsStorage; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.conf.*; | |
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution; | |
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.layers.*; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.deeplearning4j.ui.api.UIServer; | |
import org.deeplearning4j.ui.stats.StatsListener; | |
import org.deeplearning4j.ui.storage.InMemoryStatsStorage; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import java.io.File; | |
import java.util.Arrays; | |
import java.util.List; | |
import java.util.Random; | |
public class Glaucoma { | |
private static final String[] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS; | |
private static final long seed = 12345; | |
private static final int epochs = 10; // epochés | |
private static final Random randNumGen = new Random(seed); | |
private static final int height = 300; // image info | |
private static final int width = 300; | |
private static final int channels = 1; // RGB = 3, but we're converting to Grayscale | |
private static final int numLabels = 2; // Yes or No glaucoma | |
private static final int batchSize = 200; // Images to train per epoch | |
private static final int labelIndex = 1; | |
public static void main(String[] args) throws Exception { | |
File parentDir = new ClassPathResource("/glaucoma_test/").getFile(); | |
FileSplit filesInDir = new FileSplit(parentDir, allowedExtensions, randNumGen); | |
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); | |
BalancedPathFilter pathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, labelMaker); | |
InputSplit[] filesInDirSplit = filesInDir.sample(pathFilter, 50, 50); // 80, 20 | |
InputSplit trainData = filesInDirSplit[0]; | |
InputSplit testData = filesInDirSplit[1]; | |
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker); | |
UIServer uiServer = UIServer.getInstance(); | |
StatsStorage storage = new InMemoryStatsStorage(); | |
uiServer.attach(storage); | |
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numLabels); | |
ImageTransform flipTransform1 = new FlipImageTransform(); | |
ImageTransform flipTransform2 = new FlipImageTransform(new Random(seed)); | |
List<ImageTransform> transforms = Arrays.asList(new ImageTransform[]{flipTransform1, flipTransform2}); | |
MultiLayerConfiguration conf = ourNet(); | |
MultiLayerNetwork net = new MultiLayerNetwork(conf); | |
recordReader.initialize(trainData); | |
net.init(); | |
net.setListeners(new StatsListener(storage), new ScoreIterationListener(10)); | |
MultipleEpochsIterator trainIter = new MultipleEpochsIterator(epochs, dataIter); // we will train with multiple epochs | |
net.fit(trainIter); | |
for (ImageTransform t : transforms) { // re-train on every transform | |
System.out.print("\nTraining on transformation: " + t.getClass().toString() + "\n\n"); | |
recordReader.initialize(trainData, t); | |
dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numLabels); | |
trainIter = new MultipleEpochsIterator(epochs, dataIter); | |
net.fit(trainIter); | |
} | |
recordReader.initialize(testData); | |
dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numLabels); | |
Evaluation eval = net.evaluate(dataIter); | |
System.out.println(eval.stats(true)); | |
} | |
private static MultiLayerConfiguration ourNet() { | |
double nonZeroBias = 1; | |
int[] inputShape = new int[]{channels, height, width}; | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.weightInit(WeightInit.DISTRIBUTION) | |
.dist(new NormalDistribution(0.0, 0.01)) | |
.activation(Activation.RELU) | |
.updater(Updater.NESTEROVS) | |
.convolutionMode(ConvolutionMode.Same) | |
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) // normalize to prevent vanishing or exploding gradients | |
.l2(5 * 1e-4) | |
.miniBatch(false) | |
.list() | |
.layer(0, new ConvolutionLayer.Builder(new int[]{11, 11}, new int[]{4, 4}) | |
.name("cnn1") | |
.cudnnAlgoMode(ConvolutionLayer.AlgoMode.PREFER_FASTEST) | |
.convolutionMode(ConvolutionMode.Truncate) | |
.nIn(inputShape[0]) | |
.nOut(96) | |
.build()) | |
.layer(1, new LocalResponseNormalization.Builder().build()) | |
.layer(2, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) | |
.kernelSize(3, 3) | |
.stride(2, 2) | |
.padding(1, 1) | |
.name("maxpool1") | |
.build()) | |
.layer(3, new ConvolutionLayer.Builder(new int[]{5, 5}, new int[]{1, 1}, new int[]{2, 2}) | |
.name("cnn2") | |
.cudnnAlgoMode(ConvolutionLayer.AlgoMode.PREFER_FASTEST) | |
.convolutionMode(ConvolutionMode.Truncate) | |
.nOut(256) | |
.biasInit(nonZeroBias) | |
.build()) | |
.layer(4, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[]{3, 3}, new int[]{2, 2}) | |
.convolutionMode(ConvolutionMode.Truncate) | |
.name("maxpool2") | |
.build()) | |
.layer(5, new LocalResponseNormalization.Builder().build()) | |
.layer(6, new ConvolutionLayer.Builder() | |
.kernelSize(3, 3) | |
.stride(1, 1) | |
.convolutionMode(ConvolutionMode.Same) | |
.name("cnn3") | |
.cudnnAlgoMode(ConvolutionLayer.AlgoMode.PREFER_FASTEST) | |
.nOut(384) | |
.build()) | |
.layer(7, new ConvolutionLayer.Builder(new int[]{3, 3}, new int[]{1, 1}) | |
.name("cnn4") | |
.cudnnAlgoMode(ConvolutionLayer.AlgoMode.PREFER_FASTEST) | |
.nOut(384) | |
.biasInit(nonZeroBias) | |
.build()) | |
.layer(8, new ConvolutionLayer.Builder(new int[]{3, 3}, new int[]{1, 1}) | |
.name("cnn5") | |
.cudnnAlgoMode(ConvolutionLayer.AlgoMode.PREFER_FASTEST) | |
.nOut(256) | |
.biasInit(nonZeroBias) | |
.build()) | |
.layer(9, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[]{3, 3}, new int[]{2, 2}) | |
.name("maxpool3") | |
.convolutionMode(ConvolutionMode.Truncate) | |
.build()) | |
/* .layer(10, new DenseLayer.Builder() | |
.name("ffn1") | |
.nIn(256 * 6 * 6) | |
.nOut(4096) | |
.dist(new GaussianDistribution(0, 0.005)) | |
.biasInit(nonZeroBias) | |
.build())*/ | |
.layer(10, new DenseLayer.Builder() | |
.name("ffn2") | |
.nOut(4096) | |
.weightInit(WeightInit.DISTRIBUTION).dist(new GaussianDistribution(0, 0.005)) | |
.biasInit(nonZeroBias) | |
.dropOut(0.5) | |
.build()) | |
.layer(11, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
.name("output") | |
.nOut(numLabels) | |
.activation(Activation.SOFTMAX) | |
.weightInit(WeightInit.DISTRIBUTION).dist(new GaussianDistribution(0, 0.005)) | |
.biasInit(0.1) | |
.build()) | |
.backprop(true) | |
.pretrain(false) | |
.setInputType(InputType.convolutional(inputShape[2], inputShape[1], inputShape[0])) | |
.build(); | |
return conf; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment