Created
September 28, 2018 12:40
-
-
Save C4N4D4M4N/ff46cebf1a0d7024f3a9ad75aae03630 to your computer and use it in GitHub Desktop.
Testing Batch Normalization, Reduces Accuracy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public static ComputationGraphConfiguration getBatchNormModel() { | |
int[][][] i; | |
if (LENGTHWISE) { | |
i = new int[][][] {{{0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0}}}; | |
} else { | |
i = new int[][][] {{{0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}}}; | |
} | |
int channels = i.length; | |
int length = i[0].length; | |
int width = i[0][0].length; | |
InputType graph = InputType.convolutional(channels, length, width); | |
int itemOut = 16; | |
int marketOut = 16; | |
int playersOut = 8; | |
int outputCNN = itemOut * 4 + itemOut * 4 + marketOut * 2 + marketOut * 2 + playersOut * 2; | |
HashMap<String, InputPreProcessor> preProcessors = new HashMap<>(); | |
preProcessors.put("CNN -> ANN", new CnnToFeedForwardPreProcessor(length, width, outputCNN)); | |
outputCNN *= length; | |
outputCNN *= width; | |
ComputationGraphConfiguration.GraphBuilder config = new NeuralNetConfiguration.Builder() | |
.weightInit(XAVIER) | |
.activation(RELU) | |
.updater(new Adam.Builder() | |
.learningRateSchedule(new MapSchedule(EPOCH, LR_SCHEDULE)) | |
.build()) | |
.convolutionMode(Same) | |
.l2(0.0001) | |
.seed(SEED) | |
.graphBuilder() | |
.addInputs("Item Price", "Item Count", "Market Price", "Market Count", "Player Count", "Neural Data") | |
.setInputTypes(graph, graph, graph, graph, graph, InputType.feedForward(ANN_INDEXES.length)) | |
// Item price CNN layers | |
.addLayer("Item Price CNN 1", new ConvolutionLayer.Builder() | |
.kernelSize(channels, length) | |
.stride(1, 1) | |
.padding(0, 0) | |
.nIn(1) | |
.nOut(itemOut) | |
.build(), "Item Price") | |
.addLayer("Item Price CNN 2", new ConvolutionLayer.Builder() | |
.kernelSize(channels, length) | |
.stride(1, 1) | |
.padding(0, 0) | |
.nIn(itemOut) | |
.nOut(itemOut * 2) | |
.build(), "Item Price CNN 1") | |
.addLayer("Item Price CNN 3", new ConvolutionLayer.Builder() | |
.kernelSize(channels, length) | |
.stride(1, 1) | |
.padding(0, 0) | |
.nIn(itemOut * 2) | |
.nOut(itemOut * 4) | |
.build(), "Item Price CNN 2") | |
// Item count CNN layers | |
.addLayer("Item Count CNN 1", new ConvolutionLayer.Builder() | |
.kernelSize(channels, length) | |
.stride(1, 1) | |
.padding(0, 0) | |
.nIn(1) | |
.nOut(itemOut) | |
.build(), "Item Count") | |
.addLayer("Item Count CNN 2", new ConvolutionLayer.Builder() | |
.kernelSize(channels, length) | |
.stride(1, 1) | |
.padding(0, 0) | |
.nIn(itemOut) | |
.nOut(itemOut * 2) | |
.build(), "Item Count CNN 1") | |
.addLayer("Item Count CNN 3", new ConvolutionLayer.Builder() | |
.kernelSize(channels, length) | |
.stride(1, 1) | |
.padding(0, 0) | |
.nIn(itemOut * 2) | |
.nOut(itemOut * 4) | |
.build(), "Item Count CNN 2") | |
// Market price CNN layers | |
.addLayer("Market Price CNN 1", new ConvolutionLayer.Builder() | |
.kernelSize(channels, length) | |
.stride(1, 1) | |
.padding(0, 0) | |
.nIn(1) | |
.nOut(marketOut) | |
.build(), "Market Price") | |
.addLayer("Market Price CNN 2", new ConvolutionLayer.Builder() | |
.kernelSize(channels, length) | |
.stride(1, 1) | |
.padding(0, 0) | |
.nIn(marketOut) | |
.nOut(marketOut * 2) | |
.build(), "Market Price CNN 1") | |
// Market count CNN layers | |
.addLayer("Market Count CNN 1", new ConvolutionLayer.Builder() | |
.kernelSize(channels, length) | |
.stride(1, 1) | |
.padding(0, 0) | |
.nIn(1) | |
.nOut(marketOut) | |
.build(), "Market Count") | |
.addLayer("Market Count CNN 2", new ConvolutionLayer.Builder() | |
.kernelSize(channels, length) | |
.stride(1, 1) | |
.padding(0, 0) | |
.nIn(marketOut) | |
.nOut(marketOut * 2) | |
.build(), "Market Count CNN 1") | |
// Player count CNN layers | |
.addLayer("Player Count CNN 1", new ConvolutionLayer.Builder() | |
.kernelSize(channels, length) | |
.stride(1, 1) | |
.padding(0, 0) | |
.nIn(1) | |
.nOut(playersOut) | |
.build(), "Player Count") | |
.addLayer("Player Count CNN 2", new ConvolutionLayer.Builder() | |
.kernelSize(channels, length) | |
.stride(1, 1) | |
.padding(0, 0) | |
.nIn(playersOut) | |
.nOut(playersOut * 2) | |
.build(), "Player Count CNN 1") | |
// Merge all CNN graph analysis layers and format them to work with DenseLayers | |
.addVertex("Graph Merge", new MergeVertex(), "Item Price CNN 3", "Item Count CNN 3", "Market Price CNN 2", "Market Count CNN 2", "Player Count CNN 2") | |
.addLayer("CNN -> ANN", new DenseLayer.Builder() | |
.nIn(outputCNN) | |
.nOut(128) | |
.build(), "Graph Merge") | |
.addLayer("CNN -> ANN Norm", new BatchNormalization(), "CNN -> ANN") | |
.addVertex("CNN ANN Merge", new MergeVertex(), "CNN -> ANN Norm", "Neural Data") | |
// ANN layers for graph data plus neural data | |
.addLayer("ANN 1", new DenseLayer.Builder() | |
.nOut(256) | |
.build(), "CNN ANN Merge") | |
.addLayer("ANN 1 Norm", new BatchNormalization(), "ANN 1") | |
.addLayer("ANN 2", new DenseLayer.Builder() | |
.nIn(256) | |
.nOut(128) | |
.build(), "ANN 1 Norm") | |
.addLayer("ANN 2 Norm", new BatchNormalization(), "ANN 2") | |
.addLayer("ANN 3", new DenseLayer.Builder() | |
.nIn(128) | |
.nOut(64) | |
.build(), "ANN 2 Norm") | |
.addLayer("ANN 3 Norm", new BatchNormalization(), "ANN 3") | |
.addLayer("ANN 4", new DenseLayer.Builder() | |
.nIn(64) | |
.nOut(32) | |
.build(), "ANN 3 Norm") | |
.addLayer("ANN 4 Norm", new BatchNormalization(), "ANN 4") | |
.addLayer("ANN 5", new DenseLayer.Builder() | |
.nIn(32) | |
.nOut(16) | |
.build(), "ANN 4 Norm") | |
// Regression output | |
.addLayer("Output", new OutputLayer.Builder() | |
.lossFunction(MSE) // MSE tends to be better as it allows the model to better understand outliers. | |
.activation(IDENTITY) // When this isn't Identity, predictions tend to have a "minimum" value and it converges more slowly and erratically. | |
.nIn(16) | |
.nOut(1) | |
.build(), "ANN 5") | |
.setOutputs("Output"); | |
config.setInputPreProcessors(preProcessors); | |
return config.build(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment