Created
October 13, 2016 03:32
-
-
Save marcusking01/52bda89b4124ec9c2f953c08ac6ba3eb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package RandomizedOptimization; | |
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics; | |
import shared.Trainer; | |
/** | |
* A momentum convergence trainer trains a network | |
* until convergence based on a sliding window of statistical variance on the error, using another trainer | |
* @author Marcus King | |
* @version 1.0 | |
*/ | |
public class MomentumConvergenceTrainer implements Trainer { | |
/** The default threshold */ | |
private static final double THRESHOLD = .05; | |
/** The maxium number of iterations */ | |
private static final int MAX_ITERATIONS = 5000; | |
/** | |
* The trainer | |
*/ | |
private Trainer trainer; | |
/** | |
* The threshold | |
*/ | |
private double threshold; | |
/** | |
* The number of iterations trained | |
*/ | |
private int iterations; | |
/** | |
* The maximum number of iterations to use | |
*/ | |
private int maxIterations; | |
/** | |
* Create a new convergence trainer | |
* @param trainer the thrainer to use | |
* @param threshold the error threshold | |
* @param maxIterations the maximum iterations | |
*/ | |
public MomentumConvergenceTrainer(Trainer trainer, | |
double threshold, int maxIterations) { | |
this.trainer = trainer; | |
this.threshold = threshold; | |
this.maxIterations = maxIterations; | |
} | |
/** | |
* Create a new convergence trainer | |
* @param trainer the trainer to use | |
*/ | |
public MomentumConvergenceTrainer(Trainer trainer) { | |
this(trainer, THRESHOLD, MAX_ITERATIONS); | |
} | |
/** | |
* @see Trainer#train() | |
*/ | |
public double train() { | |
int window = Math.max(20, (int) (maxIterations*.03)); | |
DescriptiveStatistics dstats = new DescriptiveStatistics(window); | |
double error = Double.MAX_VALUE; | |
do { | |
iterations++; | |
error = trainer.train(); | |
dstats.addValue(error); | |
} while ((iterations < window || dstats.getVariance() > threshold) && iterations < maxIterations); | |
return error; | |
} | |
/** | |
* Get the number of iterations used | |
* @return the number of iterations | |
*/ | |
public int getIterations() { | |
return iterations; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment