Skip to content

Commit

Permalink
Added the twoing node splitting criterion. While it is most important…
Browse files Browse the repository at this point in the history
… to offer entropy and gini splitting criteria, it may be nice to offer twoing as well.

Reference: Breiman, L. Some Properties of Splitting Criteria, Statistics Department, University of California, Berkeley. 1992.

Added some very basic jUnit tests for the twoing computation.
  • Loading branch information
ericeckstrand committed Nov 5, 2014
1 parent 9c94cb1 commit ce6c45c
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 0 deletions.
119 changes: 119 additions & 0 deletions src/main/java/hex/singlenoderf/TwoingStatistic.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package hex.singlenoderf;

import water.util.Utils;

import java.util.Random;

import org.junit.Test;
import org.junit.Assert;

/** Computes the twoing split statistic.
*
* The decrease in (twoing) impurity as the result of a given split is
* computed as follows:
*
* 1 weight left weight right
* - * ------------ * ------------- * twoing( left, right )
* 4 weight total weight total
*
* twoing( left, right ) = (\sum(|p_i(left) - p_i(right)|)^2, where
* p_i( left ) is the fraction of observations in the left node of class i
* p_i( right ) is the fraction of observations in the right node of class i
*
* The split that produces the largest decrease in impurity is selected.
* Same is done for exclusions, where again left stands for the rows with column
* value equal to the split value and right for all different ones.
*
* ece 11/14
*/
public class TwoingStatistic extends Statistic {

public TwoingStatistic(Data data, int features, long seed, int exclusiveSplitLimit) { super(data, features, seed, exclusiveSplitLimit, false /*classification*/); }

private double twoing(int[] dd_l, int sum_l, int[] dd_r, int sum_r ) {
double result = 0.0;
double sd_l = (double)sum_l;
double sd_r = (double)sum_r;
for (int i = 0; i < dd_l.length; i++) {
double tmp = Math.abs(((double)dd_l[i])/sd_l - ((double)dd_r[i])/sd_r);
result = result + tmp;
}
result = result * result;
return result;
}

@Override protected Split ltSplit(int col, Data d, int[] dist, int distWeight, Random _) {
int[] leftDist = new int[d.classes()];
int[] riteDist = dist.clone();
int lW = 0;
int rW = distWeight;
double totWeight = rW;
// we are not a single class, calculate the best split for the column
int bestSplit = -1;
double bestFitness = 0.0;
assert leftDist.length==_columnDists[col][0].length;

for (int i = 0; i < _columnDists[col].length-1; ++i) {
// first copy the i-th guys from rite to left
for (int j = 0; j < leftDist.length; ++j) {
int t = _columnDists[col][i][j];
lW += t;
rW -= t;
leftDist[j] += t;
riteDist[j] -= t;
}
// now make sure we have something to split
if( lW == 0 || rW == 0 ) continue;
double f = 0.25 * ((double)lW / totWeight) * ((double)rW / totWeight) *
twoing(leftDist, lW, riteDist, rW);
if( f>bestFitness ) { // Take split with largest fitness
bestSplit = i;
bestFitness = f;
}
}
return bestSplit == -1
? Split.impossible(Utils.maxIndex(dist, _random))
: Split.split(col, bestSplit, bestFitness);
}

@Override protected Split eqSplit(int colIndex, Data d, int[] dist, int distWeight, Random _) {
int[] inclDist = new int[d.classes()];
int[] exclDist = dist.clone();
// we are not a single class, calculate the best split for the column
int bestSplit = -1;
double bestFitness = 0.0; // Fitness to maximize
for( int i = 0; i < _columnDists[colIndex].length-1; ++i ) {
// first copy the i-th guys from rite to left
int sumt = 0;
for( int j = 0; j < inclDist.length; ++j ) {
int t = _columnDists[colIndex][i][j];
sumt += t;
inclDist[j] = t;
exclDist[j] = dist[j] - t;
}
int inclW = sumt;
int exclW = distWeight - inclW;
// now make sure we have something to split
if( inclW == 0 || exclW == 0 ) continue;
double f = ((double)inclW / distWeight) * ((double)exclW / distWeight) *
twoing(inclDist, inclW, exclDist, exclW);
if( f>bestFitness ) { // Take split with largest fitness
bestSplit = i;
bestFitness = f;
}
}
return bestSplit == -1
? Split.impossible(Utils.maxIndex(dist, _random))
: Split.exclusion(colIndex, bestSplit, bestFitness);
}

@Override
protected Split ltSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand) {
return null; //not called for classification
}

@Override
protected Split eqSplit(int colIndex, Data d, float[] dist, float distWeight, Random rand) {
return null; //not called for classification
}
}
43 changes: 43 additions & 0 deletions src/test/java/hex/singlenoderf/TwoingStatisticTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package hex.singlenoderf;

import org.junit.Test;
import org.junit.Assert;

public class TwoingStatisticTest {

@Test
public void twoingTest() {
// basic test cases to check twoing computation
int[] dd_l = {4,0,0,1};
int[] dd_r = {0,3,2,0};
double result = twoing(dd_l, 5, dd_r, 5);
Assert.assertTrue(Math.abs(result - 4.0) < 1e-10);

dd_l[0] = 4; dd_l[1] = 3; dd_l[2] = 2; dd_l[3] = 0;
dd_r[0] = 0; dd_r[1] = 0; dd_r[2] = 0; dd_r[3] = 1;
result = twoing(dd_l, 9, dd_r, 1);
Assert.assertTrue(Math.abs(result - 4.0) < 1e-10);

dd_l[0] = 4; dd_l[1] = 3; dd_l[2] = 1; dd_l[3] = 0;
dd_r[0] = 0; dd_r[1] = 0; dd_r[2] = 1; dd_r[3] = 1;
result = twoing(dd_l, 8, dd_r, 2);
Assert.assertTrue(Math.abs(result - 3.0625) < 1e-10);

dd_l[0] = 999; dd_l[1] = 1000005; dd_l[2] = 3009; dd_l[3] = 1;
dd_r[0] = 999; dd_r[1] = 1000005; dd_r[2] = 3009; dd_r[3] = 1;
result = twoing(dd_l, 999+1000005+3009+1, dd_r, 999+1000005+3009+1);
Assert.assertTrue(Math.abs(result - 0.0) < 1e-10);
}

private double twoing(int[] dd_l, int sum_l, int[] dd_r, int sum_r ) {
double result = 0.0;
double sd_l = (double)sum_l;
double sd_r = (double)sum_r;
for (int i = 0; i < dd_l.length; i++) {
double tmp = Math.abs(((double)dd_l[i])/sd_l - ((double)dd_r[i])/sd_r);
result = result + tmp;
}
result = result * result;
return result;
}
}

0 comments on commit ce6c45c

Please sign in to comment.