Skip to content

Instantly share code, notes, and snippets.

@exception
Created February 17, 2017 14:33
Show Gist options
  • Save exception/448e1720ac2470e7c00e0ae7c3a6cbd4 to your computer and use it in GitHub Desktop.
Save exception/448e1720ac2470e7c00e0ae7c3a6cbd4 to your computer and use it in GitHub Desktop.
public abstract class LVQ<T> {
protected final T[] neurons;
protected final double learningRate;
protected LVQ(T[] neurons, double learningRate) {
this.neurons = neurons;
this.learningRate = learningRate;
}
public T[] getNeurons() {
return this.neurons;
}
public double getLearningRate() {
return this.learningRate;
}
public void train(List<T> trainingData) {
for (T data : trainingData) {
int neuron = 0;
double distance = this.distance(neuron, data);
for (int n = 1; n < this.neurons.length; n++) {
double d = this.distance(n, data);
if (d < distance) {
neuron = n;
distance = d;
}
}
this.neurons[neuron] = this.add(this.neurons[neuron], this.multiply(this.learningRate, this.subtract(data, this.neurons[neuron])));
}
}
protected abstract double distance(int neuron, T data);
protected abstract T add(T t1, T t2);
protected abstract T subtract(T t1, T t2);
protected abstract T multiply(double d, T t);
public static class DoubleLVQ extends LVQ<double[]> {
public DoubleLVQ(double[][] neurons, double learningRate) {
super(neurons, learningRate);
}
@Override
protected double distance(int neuron, double[] data) {
double[] lit = neurons[neuron];
double sum = 0.0D;
for(int i = 0; i < lit.length; ++i) {
double d = lit[i] - data[i];
sum += d * d;
}
return Math.sqrt(sum);
}
@Override
protected double[] add(double[] t1, double[] t2) {
double[] result = new double[t1.length];
for (int i = 0; i < t1.length; i++) {
result[i] = t1[i] + t2[i];
}
return result;
}
@Override
protected double[] subtract(double[] t1, double[] t2) {
double[] result = new double[t1.length];
for (int i = 0; i < t1.length; i++) {
result[i] = t2[i] - t1[i];
}
return result;
}
@Override
protected double[] multiply(double d, double[] t) {
double[] result = new double[t.length];
for (int i = 0; i < t.length; i++) {
result[i] = d * t[i];
}
return result;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment