Skip to content

Instantly share code, notes, and snippets.

@BartKeulen
Last active April 4, 2017 22:35
Show Gist options
  • Save BartKeulen/00ef00291338f1ad662089f62439a374 to your computer and use it in GitHub Desktop.
Save BartKeulen/00ef00291338f1ad662089f62439a374 to your computer and use it in GitHub Desktop.
Deeplearning4j version: 0.8.0; OS: Ubuntu 16.04.1 LTS; Java version: 1.8.0_121
public class GradientTest {
private static final Logger logger = LoggerFactory.getLogger(GradientTest.class);
public static void main(String[] args) {
//Create the model
int nIn = 2;
int nOut = 1;
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.updater(Updater.SGD)
.learningRate(0.1)
.list()
.layer(0, new DenseLayer.Builder().nIn(nIn).nOut(3).build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3).build())
.layer(2, new DenseLayer.Builder().activation(Activation.TANH).nIn(3).nOut(nOut).build())
.backprop(true).pretrain(false)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
logger.info("Number parameters: " + model.numParams() + ", Number layers: " + model.getnLayers());
logger.info("Number params layer 1: " + model.getLayer(0).numParams() + ", params: " + model.getLayer(0).params().toString());
logger.info("Number params layer 2: " + model.getLayer(1).numParams() + ", params: " + model.getLayer(1).params().toString());
logger.info("Number params layer 3: " + model.getLayer(2).numParams() + ", params: " + model.getLayer(2).params().toString());
//Calculate gradient with respect to an external error
int minibatch = 1;
INDArray input = Nd4j.rand(minibatch, nIn);
INDArray output = model.output(input); //Do forward pass. Normally: calculate the error based on this
logger.info("Input: " + input.toString());
logger.info("Output: " + output.toString());
List<INDArray> zValues = model.computeZ(input.getRow(0), false);
logger.info("wrong z values: " + zValues.toString());
logger.info("pre-output: " + model.preOutput(input.getRow(0)));
zValues = new ArrayList<>();
zValues.add(input);
for (int i = 0; i < model.getnLayers(); i++)
{
zValues.add(model.zFromPrevLayer(i, input, false));
input = model.activationFromPrevLayer(i, input, false);
}
logger.info("correct z values: " + zValues.toString());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment