Created
November 5, 2013 04:43
-
-
Save Alrecenk/7314019 to your computer and use it in GitHub Desktop.
The core learning algorithm for the rotation forest that calculates the best split based on approximate information gain.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//splits this node if it should and returns whether it did | |
//data is assumed to be a set of presorted lists where data[k][j] is the jth element of data when sorted by axis[k] | |
public boolean split(int minpoints){ | |
//if already split or one class or not enough points remaining then don't split | |
if (branchnode || totalpositive == 0 || totalnegative == 0 || totalpositive + totalnegative < minpoints){ | |
return false; | |
}else{ | |
int bestaxis = -1, splitafter=-1; | |
double bestscore = Double.MAX_VALUE;//any valid split will beat no split | |
int bestLp=0, bestLn=0; | |
for (int k = 0; k < data.length; k++){//try each axis | |
int Lp = 0, Ln = 0, Rp = totalpositive, Rn = totalnegative;//reset the +/- counts | |
for (int j = 0; j < data[k].length - 1; j++){//walk through the data points | |
if (data[k][j].output){ | |
Lp++;//update positive counts | |
Rp--; | |
}else{ | |
Ln++;//update negative counts | |
Rn--; | |
} | |
//score by a parabola approximating information gain | |
double score = Lp * Ln / (double)(Lp + Ln) + Rp * Rn / (double)(Rp + Rn); | |
if (score < bestscore){ // lower score is better | |
bestscore = score;//save score | |
bestaxis = k;//save axis | |
splitafter = j;//svale split location | |
bestLp = Lp;//save positives and negatives to left of split | |
bestLn = Ln ;//so they don't need to be counted again later | |
} | |
} | |
} | |
//if we got a valid split | |
if (bestscore < Double.MAX_VALUE){ | |
splitaxis = axis[bestaxis]; | |
//split halfway between the 2 points around the split | |
splitvalue = 0.5 * (data[bestaxis][splitafter].dot(splitaxis) + data[bestaxis][splitafter + 1].dot(splitaxis)); | |
Datapoint[][] lowerdata = new Datapoint[axis.length][] ; | |
Datapoint[][] upperdata = new Datapoint[axis.length][] ; | |
lowerdata[bestaxis] = new Datapoint[splitafter+1] ;//initialize the child data arrays for the split axis | |
upperdata[bestaxis] = new Datapoint[data[bestaxis].length-splitafter-1] ; | |
for (int k = 0; k <= splitafter; k++){//for the lower node data points | |
data[bestaxis][k].setchild(false);//mark which leaf it goes to | |
lowerdata[bestaxis][k] = data[bestaxis][k];//go ahead and separate the split axis | |
} | |
for (int k = splitafter+1; k < data[bestaxis].length; k++){ | |
data[bestaxis][k].setchild(true);//mark which leaf it goes on | |
upperdata[bestaxis][k-splitafter-1] = data[bestaxis][k] ;//go ahead and separate the split axis | |
} | |
//separate all the other axes maintaining sorting | |
for (int k = 0; k < axis.length; k++){ | |
if (k != bestaxis){//we already did bestaxis=k above | |
//initialize the arrays | |
lowerdata[k] = new Datapoint[splitafter + 1]; | |
upperdata[k] = new Datapoint[data[bestaxis].length - splitafter - 1]; | |
//fill the data into these arrays without changing order | |
int lowerindex=0,upperindex=0; | |
for (int j = 0; j < data[k].length; j++){ | |
if (data[k][j].upperchild){//if goes in upper node | |
upperdata[k][upperindex] = data[k][j]; | |
upperindex++;//put in upper node data array | |
}else{//if goes in lower node | |
lowerdata[k][lowerindex] = data[k][j]; | |
lowerindex++;//put in lower node array | |
} | |
} | |
} | |
} | |
//initialize but do not yet split the children | |
lower = new Treenode(lowerdata, axis, bestLp, bestLn); | |
upper = new Treenode(upperdata, axis, totalpositive - bestLp, totalnegative - bestLn); | |
branchnode = true; | |
return true; | |
}else{//if no valid splits found | |
return false ;//return did not split | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment