Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save tomthetrainer/82ddaf45df855766b939976a38858156 to your computer and use it in GitHub Desktop.
Save tomthetrainer/82ddaf45df855766b939976a38858156 to your computer and use it in GitHub Desktop.
Regression example
package ai.skymind.training.solutions;
import org.apache.log4j.BasicConfigurator;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Random;
import java.io.File;
/**
* Created by tomhanlon on 2/23/17.
*/
public class AbeloneFeedForwardNetworkregression {
private static Logger log = LoggerFactory.getLogger(AbeloneFeedForwardNetworkregression.class);
public static void main(String[] args) throws Exception {
BasicConfigurator.configure();
int numLinesToSkip = 0;
String delimiter = ",";
int batchSize = 600;
int seed = 123;
int labelIndex = 8; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 30;
int numOutputs = 30;
double learningRate = 0.005;
int numInputs = 8;
int numHiddenNodes = 40;
int nEpochs = 5;
int iterations = 100;
Random rng = new Random(seed);
File traindata = new ClassPathResource("abalone/abalone_train.csv").getFile();
File testdata = new ClassPathResource("abalone/abalone_test.csv").getFile();
//final String filenameTrain = new org.nd4j.linalg.io.ClassPathResource("/classification/saturn_data_train.csv").getFile().getPath();
//final String filenameTest = new org.nd4j.linalg.io.ClassPathResource("/classification/saturn_data_eval.csv").getFile().getPath();
//Load the training data:
RecordReader rrtrain = new CSVRecordReader();
rrtrain.initialize(new FileSplit(traindata,rng));
//DataSetIterator trainIter = new RecordReaderDataSetIterator(rrtrain,batchSize,labelIndex,numClasses);
DataSetIterator trainIter = new RecordReaderDataSetIterator(rrtrain,null,batchSize,labelIndex,numClasses,true);
//Load the test/evaluation data:
RecordReader rrTest = new CSVRecordReader();
rrTest.initialize(new FileSplit(testdata));
//DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest,batchSize);
DataSetIterator testIter = new RecordReaderDataSetIterator(rrTest,null,batchSize,labelIndex,numClasses,true);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(learningRate)
.updater(Updater.NESTEROVS).momentum(0.9)
.regularization(true).l2(1e-4)
.weightInit(WeightInit.XAVIER)
.activation(Activation.TANH)
.list()
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.build())
.layer(1, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
.build())
.layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.IDENTITY)
.nIn(numHiddenNodes).nOut(1).build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(10)); //Print score every 10 parameter updates
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
model.setListeners(new StatsListener(statsStorage),new ScoreIterationListener(1));
uiServer.attach(statsStorage);
for ( int n = 0; n < nEpochs; n++) {
model.fit( trainIter );
}
System.out.println("Evaluate model....");
RegressionEvaluation eval = new RegressionEvaluation(1);
while(testIter.hasNext()){
DataSet t = testIter.next();
INDArray features = t.getFeatureMatrix();
INDArray lables = t.getLabels();
INDArray predicted = model.output(features,false);
eval.eval(lables, predicted);
//System.out.println("LABELS");
//System.out.println(lables);
//System.out.println("PREDICTIONS");
//System.out.println(predicted);
}
System.out.println(eval.stats());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment