Created
August 20, 2018 08:30
-
-
Save gfrison/e6a7e3a1a1626f5daf2865ca87880366 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.gradient.Gradient; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.slf4j.Logger; | |
import java.util.Arrays; | |
import static java.util.stream.IntStream.range; | |
import static org.slf4j.LoggerFactory.getLogger; | |
/** | |
* https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/misc/externalerrors/MultiLayerNetworkExternalErrors.java | |
*/ | |
public class Sum { | |
private static final Logger log = getLogger(Sum.class); | |
public static void main(String[] args) { | |
//Create the model | |
int nIn = 2; | |
int nOut = 1; | |
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() | |
.activation(Activation.RELU) | |
.weightInit(WeightInit.XAVIER) | |
.graphBuilder() | |
.addInputs("input") | |
.addLayer("0", new DenseLayer.Builder() | |
.activation(Activation.TANH) | |
.nIn(nIn).nOut(10).build(), "input") | |
.addLayer("output", new DenseLayer.Builder() | |
.activation(Activation.IDENTITY) | |
.nIn(10).nOut(nOut).build(), "0") | |
.setOutputs("output") | |
.backprop(true).pretrain(false) | |
.build(); | |
ComputationGraph model = new ComputationGraph(conf); | |
model.setListeners(new ScoreIterationListener(1)); | |
model.init(); | |
//Calculate gradient with respect to an external error | |
int minibatch = 1000; | |
range(0, 1000).forEach(epoch -> { | |
INDArray input = Nd4j.rand(minibatch, nIn); | |
//Do forward pass, but don't clear the input activations in each layers - we need those set so we can calculate | |
// gradients based on them | |
INDArray out = model.feedForward(new INDArray[]{input}, true, false).get("output"); | |
double[] errs = new double[minibatch]; | |
for (int i = 0; i < minibatch; i++) { | |
int ii = i; | |
double sum = range(0, nIn).mapToDouble(t -> input.getDouble(ii, t)).sum(); | |
final double predictedSum = out.getDouble(i); | |
double err = Math.abs(sum - predictedSum); | |
// System.out.printf("predicted:%.2f, actual:%.2f, err:%.2f \n", predictedSum, sum, err); | |
errs[i] = err; | |
} | |
System.out.printf("avg err: %.2f\n", Arrays.stream(errs).average().getAsDouble()); | |
INDArray externalError = Nd4j.create(errs, new int[]{minibatch, 1}); | |
Gradient gradient = model.backpropGradient(externalError); //Calculate backprop gradient based on error array | |
//Update the gradient: apply learning rate, momentum, etc | |
//This modifies the Gradient object in-place | |
int iteration = 0; | |
model.getUpdater().update(gradient, iteration, epoch, minibatch, LayerWorkspaceMgr.noWorkspaces()); | |
//Get a row vector gradient array, and apply it to the parameters to update the model | |
INDArray updateVector = gradient.gradient(); | |
model.params().subi(updateVector); | |
}); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment