Created
February 17, 2017 14:33
-
-
Save exception/448e1720ac2470e7c00e0ae7c3a6cbd4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public abstract class LVQ<T> { | |
protected final T[] neurons; | |
protected final double learningRate; | |
protected LVQ(T[] neurons, double learningRate) { | |
this.neurons = neurons; | |
this.learningRate = learningRate; | |
} | |
public T[] getNeurons() { | |
return this.neurons; | |
} | |
public double getLearningRate() { | |
return this.learningRate; | |
} | |
public void train(List<T> trainingData) { | |
for (T data : trainingData) { | |
int neuron = 0; | |
double distance = this.distance(neuron, data); | |
for (int n = 1; n < this.neurons.length; n++) { | |
double d = this.distance(n, data); | |
if (d < distance) { | |
neuron = n; | |
distance = d; | |
} | |
} | |
this.neurons[neuron] = this.add(this.neurons[neuron], this.multiply(this.learningRate, this.subtract(data, this.neurons[neuron]))); | |
} | |
} | |
protected abstract double distance(int neuron, T data); | |
protected abstract T add(T t1, T t2); | |
protected abstract T subtract(T t1, T t2); | |
protected abstract T multiply(double d, T t); | |
public static class DoubleLVQ extends LVQ<double[]> { | |
public DoubleLVQ(double[][] neurons, double learningRate) { | |
super(neurons, learningRate); | |
} | |
@Override | |
protected double distance(int neuron, double[] data) { | |
double[] lit = neurons[neuron]; | |
double sum = 0.0D; | |
for(int i = 0; i < lit.length; ++i) { | |
double d = lit[i] - data[i]; | |
sum += d * d; | |
} | |
return Math.sqrt(sum); | |
} | |
@Override | |
protected double[] add(double[] t1, double[] t2) { | |
double[] result = new double[t1.length]; | |
for (int i = 0; i < t1.length; i++) { | |
result[i] = t1[i] + t2[i]; | |
} | |
return result; | |
} | |
@Override | |
protected double[] subtract(double[] t1, double[] t2) { | |
double[] result = new double[t1.length]; | |
for (int i = 0; i < t1.length; i++) { | |
result[i] = t2[i] - t1[i]; | |
} | |
return result; | |
} | |
@Override | |
protected double[] multiply(double d, double[] t) { | |
double[] result = new double[t.length]; | |
for (int i = 0; i < t.length; i++) { | |
result[i] = d * t[i]; | |
} | |
return result; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment