Skip to content

Instantly share code, notes, and snippets.

@yeison
Last active August 29, 2015 14:01
Show Gist options
  • Save yeison/7af98cdb69a213f32cae to your computer and use it in GitHub Desktop.
Save yeison/7af98cdb69a213f32cae to your computer and use it in GitHub Desktop.
DBN 53% error
package classifier.vasilev;
import com.github.neuralnetworks.architecture.Matrix;
import com.github.neuralnetworks.training.TrainingInputData;
import com.github.neuralnetworks.training.TrainingInputProvider;
/**
* Simple input provider for testing purposes.
* Training and target data are two dimensional float arrays
*/
public class SimpleInputProvider implements TrainingInputProvider {
private static final long serialVersionUID = 1L;
private float[][] input;
private float[][] target;
private SimpleTrainingInputData data;
private int count;
private int miniBatchSize;
private int current;
public SimpleInputProvider(float[][] input, float[][] target, int count, int miniBatchSize) {
super();
this.count = count;
this.miniBatchSize = miniBatchSize;
data = new SimpleTrainingInputData(null, null);
if (input != null) {
this.input = input;
data.setInput(new Matrix(input[0].length, miniBatchSize));
}
if (target != null) {
this.target = target;
data.setTarget(new Matrix(target[0].length, miniBatchSize));
}
}
@Override
public int getInputSize() {
return count;
}
@Override
public void reset() {
current = 0;
}
@Override
public TrainingInputData getNextInput() {
if (current < count) {
for (int i = 0; i < miniBatchSize; i++, current++) {
if (input != null) {
if(data == null || input == null || input[current % input.length] == null )
System.out.println("data or input is null");
for (int j = 0; j < input[current % input.length].length; j++) {
data.getInput().set(j, i, input[current % input.length][j]);
}
}
if (target != null) {
for (int j = 0; j < target[current % target.length].length; j++) {
data.getTarget().set(j, i, target[current % target.length][j]);
}
}
}
return data;
}
return null;
}
private static class SimpleTrainingInputData implements TrainingInputData {
private static final long serialVersionUID = 1L;
private Matrix input;
private Matrix target;
public SimpleTrainingInputData(Matrix input, Matrix target) {
super();
this.input = input;
this.target = target;
}
@Override
public Matrix getInput() {
return input;
}
public void setInput(Matrix input) {
this.input = input;
}
@Override
public Matrix getTarget() {
return target;
}
public void setTarget(Matrix target) {
this.target = target;
}
}
}
import classifier.vasilev.SimpleInputProvider;
import com.amd.aparapi.Kernel;
import com.github.neuralnetworks.architecture.NeuralNetwork;
import com.github.neuralnetworks.architecture.types.DBN;
import com.github.neuralnetworks.architecture.types.NNFactory;
import com.github.neuralnetworks.input.MultipleNeuronsOutputError;
import com.github.neuralnetworks.training.OneStepTrainer;
import com.github.neuralnetworks.training.TrainerFactory;
import com.github.neuralnetworks.training.backpropagation.BackPropagationTrainer;
import com.github.neuralnetworks.training.events.LogTrainingListener;
import com.github.neuralnetworks.training.random.MersenneTwisterRandomInitializer;
import com.github.neuralnetworks.training.random.NNRandomInitializer;
import com.github.neuralnetworks.training.rbm.AparapiCDTrainer;
import com.github.neuralnetworks.training.rbm.DBNTrainer;
import com.github.neuralnetworks.util.Environment;
import org.apache.log4j.Logger;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
/**
* Created by yeison on 5/14/2014.
*/
public class TestClassifier {
private static final Logger logger = Logger.getLogger(TestClassifier.class);
public static void main(String[] args){
float[][] data = loadTrainingData("training.csv");
runVasilevClassifier(data);
}
private static float[][] loadTrainingData(String fileName) {
ArrayList<float[]> trainingData = new ArrayList<>();
logger.info("Loading data from file: " + fileName);
try(BufferedReader br = new BufferedReader(new FileReader(new File(fileName)))) {
for(String line; (line = br.readLine()) != null; ) {
boolean skip = false;
String[] values = line.split(",");
float[] event = new float[values.length];
for (int i = 0; i < values.length; i++) {
event[i] = Float.valueOf(values[i]);
if(event[i] == -999.0) {
skip = true;
break;
}
}
if(skip) {
continue;
} else {
trainingData.add(event);
}
}
} catch (IOException e){
logger.error("Error reading file: " + fileName);
}
logger.info("Finished loading data.");
return trainingData.toArray(new float[0][]);
}
public static void runVasilevClassifier(float[][] data){
final float[][] input = new float[data.length][];
final float[][] target = new float[input.length][];
/* Create a new matrix with the first and last two columns removed. */
for (int i = 0; i < data.length; i++) {
input[i] = Arrays.copyOfRange(data[i], 1, data[i].length-2);
}
/* The last column contains the labels. Place it in it's own vector called target. */
int labelIndex = data[0].length-1;
for (int i = 0; i < data.length; i++) {
if(data[i][labelIndex] == 1.0) {
target[i] = new float[]{1, 0};
}else{
target[i] = new float[]{0, 1};
}
}
// int segment = (int) (input.length * 1);
//
// float[][] trainInputSet = Arrays.copyOfRange(input, 0, segment);
// float[][] trainTargetSet = Arrays.copyOfRange(target, 0, segment);
//
// float[][] testInputSet = Arrays.copyOfRange(input, segment, input.length);
// float[][] testTargetSet = Arrays.copyOfRange(target, segment, target.length);
// NeuralNetworkImpl mlp = NNFactory.mlpSigmoid(new int[]{30, 25, 15, 2}, false);
DBN dbn = NNFactory.dbn(new int[]{30, 121, 2}, true);
dbn.setLayerCalculator(NNFactory.lcSigmoid(dbn, null));
SimpleInputProvider trainInputProvider = new SimpleInputProvider(input, target, input.length, 1);
SimpleInputProvider testInputProvider = new SimpleInputProvider(input, target, input.length, 1);
// rbm trainers for each layer
AparapiCDTrainer firstTrainer = TrainerFactory.cdSigmoidTrainer(dbn.getFirstNeuralNetwork(), null, null, null, new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.01f, 0.5f, 0f, 0f, 1, true);
AparapiCDTrainer lastTrainer = TrainerFactory.cdSigmoidTrainer(dbn.getLastNeuralNetwork(), null, null, null, new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.01f, 0.5f, 0f, 0f, 1, true);
Map<NeuralNetwork, OneStepTrainer<?>> map = new HashMap<>();
map.put(dbn.getFirstNeuralNetwork(), firstTrainer);
map.put(dbn.getLastNeuralNetwork(), lastTrainer);
// deep trainer
DBNTrainer deepTrainer = TrainerFactory.dbnTrainer(dbn, map, trainInputProvider, null, null);
Environment.getInstance().setExecutionMode(Kernel.EXECUTION_MODE.SEQ);
// layer pre-training
deepTrainer.train();
// fine tuning backpropagation
BackPropagationTrainer<?> bpt = TrainerFactory.backPropagation(dbn, trainInputProvider, testInputProvider, new MultipleNeuronsOutputError(), new NNRandomInitializer(new MersenneTwisterRandomInitializer(-0.01f, 0.01f)), 0.01f, 0.5f, 0f, 0f);
// log data
bpt.addEventListener(new LogTrainingListener(Thread.currentThread().getStackTrace()[1].getMethodName(), true, true));
// training
bpt.train();
// testing
bpt.test();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment