Last active
January 10, 2017 12:53
-
-
Save imjacobclark/d55e527abc453f773dc0888fa5ffdcd0 to your computer and use it in GitHub Desktop.
XOR Machine Learning Example (Encog, Java 8 + annotated)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package xyz.jacobclark; | |
import org.encog.engine.network.activation.ActivationSigmoid; | |
import org.encog.ml.data.MLData; | |
import org.encog.ml.data.MLDataPair; | |
import org.encog.ml.data.basic.BasicMLDataSet; | |
import org.encog.neural.networks.BasicNetwork; | |
import org.encog.neural.networks.layers.BasicLayer; | |
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation; | |
public class MachineLearningApplication { | |
public static void main(String args[]) { | |
/* | |
Encog + Java 8 | |
Machine Learning 3 layer neural network which classifies whether an XOR expression is truthy or falsey in a range of 0d - 1d | |
Overview: Neural network with a sigmoid activation function which takes input and ideal outputs through an resilient propagation training algorithm | |
Topics Covered: | |
Training data setup | |
Neural network setup | |
Network training | |
Running against real values | |
Neurals: (simplified for text) | |
I1 -(bias)> H1 | |
-> O1 | |
I1 -(bias)> H2 | |
Activation function: | |
Sigmoid | |
Training method: | |
Rprop (Resilient Propagation) | |
*/ | |
// ----- | |
// Training input | |
// Training data set | |
// XOR -> Exclusive OR training set | |
// Performs ML on the real values of an XOR operation. | |
// Example XOR operations and their values | |
// 0.0 + 0.0 = False | |
// 1.0 + 0.0 = True | |
// 0.0 + 1.0 = True | |
// 1.0 + 1.0 = False | |
double[][] XOR_INPUT = { | |
{ 0.0, 0.0 }, | |
{ 1.0, 0.0 }, | |
{ 0.0, 1.0 }, | |
{ 1.0, 1.0 } | |
}; | |
// Training input XOR evaluations | |
double[][] XOR_IDEAL = { | |
{0.0}, | |
{1.0}, | |
{1.0}, | |
{0.0} | |
}; | |
BasicMLDataSet basicMLDataSet = new BasicMLDataSet(XOR_INPUT, XOR_IDEAL); | |
// ----- | |
// Neural Network | |
// Create the network | |
BasicNetwork network = new BasicNetwork(); | |
// Create input network layer | |
// Creates a layer with 2 neurons and a bias nuron | |
BasicLayer input = new BasicLayer(null, true, 2); | |
// Sigmoid = Mathematical Function | |
// Create hidden network layer | |
// Takes the activation sigmoid and creates a layer with 2 neurons and a bias nuron | |
BasicLayer hidden = new BasicLayer(new ActivationSigmoid(), true, 2); | |
// Create output network layer | |
// Takes the activation sigmoid and creates a layer with 1 neuron | |
BasicLayer output = new BasicLayer(new ActivationSigmoid(), false, 1); | |
// Add network layers | |
network.addLayer(input); | |
network.addLayer(hidden); | |
network.addLayer(output); | |
// All networks are added | |
network.getStructure().finalizeStructure(); | |
// Randomly init all weights | |
network.reset(); | |
// ----- | |
// Training | |
// Create training data set (regression) | |
ResilientPropagation resilientPropagation = new ResilientPropagation(network, basicMLDataSet); | |
// Train the network until the error rate gets very low (1%) | |
do { | |
resilientPropagation.iteration(); | |
} while(resilientPropagation.getError() > 0.01); | |
// ----- | |
// Running ideal data against trained network | |
// Iterate over each pair of training data created above | |
for(MLDataPair pair: basicMLDataSet ) { | |
// Compute the real value (hypotehsis function) | |
final MLData data = network.compute(pair.getInput()); | |
// Log out the values | |
System.out.println( | |
"Input = " | |
+ pair.getInput().getData(0) + ", " + pair.getInput().getData(1) | |
+ " | Computed = " + data.getData(0) | |
+ " | Ideal = " + pair.getIdeal().getData(0) | |
); | |
} | |
// ----- | |
// Program output | |
/* | |
Input = 0.0, 0.0 | Computed = 0.13492055225524496 | Ideal = 0.0 | |
Input = 1.0, 0.0 | Computed = 0.9480355027515462 | Ideal = 1.0 | |
Input = 0.0, 1.0 | Computed = 0.9133561421396043 | Ideal = 1.0 | |
Input = 1.0, 1.0 | Computed = 0.06562084788083862 | Ideal = 0.0 | |
*/ | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment