NMSE = 0.039
Last active
August 29, 2015 13:55
-
-
Save caub/8692363 to your computer and use it in GitHub Desktop.
Kernel Recursive Least Square using JSAT https://code.google.com/p/java-statistical-analysis-tool/, and test on Santa Fe laser data
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.BufferedReader; | |
import java.io.File; | |
import java.io.FileReader; | |
import java.io.IOException; | |
import java.util.Arrays; | |
import java.util.concurrent.ExecutorService; | |
import java.util.concurrent.Executors; | |
import jsat.classifiers.CategoricalData; | |
import jsat.classifiers.DataPoint; | |
import jsat.distributions.kernels.RBFKernel; | |
import jsat.linear.DenseVector; | |
import jsat.parameters.DoubleParameter; | |
import jsat.parameters.GridSearch; | |
import jsat.parameters.Parameterized; | |
import jsat.regression.KernelRLS; | |
import jsat.regression.RegressionDataSet; | |
import jsat.regression.RegressionModelEvaluation; | |
import jsat.regression.Regressor; | |
import jsat.utils.SystemInfo; | |
import org.math.plot.Plot2DPanel; //https://code.google.com/p/jmathplot/downloads/list | |
import javax.swing.*; | |
public class Pred { | |
// http://webee.technion.ac.il/people/rmeir/Publications/KrlsReport.pdf | |
static ExecutorService ex = Executors.newFixedThreadPool(SystemInfo.LogicalCores); | |
static int k = 40; // window size | |
static int n = 7; //number of iterations in learning | |
static double sigma = 0.9; //rbf kernel param | |
static double errTol = 0.01; | |
static KernelRLS krls; | |
static RegressionDataSet ds; | |
public static void main(String[] args) throws IOException, InterruptedException { | |
int trainingStart = 0; | |
int trainingEnd = 1000; | |
int testEnd = 1100; | |
//double[] ts = new double[testEnd];for (int i=0;i<ts.length;i++) ts[i] = Math.sin(i); | |
double[] ts = read("C:/Users/profit/A.dat", testEnd); //http://www-psych.stanford.edu/~andreas/Time-Series/SantaFe.html A.cont was appended to A.dat | |
ts = scale(ts, 0, 256, 0, 1); | |
double[] training = Arrays.copyOfRange(ts, trainingStart, trainingEnd); | |
double[][] data = makeLagMatrix(Arrays.copyOfRange(training, 0, training.length - 1), k); | |
//double[][] target = makeLagMatrix(Arrays.copyOfRange(training, k, training.length), 1); | |
ds = new RegressionDataSet(k, new CategoricalData[0]); | |
for (int i = 0; i < data.length; i++) { | |
ds.addDataPoint(new DenseVector(data[i]), new int[0], training[k + i]); | |
} | |
krls = new KernelRLS(new RBFKernel(sigma), errTol); | |
//gridSearch(); | |
//System.out.println("cv error: " + getCrossValidationMeanError(10)); | |
krls.train(ds, ex); | |
double[][] dat = data.clone(); | |
for (int j = 1; j <= n; j++) { | |
System.out.println("step " + j + "/" + n); | |
double[][] dat_ = new double[dat.length-1][k]; | |
for (int i =0; i<dat_.length; i++) { | |
double r = krls.regress(new DataPoint(new DenseVector(dat[i]), new int[0], null)); | |
System.arraycopy(dat[i], 1, dat_[i], 0, k - 1); | |
dat_[i][k - 1] = r; | |
//System.out.println(str(data[i])+" --> "+target[i][0]); | |
ds.addDataPoint(new DenseVector(dat_[i]), new int[0], training[k + i + j]); | |
} | |
dat = dat_; | |
krls.train(ds, ex); | |
/*for (int i =0; i<dat_.length; i++) { | |
double r = krls.regress(new DataPoint(new DenseVector(dat[i]), new int[0], null)); | |
System.arraycopy(dat[i], 1, dat_[i], 0, k - 1); | |
dat_[i][k - 1] = r; | |
//System.out.println(str(data[i])+" --> "+target[i][0]); | |
krls.update(new DataPoint(new DenseVector(dat_[i]), new int[0], null), training[k + i + j]); | |
} | |
dat = dat_;*/ | |
} | |
double[] rec = new double[training.length-k]; | |
for (int i = 0; i < rec.length; i++) { | |
rec[i] = krls.regress(new DataPoint(new DenseVector(data[i]), new int[0], null)); | |
} | |
double[] targ = Arrays.copyOfRange(training, k, training.length); | |
//plot(rec, targ); | |
System.out.println("training nsme: " + getSquaredError(rec, targ)/(targ.length*variance(targ))); | |
double[] ideal = Arrays.copyOfRange(ts, trainingEnd, testEnd); | |
double[] forecast = new double[ideal.length]; | |
double[] w = Arrays.copyOfRange(training, training.length - k, training.length); | |
for (int i = 0; i < forecast.length; i++) { | |
double r = krls.regress(new DataPoint(new DenseVector(w), new int[0], null)); | |
System.arraycopy(w, 1, w, 0, k - 1); | |
w[k - 1] = r; | |
forecast[i] = r<0?0:r; | |
} | |
plot(forecast, ideal); | |
System.out.println("forecast nmse: " + getSquaredError(forecast, ideal)/(ideal.length*variance(ideal))); | |
} | |
public static void gridSearch() { | |
double[] sigmas = new double[]{.8, .85, .9, .95, 1.}; | |
double[] errTols = new double[]{5e-3, 1e-2, 2e-2}; | |
GridSearch gs = new GridSearch(krls, 20);//default params overriden by what follows | |
DoubleParameter paramK = (DoubleParameter) ((Parameterized) gs.getBaseRegressor()).getParameter("RBFKernel_sigma"); | |
gs.addParameter(paramK, sigmas); | |
DoubleParameter paramE = (DoubleParameter) ((Parameterized) gs.getBaseRegressor()).getParameter("Error Tolerance"); | |
gs.addParameter(paramE, errTols); | |
System.out.println("before: " + krls.getParameter("RBFKernel_sigma").getValueString() + " " + krls.getParameter("Error Tolerance").getValueString()); | |
gs.train(ds, ex); | |
sigma = ((DoubleParameter) ((Parameterized) gs.getTrainedRegressor()).getParameter("RBFKernel_sigma")).getValue(); | |
errTol = ((DoubleParameter) ((Parameterized) gs.getTrainedRegressor()).getParameter("Error Tolerance")).getValue(); | |
((RBFKernel) krls.getKernelTrick()).setSigma(sigma); | |
krls.setErrorTolerance(errTol); | |
System.out.println("after: " + krls.getParameter("RBFKernel_sigma").getValueString() + " " + krls.getParameter("Error Tolerance").getValueString()); | |
} | |
public static double getCrossValidationMeanError(int folds) { | |
RegressionModelEvaluation crossEval = new RegressionModelEvaluation(krls, ds); | |
crossEval.evaluateCrossValidation(folds); | |
return crossEval.getMeanError(); | |
} | |
public static void plot(double[]... graphs) { | |
JFrame frame = new JFrame(); | |
frame.setSize(1000, 600); | |
Plot2DPanel plot = new Plot2DPanel(); | |
for (int i = 0; i < graphs.length; i++) { | |
plot.addLinePlot("", seq(1, graphs[i].length), graphs[i]); | |
} | |
frame.setContentPane(plot); | |
frame.setVisible(true); | |
} | |
public static double[][] makeLagMatrix(double[] a, int k) { | |
double[][] m = new double[a.length - k + 1][k]; | |
for (int i = 0; i < m.length; i++) | |
m[i] = Arrays.copyOfRange(a, i, i + k); | |
return m; | |
} | |
public static double[] seq(int from, int to) { | |
double[] indexes = new double[to - from]; | |
for (int i = from; i < to; i++) { | |
indexes[i - from] = i; | |
} | |
return indexes; | |
} | |
public static double[] mult(double[] a, double v) { | |
for (int i = 0; i < a.length; i++) | |
a[i] *= v; | |
return a; | |
} | |
public static double[] scale(double[] v, double min, double max, double newmin, double newmax) { | |
double[] r = new double[v.length]; | |
double K = (newmax - newmin) / (max - min); | |
for (int i = 0; i < v.length; i++) | |
r[i] = newmin + K * (v[i] - min); | |
return r; | |
} | |
public static double getSquaredError(double[] vector1, double[] vector2) { | |
double squaredError = 0; | |
for (int i = 0; i < vector1.length; i++) { | |
squaredError += (vector1[i] - vector2[i]) * (vector1[i] - vector2[i]); | |
} | |
return squaredError; | |
} | |
public static double sum(double[] a) { | |
double s = 0; | |
for (double d : a) | |
s += d; | |
return s; | |
} | |
public static double mean(double[] a) { | |
return sum(a) / a.length; | |
} | |
public static double variance(double[] v) { | |
return variance(v, mean(v)); | |
} | |
public static double variance(double[] v, double mean) { | |
double r = 0.0; | |
for (int i = 0; i < v.length; i++) | |
r += (v[i] - mean) * (v[i] - mean); | |
return r / (v.length - 1); | |
} | |
public static double sd(double[] v) { | |
return Math.sqrt(variance(v)); | |
} | |
public static double sd(double[] v, double mean) { | |
return Math.sqrt(variance(v, mean)); | |
} | |
public static String str(double[] a) { | |
String r = a[0] + ""; | |
for (int i = 1; i < a.length; i++) | |
r += " " + a[i]; | |
return r; | |
} | |
public static void print(double[] a) { | |
System.out.println(str(a)); | |
} | |
public static double[] read(String filename, int len) throws IOException { | |
BufferedReader rdr = new BufferedReader(new FileReader(new File(filename))); | |
int i; | |
String line = ""; | |
double[] d = new double[len]; | |
for (i = 0; i < len && (line = rdr.readLine()) != null; i++) { | |
d[i] = Double.parseDouble(line); | |
} | |
rdr.close(); | |
return d; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment