Last active
December 14, 2015 15:09
-
-
Save tomelm/5105779 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class RandomizedOptimization { | |
public static void main(String args[]) { | |
NNWithRandomHillClimbing rhc = new NNWithRandomHillClimbing(); | |
NNWithSimulatedAnnealing sa = new NNWithSimulatedAnnealing(); | |
NNWithGeneticAlgorithm ga = new NNWithGeneticAlgorithm(); | |
Long start; | |
// Run neural network with different random hill climbing iterations | |
try { | |
System.out.println("Running neural networks with random hill climbing at different iterations:"); | |
System.out.println("Starting RHC with 100 iterations..."); | |
start = System.currentTimeMillis(); | |
rhc.run(100); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
System.out.println("Starting RHC with 500 iterations..."); | |
start = System.currentTimeMillis(); | |
rhc.run(500); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
System.out.println("Starting RHC with 1000 iterations..."); | |
start = System.currentTimeMillis(); | |
rhc.run(1000); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
System.out.println("Starting RHC with 2500 iterations..."); | |
start = System.currentTimeMillis(); | |
rhc.run(2500); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
System.out.println("Starting RHC with 5000 iterations..."); | |
start = System.currentTimeMillis(); | |
rhc.run(5000); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
System.out.println("Starting RHC with 50000 iterations..."); | |
start = System.currentTimeMillis(); | |
rhc.run(50000); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
} catch (Exception e) { | |
System.out.println("Something went terribly wrong with the random hill climb!"); | |
System.out.println(e); | |
} | |
/* | |
// Run neural network with different simulated annealing iterations | |
try { | |
System.out.println("Running neural networks with simulated annealing at different iterations:"); | |
System.out.println("Starting SA with 100 iterations..."); | |
start = System.currentTimeMillis(); | |
sa.run(100); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
System.out.println("Starting SA with 500 iterations..."); | |
start = System.currentTimeMillis(); | |
sa.run(500); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
System.out.println("Starting SA with 1000 iterations..."); | |
start = System.currentTimeMillis(); | |
sa.run(1000); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
System.out.println("Starting SA with 2500 iterations..."); | |
start = System.currentTimeMillis(); | |
sa.run(2500); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
System.out.println("Starting SA with 5000 iterations..."); | |
start = System.currentTimeMillis(); | |
sa.run(5000); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
} catch (Exception e) { | |
System.out.println("Something went terribly wrong with the simulated annealing!"); | |
System.out.println(e); | |
} | |
// Run neural network with different genetic algorithm iterations | |
try { | |
System.out.println("Running neural networks with genetic algorithm optimizations at different iterations:"); | |
System.out.println("Starting GA with 100 iterations..."); | |
start = System.currentTimeMillis(); | |
ga.run(100); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
System.out.println("Starting GA with 500 iterations..."); | |
start = System.currentTimeMillis(); | |
ga.run(500); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
System.out.println("Starting GA with 1000 iterations..."); | |
start = System.currentTimeMillis(); | |
ga.run(1000); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
System.out.println("Starting GA with 2500 iterations..."); | |
start = System.currentTimeMillis(); | |
ga.run(2500); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
System.out.println("Starting GA with 5000 iterations..."); | |
start = System.currentTimeMillis(); | |
ga.run(5000); | |
System.out.println(); | |
System.out.println("Ran in: " + (System.currentTimeMillis() - start) + " seconds"); | |
} catch (Exception e) { | |
System.out.println("Something went terribly wrong with the genetic algorithm!"); | |
System.out.println(e); | |
}*/ | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import opt.OptimizationAlgorithm; | |
import opt.RandomizedHillClimbing; | |
import opt.example.NeuralNetworkOptimizationProblem; | |
import shared.DataSet; | |
import shared.DataSetDescription; | |
import shared.ErrorMeasure; | |
import shared.FixedIterationTrainer; | |
import shared.Instance; | |
import shared.SumOfSquaresError; | |
import shared.filt.LabelSplitFilter; | |
import shared.tester.AccuracyTestMetric; | |
import shared.tester.ConfusionMatrixTestMetric; | |
import shared.tester.NeuralNetworkTester; | |
import shared.tester.TestMetric; | |
import shared.tester.Tester; | |
import shared.reader.CSVDataSetReader; | |
import shared.reader.ArffDataSetReader; | |
import shared.reader.DataSetLabelBinarySeperator; | |
import func.nn.feedfwd.FeedForwardNetwork; | |
import func.nn.feedfwd.FeedForwardNeuralNetworkFactory; | |
public class NNWithRandomHillClimbing { | |
public void run(int iterations) throws Exception { | |
// 1) Construct data instances for training. These will also be run | |
// through the network at the bottom to verify the output | |
//CSVDataSetReader reader = new CSVDataSetReader("data/subset.data"); | |
ArffDataSetReader reader = new ArffDataSetReader("data/iris.arff"); | |
DataSet set = reader.read(); | |
LabelSplitFilter flt = new LabelSplitFilter(); | |
flt.filter(set); | |
DataSetDescription desc = set.getDescription(); | |
DataSetDescription labelDesc = desc.getLabelDescription(); | |
// 2) Instantiate a network using the FeedForwardNeuralNetworkFactory. This network | |
// will be our classifier. | |
FeedForwardNeuralNetworkFactory factory = new FeedForwardNeuralNetworkFactory(); | |
// 2a) These numbers correspond to the number of nodes in each layer. | |
// This network has 4 input nodes, 3 hidden nodes in 1 layer, and 1 output node in the output layer. | |
FeedForwardNetwork network = factory.createClassificationNetwork(new int[] { 4, 5, 3 }); | |
// 3) Instantiate a measure, which is used to evaluate each possible set of weights. | |
ErrorMeasure measure = new SumOfSquaresError(); | |
// 4) Instantiate a DataSet, which adapts a set of instances to the optimization problem. | |
//DataSet set = new DataSet(patterns); | |
// 5) Instantiate an optimization problem, which is used to specify the dataset, evaluation | |
// function, mutator and crossover function (for Genetic Algorithms), and any other | |
// parameters used in optimization. | |
NeuralNetworkOptimizationProblem nno = new NeuralNetworkOptimizationProblem( | |
set, network, measure); | |
// 6) Instantiate a specific OptimizationAlgorithm, which defines how we pick our next potential | |
// hypothesis. | |
OptimizationAlgorithm o = new RandomizedHillClimbing(nno); | |
// 7) Instantiate a trainer. The FixtIterationTrainer takes another trainer (in this case, | |
// an OptimizationAlgorithm) and executes it a specified number of times. | |
FixedIterationTrainer fit = new FixedIterationTrainer(o, iterations); | |
// 8) Run the trainer. This may take a little while to run, depending on the OptimizationAlgorithm, | |
// size of the data, and number of iterations. | |
fit.train(); | |
// 9) Once training is done, get the optimal solution from the OptimizationAlgorithm. These are the | |
// optimal weights found for this network. | |
Instance opt = o.getOptimal(); | |
network.setWeights(opt.getData()); | |
//10) Run the training data through the network with the weights discovered through optimization, and | |
// print out the expected label and result of the classifier for each instance. | |
int[] labels = {0,1,2}; | |
TestMetric acc = new AccuracyTestMetric(); | |
TestMetric cm = new ConfusionMatrixTestMetric(labels); | |
Tester t = new NeuralNetworkTester(network, acc, cm); | |
t.test(set.getInstances()); | |
acc.printResults(); | |
cm.printResults(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment