Skip to content

Instantly share code, notes, and snippets.

@agibsonccc
Created August 27, 2014 15:52
Show Gist options
  • Save agibsonccc/736ba27163e5abf6d3eb to your computer and use it in GitHub Desktop.
Save agibsonccc/736ba27163e5abf6d3eb to your computer and use it in GitHub Desktop.
package org.deeplearning4j.nn.conf;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.nn.WeightInit;
import org.deeplearning4j.nn.api.NeuralNetwork;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
/**
* A Serializable configuration
*/
public class NeuralNetConfiguration implements Serializable {
private float sparsity;
private boolean useAdaGrad = true;
private float lr = 1e-1f;
/* momentum for learning */
protected float momentum = 0.5f;
/* L2 Regularization constant */
protected float l2 = 0.1f;
protected boolean useRegularization = false;
//momentum after n iterations
protected Map<Integer,Float> momentumAfter = new HashMap<>();
//reset adagrad historical gradient after n iterations
protected int resetAdaGradIterations = -1;
protected float dropOut = 0;
//use only when binary hidden layers are active
protected boolean applySparsity = false;
//weight init scheme, this can either be a distribution or a applyTransformToDestination scheme
protected WeightInit weightInit;
protected NeuralNetwork.OptimizationAlgorithm optimizationAlgo;
protected int renderWeightsEveryNumEpochs = -1;
//whether to concat hidden bias or add it
protected boolean concatBiases = false;
//whether to constrain the gradient to unit norm or not
protected boolean constrainGradientToUnitNorm = false;
/* RNG for sampling. */
protected RandomGenerator rng;
protected long seed = 123;
public NeuralNetConfiguration(float sparsity, boolean useAdaGrad, float lr, float momentum, float l2, boolean useRegularization, Map<Integer, Float> momentumAfter, int resetAdaGradIterations, float dropOut, boolean applySparsity, WeightInit weightInit, NeuralNetwork.OptimizationAlgorithm optimizationAlgo, int renderWeightsEveryNumEpochs, boolean concatBiases, boolean constrainGradientToUnitNorm, RandomGenerator rng, long seed) {
this.sparsity = sparsity;
this.useAdaGrad = useAdaGrad;
this.lr = lr;
this.momentum = momentum;
this.l2 = l2;
this.useRegularization = useRegularization;
this.momentumAfter = momentumAfter;
this.resetAdaGradIterations = resetAdaGradIterations;
this.dropOut = dropOut;
this.applySparsity = applySparsity;
this.weightInit = weightInit;
this.optimizationAlgo = optimizationAlgo;
this.renderWeightsEveryNumEpochs = renderWeightsEveryNumEpochs;
this.concatBiases = concatBiases;
this.constrainGradientToUnitNorm = constrainGradientToUnitNorm;
this.rng = rng;
this.seed = seed;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment