forked from h2oai/h2o-2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added the twoing node splitting criterion. While it is most important…
… to offer entropy and gini splitting criteria, it may be nice to offer twoing as well. Reference: Breiman, L. Some Properties of Splitting Criteria, Statistics Department, University of California, Berkeley. 1992. Added some very basic jUnit tests for the twoing computation.
- Loading branch information
1 parent
9c94cb1
commit ce6c45c
Showing
2 changed files
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
package hex.singlenoderf; | ||
|
||
import water.util.Utils; | ||
|
||
import java.util.Random; | ||
|
||
import org.junit.Test; | ||
import org.junit.Assert; | ||
|
||
/** Computes the twoing split statistic. | ||
* | ||
* The decrease in (twoing) impurity as the result of a given split is | ||
* computed as follows: | ||
* | ||
* 1 weight left weight right | ||
* - * ------------ * ------------- * twoing( left, right ) | ||
* 4 weight total weight total | ||
* | ||
* twoing( left, right ) = (\sum(|p_i(left) - p_i(right)|)^2, where | ||
* p_i( left ) is the fraction of observations in the left node of class i | ||
* p_i( right ) is the fraction of observations in the right node of class i | ||
* | ||
* The split that produces the largest decrease in impurity is selected. | ||
* Same is done for exclusions, where again left stands for the rows with column | ||
* value equal to the split value and right for all different ones. | ||
* | ||
* ece 11/14 | ||
*/ | ||
public class TwoingStatistic extends Statistic { | ||
|
||
public TwoingStatistic(Data data, int features, long seed, int exclusiveSplitLimit) { super(data, features, seed, exclusiveSplitLimit, false /*classification*/); } | ||
|
||
private double twoing(int[] dd_l, int sum_l, int[] dd_r, int sum_r ) { | ||
double result = 0.0; | ||
double sd_l = (double)sum_l; | ||
double sd_r = (double)sum_r; | ||
for (int i = 0; i < dd_l.length; i++) { | ||
double tmp = Math.abs(((double)dd_l[i])/sd_l - ((double)dd_r[i])/sd_r); | ||
result = result + tmp; | ||
} | ||
result = result * result; | ||
return result; | ||
} | ||
|
||
@Override protected Split ltSplit(int col, Data d, int[] dist, int distWeight, Random _) { | ||
int[] leftDist = new int[d.classes()]; | ||
int[] riteDist = dist.clone(); | ||
int lW = 0; | ||
int rW = distWeight; | ||
double totWeight = rW; | ||
// we are not a single class, calculate the best split for the column | ||
int bestSplit = -1; | ||
double bestFitness = 0.0; | ||
assert leftDist.length==_columnDists[col][0].length; | ||
|
||
for (int i = 0; i < _columnDists[col].length-1; ++i) { | ||
// first copy the i-th guys from rite to left | ||
for (int j = 0; j < leftDist.length; ++j) { | ||
int t = _columnDists[col][i][j]; | ||
lW += t; | ||
rW -= t; | ||
leftDist[j] += t; | ||
riteDist[j] -= t; | ||
} | ||
// now make sure we have something to split | ||
if( lW == 0 || rW == 0 ) continue; | ||
double f = 0.25 * ((double)lW / totWeight) * ((double)rW / totWeight) * | ||
twoing(leftDist, lW, riteDist, rW); | ||
if( f>bestFitness ) { // Take split with largest fitness | ||
bestSplit = i; | ||
bestFitness = f; | ||
} | ||
} | ||
return bestSplit == -1 | ||
? Split.impossible(Utils.maxIndex(dist, _random)) | ||
: Split.split(col, bestSplit, bestFitness); | ||
} | ||
|
||
@Override protected Split eqSplit(int colIndex, Data d, int[] dist, int distWeight, Random _) { | ||
int[] inclDist = new int[d.classes()]; | ||
int[] exclDist = dist.clone(); | ||
// we are not a single class, calculate the best split for the column | ||
int bestSplit = -1; | ||
double bestFitness = 0.0; // Fitness to maximize | ||
for( int i = 0; i < _columnDists[colIndex].length-1; ++i ) { | ||
// first copy the i-th guys from rite to left | ||
int sumt = 0; | ||
for( int j = 0; j < inclDist.length; ++j ) { | ||
int t = _columnDists[colIndex][i][j]; | ||
sumt += t; | ||
inclDist[j] = t; | ||
exclDist[j] = dist[j] - t; | ||
} | ||
int inclW = sumt; | ||
int exclW = distWeight - inclW; | ||
// now make sure we have something to split | ||
if( inclW == 0 || exclW == 0 ) continue; | ||
double f = ((double)inclW / distWeight) * ((double)exclW / distWeight) * | ||
twoing(inclDist, inclW, exclDist, exclW); | ||
if( f>bestFitness ) { // Take split with largest fitness | ||
bestSplit = i; | ||
bestFitness = f; | ||
} | ||
} | ||
return bestSplit == -1 | ||
? Split.impossible(Utils.maxIndex(dist, _random)) | ||
: Split.exclusion(colIndex, bestSplit, bestFitness); | ||
} | ||
|
||
@Override | ||
protected Split ltSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand) { | ||
return null; //not called for classification | ||
} | ||
|
||
@Override | ||
protected Split eqSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand) { | ||
return null; //not called for classification | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package hex.singlenoderf; | ||
|
||
import org.junit.Test; | ||
import org.junit.Assert; | ||
|
||
public class TwoingStatisticTest { | ||
|
||
@Test | ||
public void twoingTest() { | ||
// basic test cases to check twoing computation | ||
int[] dd_l = {4,0,0,1}; | ||
int[] dd_r = {0,3,2,0}; | ||
double result = twoing(dd_l, 5, dd_r, 5); | ||
Assert.assertTrue(Math.abs(result - 4.0) < 1e-10); | ||
|
||
dd_l[0] = 4; dd_l[1] = 3; dd_l[2] = 2; dd_l[3] = 0; | ||
dd_r[0] = 0; dd_r[1] = 0; dd_r[2] = 0; dd_r[3] = 1; | ||
result = twoing(dd_l, 9, dd_r, 1); | ||
Assert.assertTrue(Math.abs(result - 4.0) < 1e-10); | ||
|
||
dd_l[0] = 4; dd_l[1] = 3; dd_l[2] = 1; dd_l[3] = 0; | ||
dd_r[0] = 0; dd_r[1] = 0; dd_r[2] = 1; dd_r[3] = 1; | ||
result = twoing(dd_l, 8, dd_r, 2); | ||
Assert.assertTrue(Math.abs(result - 3.0625) < 1e-10); | ||
|
||
dd_l[0] = 999; dd_l[1] = 1000005; dd_l[2] = 3009; dd_l[3] = 1; | ||
dd_r[0] = 999; dd_r[1] = 1000005; dd_r[2] = 3009; dd_r[3] = 1; | ||
result = twoing(dd_l, 999+1000005+3009+1, dd_r, 999+1000005+3009+1); | ||
Assert.assertTrue(Math.abs(result - 0.0) < 1e-10); | ||
} | ||
|
||
private double twoing(int[] dd_l, int sum_l, int[] dd_r, int sum_r ) { | ||
double result = 0.0; | ||
double sd_l = (double)sum_l; | ||
double sd_r = (double)sum_r; | ||
for (int i = 0; i < dd_l.length; i++) { | ||
double tmp = Math.abs(((double)dd_l[i])/sd_l - ((double)dd_r[i])/sd_r); | ||
result = result + tmp; | ||
} | ||
result = result * result; | ||
return result; | ||
} | ||
} |