Created
June 20, 2017 17:38
-
-
Save funrep/c2dfb3225c0eb5956aed722f776f2d24 to your computer and use it in GitHub Desktop.
Neural net - not working
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
public class Layer { | |
private ArrayList<Neuron> nodes; | |
public Layer(int nodeCount, int inputCount) { | |
nodes = new ArrayList<>(); | |
for (int i = 0; i < nodeCount; i++) { | |
nodes.add(new Neuron(inputCount)); | |
} | |
} | |
public ArrayList<Neuron> getNodes() { | |
return nodes; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
public class Main { | |
public static void main(String[] args) { | |
// Test by learning XOR boolean operator | |
double[][] sampleIn = { { 0, 0 }, { 0, 1}, { 1, 0 }, { 1, 1 } }; | |
double[][] sampleOut = { { 0 }, { 1 }, { 1 }, { 0 } }; | |
ArrayList<ArrayList<Double>> trainingInp = conv(sampleIn); | |
ArrayList<ArrayList<Double>> trainingOut = conv(sampleOut); | |
// Two input nodes, two nodes in hidden layer, 1 output node | |
int[] layers = { 2, 2, 1 }; | |
Network net = new Network(layers, 0.01); | |
train(1000, net, trainingInp, trainingOut); | |
for (int i = 0; i < trainingInp.size(); i++) { | |
ArrayList<Double> res = net.runNetwork(trainingInp.get(i)); | |
printArr(trainingInp.get(i)); | |
System.out.print("net: "); | |
printArr(res); | |
System.out.print("sample: "); | |
printArr(trainingOut.get(i)); | |
System.out.println(); | |
} | |
} | |
public static void train(int maxEpoch, Network net, | |
ArrayList<ArrayList<Double>> trainingInp, | |
ArrayList<ArrayList<Double>> trainingOut) { | |
for (int epoch = 0; epoch < maxEpoch; epoch++) { | |
double error = 0; | |
for (int i = 0; i < trainingInp.size(); i++) { | |
ArrayList<Double> inputs = trainingInp.get(i); | |
ArrayList<Double> targets = trainingOut.get(i); | |
net.backprop(inputs, targets); | |
error += totalError(targets, net.lastOutput()); | |
} | |
System.out.println("Error: " + error); | |
} | |
} | |
public static double totalError(ArrayList<Double> target, ArrayList<Double> output) { | |
double sum = 0; | |
for (int i = 0; i < target.size(); i++) { | |
sum += 0.5 * (Math.pow(target.get(i) - output.get(i), 2)); | |
} | |
return sum; | |
} | |
public static ArrayList<ArrayList<Double>> conv(double[][] arr) { | |
ArrayList<ArrayList<Double>> newArr = new ArrayList<>(new ArrayList<>()); | |
for (int i = 0; i < arr.length; i++) { | |
ArrayList<Double> inner = new ArrayList<>(); | |
for (int j = 0; j < arr[i].length; j++) { | |
inner.add(arr[i][j]); | |
} | |
newArr.add(inner); | |
} | |
return newArr; | |
} | |
public static void printArr(ArrayList<Double> arrayList) { | |
for (Double n : arrayList) { | |
System.out.print(n + " "); | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
public class Network { | |
private ArrayList<Layer> layers; | |
private double learningRate; | |
public Network(int[] layerCounts, double learningRate) { | |
layers = new ArrayList<>(); | |
for (int i = 1; i < layerCounts.length; i++) { | |
layers.add(new Layer(layerCounts[i], layerCounts[i - 1])); | |
} | |
this.learningRate = learningRate; | |
} | |
public void backprop(ArrayList<Double> inputs, ArrayList<Double> targets) { | |
feedForward(inputs); | |
calcErrors(targets); | |
backpropErrors(inputs); | |
} | |
public ArrayList<Double> runNetwork(ArrayList<Double> input) { | |
feedForward(input); | |
ArrayList<Double> result = new ArrayList<>(); | |
for (Neuron n : layers.get(layers.size() - 1).getNodes()) { | |
result.add(n.getOutput()); | |
} | |
return result; | |
} | |
public ArrayList<Double> lastOutput() { | |
ArrayList<Double> result = new ArrayList<>(); | |
for (Neuron n : layers.get(layers.size() - 1).getNodes()) { | |
result.add(n.getOutput()); | |
} | |
return result; | |
} | |
public void feedForward(ArrayList<Double> inputs) { | |
for (Neuron n : layers.get(0).getNodes()) { | |
n.calcOut(inputs); | |
} | |
for (int i = 1; i < layers.size(); i++) { | |
ArrayList<Double> prev = new ArrayList<>(); | |
for (Neuron n : layers.get(i - 1).getNodes()) { | |
prev.add(n.getOutput()); | |
} | |
for (Neuron n : layers.get(i).getNodes()) { | |
n.calcOut(prev); | |
} | |
} | |
} | |
public void calcErrors(ArrayList<Double> targets) { | |
// Calculate output layer errors | |
for (int i = 0; i < targets.size(); i++) { | |
double t = targets.get(i); | |
Neuron n = layers.get(layers.size() - 1).getNodes().get(i); | |
double o = n.getOutput(); | |
n.setError((t - o) * o * (1 - o)); | |
} | |
// Calculate hidden layers errors | |
for (int i = layers.size() - 2; i >= 0; i--) { | |
for (Neuron n : layers.get(i).getNodes()) { | |
double sum = 0; | |
for (double w : n.getWeights()) { | |
sum += w * n.getError(); | |
} | |
double o = n.getOutput(); | |
n.setError(o * (1 - o) * sum); | |
} | |
} | |
} | |
public void backpropErrors(ArrayList<Double> inputs) { | |
for (int i = layers.size() - 1; i >= 0; i--) { | |
for (int j = 0; j < layers.get(i).getNodes().size(); j++) { | |
Neuron n = layers.get(i).getNodes().get(j); | |
double biasDiff = learningRate * n.getError(); | |
n.setBias(n.getBias() + biasDiff); | |
for (int k = 0; k < n.getWeights().size(); k++) { | |
double prevOut; | |
if (i == 0) { | |
prevOut = inputs.get(k); | |
} else { | |
prevOut = layers.get(i - 1).getNodes().get(k).getOutput(); | |
} | |
double wDiff = learningRate * n.getError() * prevOut; | |
n.setWeight(k, n.getWeight(k) + wDiff); | |
} | |
} | |
} | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
import java.util.Random; | |
public class Neuron { | |
private ArrayList<Double> weights; | |
private double bias; | |
private double output; | |
private double error; | |
public double getOutput() { | |
return output; | |
} | |
public void setOutput(double output) { | |
this.output = output; | |
} | |
public double getError() { | |
return error; | |
} | |
public void setError(double error) { | |
this.error = error; | |
} | |
public double getBias() { | |
return bias; | |
} | |
public void setBias(double bias) { | |
this.bias = bias; | |
} | |
public ArrayList<Double> getWeights() { | |
return this.weights; | |
} | |
public double getWeight(int i) { | |
return this.weights.get(i); | |
} | |
public void setWeight(int i, double w) { | |
this.weights.set(i, w); | |
} | |
public Neuron(int weightCount) { | |
Random rnd = new Random(); | |
weights = new ArrayList<>(); | |
for (int i = 0; i < weightCount; i++) { | |
weights.add(rnd.nextDouble() * 2 - 1); | |
} | |
bias = rnd.nextDouble() * 2 - 1; | |
output = 0.0; | |
error = 0.0; | |
} | |
public double calcOut(ArrayList<Double> inputs) { | |
double out = bias; | |
for (int i = 0; i < weights.size(); i++) { | |
out += inputs.get(i) * weights.get(i); | |
} | |
return this.output = sigmoid(out); | |
} | |
public double sigmoid(double n) { | |
return 1 / (1 + Math.exp(-n)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment