Skip to content

Commit

Permalink
Fix distributed DRF/GBM bug with min/max on bins.
Browse files Browse the repository at this point in the history
Also GBM marks missing-response rows, so hot loop can ignore them.
  • Loading branch information
cliffclick committed Jan 1, 2014
1 parent 37e3801 commit eb2598e
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 11 deletions.
2 changes: 1 addition & 1 deletion prj.el
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
'(jde-run-option-debug nil)
'(jde-run-option-vm-args (quote ("-XX:+PrintGC")))
'(jde-compile-option-directory "./target/classes")
'(jde-run-option-application-args (quote ("-beta" "-mainClass" "org.junit.runner.JUnitCore" "hex.drf.DRFModelAdaptTest" "hex.gbm.GBMTestX")))
'(jde-run-option-application-args (quote ("-beta" "-mainClass" "org.junit.runner.JUnitCore" "hex.gbm.GBMTest")))
'(jde-debugger (quote ("JDEbug")))
'(jde-compile-option-source (quote ("1.6")))
'(jde-compile-option-classpath (quote ("./target/classes" "./lib/javassist.jar" "./lib/hadoop/cdh4/hadoop-common.jar" "./lib/hadoop/cdh4/hadoop-auth.jar" "./lib/hadoop/cdh4/slf4j-api-1.6.1.jar" "./lib/hadoop/cdh4/slf4j-nop-1.6.1.jar" "./lib/hadoop/cdh4/hadoop-hdfs.jar" "./lib/hadoop/cdh4/protobuf-java-2.4.0a.jar" "./lib/apache/commons-codec-1.4.jar" "./lib/apache/commons-configuration-1.6.jar" "./lib/apache/commons-lang-2.4.jar" "./lib/apache/commons-logging-1.1.1.jar" "./lib/apache/httpclient-4.1.1.jar" "./lib/apache/httpcore-4.1.jar" "./lib/junit/junit-4.11.jar" "./lib/apache/guava-12.0.1.jar" "./lib/gson/gson-2.2.2.jar" "./lib/poi/poi-3.8-20120326.jar" "./lib/poi/poi-ooxml-3.8-20120326.jar" "./lib/poi/poi-ooxml-schemas-3.8-20120326.jar" "./lib/poi/dom4j-1.6.1.jar" "./lib/Jama/Jama.jar" "./lib/s3/aws-java-sdk-1.3.27.jar" "./lib/log4j/log4j-1.2.15.jar")))
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/hex/gbm/DHistogram.java
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ void add( TDH dsh ) {
assert (_bins == null && dsh._bins == null) || (_bins != null && dsh._bins != null);
if( _bins == null ) return;
Utils.add(_bins,dsh._bins);
if( _min2 < _min2 ) _min2 = dsh._min2 ;
if( _maxIn > _maxIn ) _maxIn = dsh._maxIn;
if( _min2 > dsh._min2 ) _min2 = dsh._min2 ;
if( _maxIn < dsh._maxIn ) _maxIn = dsh._maxIn;
add0(dsh);
}

Expand Down
20 changes: 18 additions & 2 deletions src/main/java/hex/gbm/GBM.java
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ public static String link(Key k, String content) {
GBMModel model = new GBMModel(outputKey, dataKey, testKey, names, domains, ntrees, max_depth, min_rows, nbins, learn_rate);
DKV.put(outputKey, model);

// Tag out rows missing the response column
new ExcludeNAResponse().doAll(fr);

// Build trees until we hit the limit
int tid;
DTree[] ktrees = null; // Trees
Expand Down Expand Up @@ -142,6 +145,18 @@ public static String link(Key k, String content) {
cleanUp(fr,t_build); // Shared cleanup
}

// --------------------------------------------------------------------------
// Tag out rows missing the response column
class ExcludeNAResponse extends MRTask2<ExcludeNAResponse> {
@Override public void map( Chunk chks[] ) {
Chunk ys = chk_resp(chks);
for( int row=0; row<ys._len; row++ )
if( ys.isNA0(row) )
for( int t=0; t<_nclass; t++ )
chk_nids(chks,t).set0(row,-1);
}
}

// --------------------------------------------------------------------------
// Compute Prediction from prior tree results.
// Classification: Probability Distribution of loglikelyhoods
Expand Down Expand Up @@ -317,6 +332,7 @@ private DTree[] buildNextKTrees(Frame fr) {
final Chunk ct = chk_tree(chks,k);
for( int row=0; row<nids._len; row++ ) {
int nid = (int)nids.at80(row);
if( nid < 0 ) continue;
ct.set0(row, (float)(ct.at0(row) + ((LeafNode)tree.node(nid))._pred));
nids.set0(row,0);
}
Expand Down Expand Up @@ -364,7 +380,7 @@ private class GammaPass extends MRTask2<GammaPass> {
if( tree.root() instanceof LeafNode ) continue;
for( int row=0; row<nids._len; row++ ) { // For all rows
int nid = (int)nids.at80(row); // Get Node to decide from
int oldnid = nid;
if( nid < 0 ) continue; // Missing response
if( tree.node(nid) instanceof UndecidedNode ) // If we bottomed out the tree
nid = tree.node(nid)._pid; // Then take parent's decision
DecidedNode dn = tree.decided(nid); // Must have a decision point
Expand All @@ -378,7 +394,7 @@ private class GammaPass extends MRTask2<GammaPass> {
// sum-of-residuals (and sum/abs/mult residuals) for all rows in the
// leaf, and get our prediction from that.
nids.set0(row,leafnid);
if( ress.isNA0(row) ) continue;
assert !ress.isNA0(row);
double res = ress.at0(row);
double ares = Math.abs(res);
gs[leafnid-leaf] += _nclass > 1 ? ares*(1-ares) : 1;
Expand Down
8 changes: 7 additions & 1 deletion src/main/java/hex/gbm/SharedTreeModelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ public ScoreBuildHistogram(int k, DTree tree, int leaf, DHistogram hcs[][], bool
int nnids[] = new int[nids._len];
if( _leaf > 0) // Prior pass exists?
score_decide(chks,nids,wrks,tree,nnids);
else // Just flag all the NA rows
for( int row=0; row<nids._len; row++ )
if( isDecidedRow((int)nids.at0(row)) ) nnids[row] = -1;

// Pass 2: accumulate all rows, cols into histograms
if( _subset ) accum_subset(chks,nids,wrks,nnids);
Expand Down Expand Up @@ -295,7 +298,10 @@ else if( hs2[j] != null )
private void score_decide(Chunk chks[], Chunk nids, Chunk wrks, Chunk tree, int nnids[]) {
for( int row=0; row<nids._len; row++ ) { // Over all rows
int nid = (int)nids.at80(row); // Get Node to decide from
if( isDecidedRow(nid)) continue; // already done
if( isDecidedRow(nid)) { // already done
nnids[row] = nid;
continue;
}
// Score row against current decisions & assign new split
boolean oob = isOOBRow(nid);
if( oob ) nid = oob2Nid(nid); // sampled away - we track the position in the tree
Expand Down
10 changes: 5 additions & 5 deletions src/test/java/hex/drf/DRFTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ abstract static class PrepData { abstract int prep(Frame fr); }
1,
a( a(6, 0, 0),
a(0, 7, 0),
a(0, 2, 11)),
a(0, 3, 10)),
s("Iris-setosa","Iris-versicolor","Iris-virginica") );
}

Expand All @@ -44,7 +44,7 @@ abstract static class PrepData { abstract int prep(Frame fr); }
50,
a( a(30, 0, 0),
a(0, 31, 3),
a(0, 3, 33)),
a(0, 4, 32)),
s("Iris-setosa","Iris-versicolor","Iris-virginica") );
}

Expand All @@ -56,7 +56,7 @@ abstract static class PrepData { abstract int prep(Frame fr); }
new PrepData() { @Override int prep(Frame fr) { UKV.remove(fr.remove("name")._key); return fr.find("cylinders"); } },
1,
a( a(0, 2, 0, 0, 0),
a(0, 52, 0, 3, 1),
a(1, 51, 0, 3, 1),
a(0, 0, 0, 0, 0),
a(0, 2, 0,16, 2),
a(0, 0, 0, 0,33)),
Expand All @@ -72,8 +72,8 @@ abstract static class PrepData { abstract int prep(Frame fr); }
a( a(0, 4, 0, 0, 0),
a(0, 207, 0, 0, 0),
a(0, 2, 0, 1, 0),
a(0, 4, 0, 80, 0),
a(0, 0, 0, 4, 104)),
a(0, 4, 0, 79, 1),
a(0, 0, 1, 3, 104)),
s("3", "4", "5", "6", "8"));
}

Expand Down

0 comments on commit eb2598e

Please sign in to comment.