Last active
August 29, 2015 14:01
-
-
Save yeison/7af98cdb69a213f32cae to your computer and use it in GitHub Desktop.
DBN 53% error
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package classifier.vasilev; | |
import com.github.neuralnetworks.architecture.Matrix; | |
import com.github.neuralnetworks.training.TrainingInputData; | |
import com.github.neuralnetworks.training.TrainingInputProvider; | |
/** | |
* Simple input provider for testing purposes. | |
* Training and target data are two dimensional float arrays | |
*/ | |
public class SimpleInputProvider implements TrainingInputProvider { | |
private static final long serialVersionUID = 1L; | |
private float[][] input; | |
private float[][] target; | |
private SimpleTrainingInputData data; | |
private int count; | |
private int miniBatchSize; | |
private int current; | |
public SimpleInputProvider(float[][] input, float[][] target, int count, int miniBatchSize) { | |
super(); | |
this.count = count; | |
this.miniBatchSize = miniBatchSize; | |
data = new SimpleTrainingInputData(null, null); | |
if (input != null) { | |
this.input = input; | |
data.setInput(new Matrix(input[0].length, miniBatchSize)); | |
} | |
if (target != null) { | |
this.target = target; | |
data.setTarget(new Matrix(target[0].length, miniBatchSize)); | |
} | |
} | |
@Override | |
public int getInputSize() { | |
return count; | |
} | |
@Override | |
public void reset() { | |
current = 0; | |
} | |
@Override | |
public TrainingInputData getNextInput() { | |
if (current < count) { | |
for (int i = 0; i < miniBatchSize; i++, current++) { | |
if (input != null) { | |
if(data == null || input == null || input[current % input.length] == null ) | |
System.out.println("data or input is null"); | |
for (int j = 0; j < input[current % input.length].length; j++) { | |
data.getInput().set(j, i, input[current % input.length][j]); | |
} | |
} | |
if (target != null) { | |
for (int j = 0; j < target[current % target.length].length; j++) { | |
data.getTarget().set(j, i, target[current % target.length][j]); | |
} | |
} | |
} | |
return data; | |
} | |
return null; | |
} | |
private static class SimpleTrainingInputData implements TrainingInputData { | |
private static final long serialVersionUID = 1L; | |
private Matrix input; | |
private Matrix target; | |
public SimpleTrainingInputData(Matrix input, Matrix target) { | |
super(); | |
this.input = input; | |
this.target = target; | |
} | |
@Override | |
public Matrix getInput() { | |
return input; | |
} | |
public void setInput(Matrix input) { | |
this.input = input; | |
} | |
@Override | |
public Matrix getTarget() { | |
return target; | |
} | |
public void setTarget(Matrix target) { | |
this.target = target; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import classifier.vasilev.SimpleInputProvider; | |
import com.amd.aparapi.Kernel; | |
import com.github.neuralnetworks.architecture.NeuralNetwork; | |
import com.github.neuralnetworks.architecture.types.DBN; | |
import com.github.neuralnetworks.architecture.types.NNFactory; | |
import com.github.neuralnetworks.input.MultipleNeuronsOutputError; | |
import com.github.neuralnetworks.training.OneStepTrainer; | |
import com.github.neuralnetworks.training.TrainerFactory; | |
import com.github.neuralnetworks.training.backpropagation.BackPropagationTrainer; | |
import com.github.neuralnetworks.training.events.LogTrainingListener; | |
import com.github.neuralnetworks.training.random.MersenneTwisterRandomInitializer; | |
import com.github.neuralnetworks.training.random.NNRandomInitializer; | |
import com.github.neuralnetworks.training.rbm.AparapiCDTrainer; | |
import com.github.neuralnetworks.training.rbm.DBNTrainer; | |
import com.github.neuralnetworks.util.Environment; | |
import org.apache.log4j.Logger; | |
import java.io.BufferedReader; | |
import java.io.File; | |
import java.io.FileReader; | |
import java.io.IOException; | |
import java.util.ArrayList; | |
import java.util.Arrays; | |
import java.util.HashMap; | |
import java.util.Map; | |
/** | |
* Created by yeison on 5/14/2014. | |
*/ | |
public class TestClassifier { | |
private static final Logger logger = Logger.getLogger(TestClassifier.class); | |
public static void main(String[] args){ | |
float[][] data = loadTrainingData("training.csv"); | |
runVasilevClassifier(data); | |
} | |
private static float[][] loadTrainingData(String fileName) { | |
ArrayList<float[]> trainingData = new ArrayList<>(); | |
logger.info("Loading data from file: " + fileName); | |
try(BufferedReader br = new BufferedReader(new FileReader(new File(fileName)))) { | |
for(String line; (line = br.readLine()) != null; ) { | |
boolean skip = false; | |
String[] values = line.split(","); | |
float[] event = new float[values.length]; | |
for (int i = 0; i < values.length; i++) { | |
event[i] = Float.valueOf(values[i]); | |
if(event[i] == -999.0) { | |
skip = true; | |
break; | |
} | |
} | |
if(skip) { | |
continue; | |
} else { | |
trainingData.add(event); | |
} | |
} | |
} catch (IOException e){ | |
logger.error("Error reading file: " + fileName); | |
} | |
logger.info("Finished loading data."); | |
return trainingData.toArray(new float[0][]); | |
} | |
public static void runVasilevClassifier(float[][] data){ | |
final float[][] input = new float[data.length][]; | |
final float[][] target = new float[input.length][]; | |
/* Create a new matrix with the first and last two columns removed. */ | |
for (int i = 0; i < data.length; i++) { | |
input[i] = Arrays.copyOfRange(data[i], 1, data[i].length-2); | |
} | |
/* The last column contains the labels. Place it in it's own vector called target. */ | |
int labelIndex = data[0].length-1; | |
for (int i = 0; i < data.length; i++) { | |
if(data[i][labelIndex] == 1.0) { | |
target[i] = new float[]{1, 0}; | |
}else{ | |
target[i] = new float[]{0, 1}; | |
} | |
} | |
// int segment = (int) (input.length * 1); | |
// | |
// float[][] trainInputSet = Arrays.copyOfRange(input, 0, segment); | |
// float[][] trainTargetSet = Arrays.copyOfRange(target, 0, segment); | |
// | |
// float[][] testInputSet = Arrays.copyOfRange(input, segment, input.length); | |
// float[][] testTargetSet = Arrays.copyOfRange(target, segment, target.length); | |
// NeuralNetworkImpl mlp = NNFactory.mlpSigmoid(new int[]{30, 25, 15, 2}, false); | |
DBN dbn = NNFactory.dbn(new int[]{30, 121, 2}, true); | |
dbn.setLayerCalculator(NNFactory.lcSigmoid(dbn, null)); | |
SimpleInputProvider trainInputProvider = new SimpleInputProvider(input, target, input.length, 1); | |
SimpleInputProvider testInputProvider = new SimpleInputProvider(input, target, input.length, 1); | |
// rbm trainers for each layer | |
AparapiCDTrainer firstTrainer = TrainerFactory.cdSigmoidTrainer(dbn.getFirstNeuralNetwork(), null, null, null, new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.01f, 0.5f, 0f, 0f, 1, true); | |
AparapiCDTrainer lastTrainer = TrainerFactory.cdSigmoidTrainer(dbn.getLastNeuralNetwork(), null, null, null, new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.01f, 0.5f, 0f, 0f, 1, true); | |
Map<NeuralNetwork, OneStepTrainer<?>> map = new HashMap<>(); | |
map.put(dbn.getFirstNeuralNetwork(), firstTrainer); | |
map.put(dbn.getLastNeuralNetwork(), lastTrainer); | |
// deep trainer | |
DBNTrainer deepTrainer = TrainerFactory.dbnTrainer(dbn, map, trainInputProvider, null, null); | |
Environment.getInstance().setExecutionMode(Kernel.EXECUTION_MODE.SEQ); | |
// layer pre-training | |
deepTrainer.train(); | |
// fine tuning backpropagation | |
BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(dbn, trainInputProvider, testInputProvider, new MultipleNeuronsOutputError(), new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.01f, 0.5f, 0f, 0f); | |
// log data | |
bpt.addEventListener(new LogTrainingListener(Thread.currentThread().getStackTrace()[1].getMethodName(), true, true)); | |
// training | |
bpt.train(); | |
// testing | |
bpt.test(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment