Skip to content

Instantly share code, notes, and snippets.

@marcusking01
Created October 13, 2016 03:32
Show Gist options
  • Save marcusking01/52bda89b4124ec9c2f953c08ac6ba3eb to your computer and use it in GitHub Desktop.
Save marcusking01/52bda89b4124ec9c2f953c08ac6ba3eb to your computer and use it in GitHub Desktop.
package RandomizedOptimization;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import shared.Trainer;
/**
* A momentum convergence trainer trains a network
* until convergence based on a sliding window of statistical variance on the error, using another trainer
* @author Marcus King
* @version 1.0
*/
public class MomentumConvergenceTrainer implements Trainer {
/** The default threshold */
private static final double THRESHOLD = .05;
/** The maxium number of iterations */
private static final int MAX_ITERATIONS = 5000;
/**
* The trainer
*/
private Trainer trainer;
/**
* The threshold
*/
private double threshold;
/**
* The number of iterations trained
*/
private int iterations;
/**
* The maximum number of iterations to use
*/
private int maxIterations;
/**
* Create a new convergence trainer
* @param trainer the thrainer to use
* @param threshold the error threshold
* @param maxIterations the maximum iterations
*/
public MomentumConvergenceTrainer(Trainer trainer,
double threshold, int maxIterations) {
this.trainer = trainer;
this.threshold = threshold;
this.maxIterations = maxIterations;
}
/**
* Create a new convergence trainer
* @param trainer the trainer to use
*/
public MomentumConvergenceTrainer(Trainer trainer) {
this(trainer, THRESHOLD, MAX_ITERATIONS);
}
/**
* @see Trainer#train()
*/
public double train() {
int window = Math.max(20, (int) (maxIterations*.03));
DescriptiveStatistics dstats = new DescriptiveStatistics(window);
double error = Double.MAX_VALUE;
do {
iterations++;
error = trainer.train();
dstats.addValue(error);
} while ((iterations < window || dstats.getVariance() > threshold) && iterations < maxIterations);
return error;
}
/**
* Get the number of iterations used
* @return the number of iterations
*/
public int getIterations() {
return iterations;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment