nmse = 0.095
Last active
August 29, 2015 13:55
-
-
Save caub/8694234 to your computer and use it in GitHub Desktop.
Neural Net test on Santa Fe laser data with code.google.com/p/java-statistical-analysis-tool and sourceforge.net/projects/jarbm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.BufferedReader; | |
import java.io.File; | |
import java.io.FileReader; | |
import java.io.IOException; | |
import java.util.Arrays; | |
import java.util.concurrent.ExecutorService; | |
import java.util.concurrent.Executors; | |
import jsat.classifiers.CategoricalData; | |
import jsat.classifiers.DataPoint; | |
import jsat.classifiers.neuralnetwork.BackPropagationNet; | |
import jsat.linear.DenseVector; | |
import jsat.regression.RegressionDataSet; | |
import jsat.utils.SystemInfo; | |
import org.math.plot.Plot2DPanel; //https://code.google.com/p/jmathplot/downloads/list | |
import javax.swing.*; | |
public class FFN { | |
// http://webee.technion.ac.il/people/rmeir/Publications/KrlsReport.pdf | |
static ExecutorService ex = Executors.newFixedThreadPool(SystemInfo.LogicalCores); | |
static int k = 40; // window size | |
static int n = 4; //number of iterations in learning | |
static BackPropagationNet ffn; | |
static RegressionDataSet ds; | |
public static void main(String[] args) throws IOException, InterruptedException { | |
int trainingStart = 0; | |
int trainingEnd = 1000; | |
int testEnd = 1100; | |
//double[] ts = new double[testEnd];for (int i=0;i<ts.length;i++) ts[i] = Math.sin(i); | |
double[] ts = read("C:/Users/profit/A.dat", testEnd); //http://www-psych.stanford.edu/~andreas/Time-Series/SantaFe.html A.cont was appended to A.dat | |
ts = scale(ts, 0, 256, 0, 1); | |
double[] training = Arrays.copyOfRange(ts, trainingStart, trainingEnd); | |
double[][] data = makeLagMatrix(Arrays.copyOfRange(training, 0, training.length - 1), k); | |
double[][] target = makeLagMatrix(Arrays.copyOfRange(training, k, training.length), 1); | |
ds = new RegressionDataSet(k, new CategoricalData[0]); | |
for (int i = 0; i < data.length; i++) { | |
ds.addDataPoint(new DenseVector(data[i]), new int[0], training[k + i]); | |
} | |
ffn = new BackPropagationNet(10, 5); | |
//gridSearch(); | |
//System.out.println("cv error: " + getCrossValidationMeanError(10)); | |
ffn.train(ds, ex); | |
double[][] dat = data.clone(); | |
for (int j = 1; j <= n; j++) { | |
System.out.println("step " + j + "/" + n); | |
double[][] dat_ = new double[dat.length-1][k]; | |
for (int i =0; i<dat.length-j; i++) { | |
double r = ffn.regress(new DataPoint(new DenseVector(dat[i]), new int[0], null)); | |
System.arraycopy(dat[i], 1, dat_[i], 0, k - 1); | |
dat_[i][k - 1] = r; | |
//System.out.println(str(data[i])+" --> "+target[i][0]); | |
ds.addDataPoint(new DenseVector(dat_[i]), new int[0], training[k + i + j]); | |
} | |
dat = dat_; | |
ffn.train(ds, ex); | |
} | |
double[] rec = new double[training.length-k]; | |
for (int i = 0; i < rec.length; i++) { | |
rec[i] = ffn.regress(new DataPoint(new DenseVector(data[i]), new int[0], null)); | |
} | |
double[] targ = Arrays.copyOfRange(training, k, training.length); | |
//plot(rec, targ); | |
System.out.println("training nsme: " + getSquaredError(rec, targ)/(targ.length*variance(targ))); | |
double[] ideal = Arrays.copyOfRange(ts, trainingEnd, testEnd); | |
double[] forecast = new double[ideal.length]; | |
double[] w = Arrays.copyOfRange(training, training.length - k, training.length); | |
for (int i = 0; i < forecast.length; i++) { | |
double r = ffn.regress(new DataPoint(new DenseVector(w), new int[0], null)); | |
System.arraycopy(w, 1, w, 0, k - 1); | |
w[k - 1] = r; | |
forecast[i] = r<0?0:r; | |
} | |
plot(forecast, ideal); | |
System.out.println("forecast nmse: " + getSquaredError(forecast, ideal)/(ideal.length*variance(ideal))); | |
} | |
/*public static void gridSearch() { | |
double[] sigmas = new double[]{.8, .85, .9, .95, 1.}; | |
double[] errTols = new double[]{5e-3, 1e-2, 2e-2}; | |
GridSearch gs = new GridSearch(krls, 20);//default params overriden by what follows | |
DoubleParameter paramK = (DoubleParameter) ((Parameterized) gs.getBaseRegressor()).getParameter("RBFKernel_sigma"); | |
gs.addParameter(paramK, sigmas); | |
DoubleParameter paramE = (DoubleParameter) ((Parameterized) gs.getBaseRegressor()).getParameter("Error Tolerance"); | |
gs.addParameter(paramE, errTols); | |
System.out.println("before: " + krls.getParameter("RBFKernel_sigma").getValueString() + " " + krls.getParameter("Error Tolerance").getValueString()); | |
gs.train(ds, ex); | |
sigma = ((DoubleParameter) ((Parameterized) gs.getTrainedRegressor()).getParameter("RBFKernel_sigma")).getValue(); | |
errTol = ((DoubleParameter) ((Parameterized) gs.getTrainedRegressor()).getParameter("Error Tolerance")).getValue(); | |
((RBFKernel) krls.getKernelTrick()).setSigma(sigma); | |
krls.setErrorTolerance(errTol); | |
System.out.println("after: " + krls.getParameter("RBFKernel_sigma").getValueString() + " " + krls.getParameter("Error Tolerance").getValueString()); | |
} | |
public static double getCrossValidationMeanError(int folds) { | |
RegressionModelEvaluation crossEval = new RegressionModelEvaluation(krls, ds); | |
crossEval.evaluateCrossValidation(folds); | |
return crossEval.getMeanError(); | |
}*/ | |
public static void plot(double[]... graphs) { | |
JFrame frame = new JFrame(); | |
frame.setSize(1000, 600); | |
Plot2DPanel plot = new Plot2DPanel(); | |
for (int i = 0; i < graphs.length; i++) { | |
plot.addLinePlot("", seq(1, graphs[i].length), graphs[i]); | |
} | |
frame.setContentPane(plot); | |
frame.setVisible(true); | |
} | |
public static double[][] makeLagMatrix(double[] a, int k) { | |
double[][] m = new double[a.length - k + 1][k]; | |
for (int i = 0; i < m.length; i++) | |
m[i] = Arrays.copyOfRange(a, i, i + k); | |
return m; | |
} | |
public static double[] seq(int from, int to) { | |
double[] indexes = new double[to - from]; | |
for (int i = from; i < to; i++) { | |
indexes[i - from] = i; | |
} | |
return indexes; | |
} | |
public static double[] mult(double[] a, double v) { | |
for (int i = 0; i < a.length; i++) | |
a[i] *= v; | |
return a; | |
} | |
public static double[] scale(double[] v, double min, double max, double newmin, double newmax) { | |
double[] r = new double[v.length]; | |
double K = (newmax - newmin) / (max - min); | |
for (int i = 0; i < v.length; i++) | |
r[i] = newmin + K * (v[i] - min); | |
return r; | |
} | |
public static double getSquaredError(double[] vector1, double[] vector2) { | |
double squaredError = 0; | |
for (int i = 0; i < vector1.length; i++) { | |
squaredError += (vector1[i] - vector2[i]) * (vector1[i] - vector2[i]); | |
} | |
return squaredError; | |
} | |
public static double sum(double[] a) { | |
double s = 0; | |
for (double d : a) | |
s += d; | |
return s; | |
} | |
public static double mean(double[] a) { | |
return sum(a) / a.length; | |
} | |
public static double variance(double[] v) { | |
return variance(v, mean(v)); | |
} | |
public static double variance(double[] v, double mean) { | |
double r = 0.0; | |
for (int i = 0; i < v.length; i++) | |
r += (v[i] - mean) * (v[i] - mean); | |
return r / (v.length - 1); | |
} | |
public static double sd(double[] v) { | |
return Math.sqrt(variance(v)); | |
} | |
public static double sd(double[] v, double mean) { | |
return Math.sqrt(variance(v, mean)); | |
} | |
public static String str(double[] a) { | |
String r = a[0] + ""; | |
for (int i = 1; i < a.length; i++) | |
r += " " + a[i]; | |
return r; | |
} | |
public static void print(double[] a) { | |
System.out.println(str(a)); | |
} | |
public static double[] read(String filename, int len) throws IOException { | |
BufferedReader rdr = new BufferedReader(new FileReader(new File(filename))); | |
int i; | |
String line = ""; | |
double[] d = new double[len]; | |
for (i = 0; i < len && (line = rdr.readLine()) != null; i++) { | |
d[i] = Double.parseDouble(line); | |
} | |
rdr.close(); | |
return d; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import com.syvys.jaRBM.FactoredRBM; | |
import com.syvys.jaRBM.IO.BatchDatasourceReaderImpl; | |
import com.syvys.jaRBM.Layers.*; | |
import com.syvys.jaRBM.Math.Matrix; | |
import com.syvys.jaRBM.RBM; | |
import com.syvys.jaRBM.RBMImpl; | |
import com.syvys.jaRBM.RBMLearn.CDStochasticRBMLearner; | |
import com.syvys.jaRBM.RBMNet; | |
import com.syvys.jaRBM.RBMNetLearn.BackpropLearner; | |
import com.syvys.jaRBM.RBMNetLearn.GreedyLearner; | |
import com.syvys.jaRBM.RBMNetLearn.StochasticMetaDescentLearner; | |
import org.math.plot.Plot2DPanel; | |
import javax.swing.*; | |
import java.awt.*; | |
import java.io.BufferedReader; | |
import java.io.File; | |
import java.io.FileReader; | |
import java.io.IOException; | |
import java.util.Arrays; | |
import java.util.Random; | |
public class NN { | |
static int n = 0; | |
static int k = 40; // window size | |
static int k2 = 12; //layer 2 number of units | |
static int k3 = 8; //layer 3 number of units | |
public static void main(String[] args) throws Exception { | |
int trainingStart = 0; | |
int trainingEnd = 1000; | |
int testEnd = 1100; | |
double[] ts = read("C:/Users/profit/A.dat", testEnd); //http://www-psych.stanford.edu/~andreas/Time-Series/SantaFe.html A.cont was appended to A.dat | |
ts = scale(ts, 0, 256, 0, 1); | |
double[] training = Arrays.copyOfRange(ts, trainingStart, trainingEnd); | |
double[][] data = makeLagMatrix(Arrays.copyOfRange(training, 0, training.length - 1), k); | |
double[][] target = makeLagMatrix(Arrays.copyOfRange(training, k, training.length), 1); | |
RBMImpl rbm1 = new RBMImpl(new LinearLayer(k), new LogisticLayer(k2)); | |
RBMNet rbmnet = new RBMNet(rbm1); | |
rbm1.setLearningRate(0.1); | |
rbm1.setMomentum(.6); | |
rbm1.setWeightCost(0.000002); | |
RBMImpl rbm2 = new RBMImpl(new StochasticBinaryLayer(rbmnet.getNumOfTopHiddenUnits()), new LogisticLayer(k3)); | |
rbmnet.AddRBM(rbm2); | |
rbm2.setLearningRate(0.1); | |
rbm2.setMomentum(.6); | |
rbm2.setWeightCost(0.000002); | |
RBMImpl rbm3 = new RBMImpl(new StochasticBinaryLayer(rbmnet.getNumOfTopHiddenUnits()), new LinearLayer(1)); | |
rbmnet.AddRBM(rbm3); | |
rbm3.setLearningRate(0.1); | |
rbm3.setMomentum(.6); | |
rbm3.setWeightCost(0.000002); | |
BatchDatasourceReaderImpl batchReader = new BatchDatasourceReaderImpl(data); | |
//PartiallySupervisedGreedyLearner glearner = new PartiallySupervisedGreedyLearner(myrbmnet, batchdata, batchtarget, GREEDY_BATCH_SIZE); | |
GreedyLearner glearner = new GreedyLearner(rbmnet, batchReader, 1); | |
System.out.println("greedy training " + glearner.Learn(100, rbmnet.getNumRBMs() - 2)); | |
//rbmnet = glearner.getLearnedRBMNet(); | |
BackpropLearner mylearner; | |
StochasticMetaDescentLearner.DEFAULT_ETA_INIT = 0.1; | |
StochasticMetaDescentLearner.DEFAULT_MU = 0.1; | |
BatchDatasourceReaderImpl batchTargets = new BatchDatasourceReaderImpl(target); | |
//mylearner = new BackpropLearner(rbmnet, batchReader, batchTargets, 1); | |
//mylearner = new ConjugateGradientLearner(myrbmnet, batchdata.clone(), batchtarget.clone(), 1); | |
//mylearner = new ConjugateGradientPRLearner(myrbmnet, batchdata.clone(), batchtarget.clone(), 1); | |
mylearner = new StochasticMetaDescentLearner(rbmnet, batchReader.clone(), batchTargets.clone(), 1); | |
double error = 0; | |
for (int i = 0; i <= 300; i++) { | |
error = mylearner.Learn(); | |
if (i % 10 == 0) | |
System.out.println("Epoch " + i + ": Backprop training error = " + error); | |
} | |
double[][] outputs = rbmnet.getHiddenActivitiesFromVisibleData(data); | |
double[] rec = col(outputs, 0); | |
double[] targ = Arrays.copyOfRange(training, k, training.length); | |
plot(rec, targ); | |
System.out.println("training nsme: " + getSquaredError(rec, targ)/(targ.length*variance(targ))); | |
double[] ideal = Arrays.copyOfRange(ts, trainingEnd, testEnd); | |
double[] forecast = new double[ideal.length]; | |
double[] w = Arrays.copyOfRange(training, training.length - k, training.length); | |
for (int i = 0; i < forecast.length; i++) { | |
double r = rbmnet.getHiddenActivitiesFromVisibleData(w)[0]; | |
System.arraycopy(w, 1, w, 0, k - 1); | |
w[k - 1] = r; | |
forecast[i] = r; | |
} | |
plot(forecast, ideal); | |
System.out.println("forecast nmse: " + getSquaredError(forecast, ideal)/(ideal.length*variance(ideal))); | |
} | |
public static void plot(double[]... graphs) { | |
JFrame frame = new JFrame(); | |
frame.setSize(1000, 600); | |
Plot2DPanel plot = new Plot2DPanel(); | |
for (int i = 0; i < graphs.length; i++) { | |
plot.addLinePlot("", seq(1, graphs[i].length), graphs[i]); | |
} | |
frame.setContentPane(plot); | |
frame.setVisible(true); | |
} | |
public static double[] col(double[][] a, int j){ | |
double[] c = new double[a.length]; | |
for (int i=0; i<a.length; i++) | |
c[i] = a[i][j]; | |
return c; | |
} | |
public static double[][] makeLagMatrix(double[] a, int k) { | |
double[][] m = new double[a.length - k + 1][k]; | |
for (int i = 0; i < m.length; i++) | |
m[i] = Arrays.copyOfRange(a, i, i + k); | |
return m; | |
} | |
public static double[] seq(int from, int to) { | |
double[] indexes = new double[to - from]; | |
for (int i = from; i < to; i++) { | |
indexes[i - from] = i; | |
} | |
return indexes; | |
} | |
public static double[] mult(double[] a, double v) { | |
for (int i = 0; i < a.length; i++) | |
a[i] *= v; | |
return a; | |
} | |
public static double[] scale(double[] v, double min, double max, double newmin, double newmax) { | |
double[] r = new double[v.length]; | |
double K = (newmax - newmin) / (max - min); | |
for (int i = 0; i < v.length; i++) | |
r[i] = newmin + K * (v[i] - min); | |
return r; | |
} | |
public static double getSquaredError(double[] vector1, double[] vector2) { | |
double squaredError = 0; | |
for (int i = 0; i < vector1.length; i++) { | |
squaredError += (vector1[i] - vector2[i]) * (vector1[i] - vector2[i]); | |
} | |
return squaredError; | |
} | |
public static double sum(double[] a) { | |
double s = 0; | |
for (double d : a) | |
s += d; | |
return s; | |
} | |
public static double mean(double[] a) { | |
return sum(a) / a.length; | |
} | |
public static double variance(double[] v) { | |
return variance(v, mean(v)); | |
} | |
public static double variance(double[] v, double mean) { | |
double r = 0.0; | |
for (int i = 0; i < v.length; i++) | |
r += (v[i] - mean) * (v[i] - mean); | |
return r / (v.length - 1); | |
} | |
public static double sd(double[] v) { | |
return Math.sqrt(variance(v)); | |
} | |
public static double sd(double[] v, double mean) { | |
return Math.sqrt(variance(v, mean)); | |
} | |
public static String str(double[] a) { | |
String r = a[0] + ""; | |
for (int i = 1; i < a.length; i++) | |
r += " " + a[i]; | |
return r; | |
} | |
public static void print(double[] a) { | |
System.out.println(str(a)); | |
} | |
public static double[] read(String filename, int len) throws IOException { | |
BufferedReader rdr = new BufferedReader(new FileReader(new File(filename))); | |
int i; | |
String line = ""; | |
double[] d = new double[len]; | |
for (i = 0; i < len && (line = rdr.readLine()) != null; i++) { | |
d[i] = Double.parseDouble(line); | |
} | |
rdr.close(); | |
return d; | |
} | |
/* | |
double[][] dataX=Arrays.copyOf(data, data.length); | |
for (int i=1; i<3; i++){ | |
double[][] v = rbmnet.getHiddenActivitiesFromVisibleData(dataX); | |
for (int j=0; j<dataX.length-i;j++){ | |
System.arraycopy(dataX[j], 1, dataX[j], 0, k-1); | |
dataX[j][k-1] = v[j][0]; | |
} | |
mylearner.Learn(Arrays.copyOfRange(dataX, 0, dataX.length-i), Arrays.copyOfRange(target, i, target.length)); | |
}*/ | |
public static int[] minmaxBars(double[] a, int nMax) { //grouped by 4 | |
int li = 2, hi = 1; | |
double l = a[li], h = a[hi]; | |
for (int i = 8; i < nMax; i += 4) { | |
if (a[i - 2] < l) { | |
li = i - 2; | |
l = a[li]; | |
} | |
if (a[i - 3] > h) { | |
hi = i - 3; | |
h = a[hi]; | |
} | |
} | |
return new int[]{li, hi}; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment