Skip to content

Commit

Permalink
Merge branch 'master' of github.com:0xdata/h2o
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Dec 5, 2013
2 parents a0a06aa + e195262 commit 099fa24
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
20 changes: 15 additions & 5 deletions src/main/java/hex/drf/DRF.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import water.H2O.H2OCountedCompleter;
import water.api.DRFProgressPage;
import water.api.DocGen;
import water.api.Request.API;
import water.fvec.*;
import water.util.*;
import water.util.Log.Tag.Sys;
Expand All @@ -37,6 +38,9 @@ public class DRF extends SharedTreeModelBuilder {
@API(help = "Compute variable importance (true/false).", filter = Default.class )
boolean importance = false; // compute variable importance

@API(help = "Computed number of split features")
protected int _mtry;

/** DRF model holding serialized tree and implementing logic for scoring a row */
public static class DRFModel extends DTree.TreeModel {
static final int API_WEAVER = 1; // This file has auto-gen'd doc & json fields
Expand Down Expand Up @@ -123,11 +127,17 @@ public static String link(Key k, String content) {
return DRFProgressPage.redirect(this, self(), dest());
}

@Override protected void buildModel( final Frame fr, String names[], String domains[][], final Key outputKey, final Key dataKey, final Key testKey, final Timer t_build ) {
final int cmtries = (mtries==-1) ? // classification: mtry=sqrt(_ncols), regression: mtry=_ncols/3
@Override protected void init() {
super.init();
// Initialize local variables
_mtry = (mtries==-1) ? // classification: mtry=sqrt(_ncols), regression: mtry=_ncols/3
( classification ? Math.max((int)Math.sqrt(_ncols),1) : Math.max(_ncols/3,1)) : mtries;
assert 1 <= cmtries && cmtries <= _ncols : "Too large mtries="+cmtries+", ncols="+_ncols;
assert 0.0 < sample_rate && sample_rate <= 1.0;
if (!(1 <= _mtry && _mtry <= _ncols)) throw new IllegalArgumentException("Computed mtry should be in interval <1,#cols> but it is " + _mtry);
if (!(0.0 < sample_rate && sample_rate <= 1.0)) throw new IllegalArgumentException("Sample rate should be interval (0,1> but it is " + sample_rate);
}

@Override protected void buildModel( final Frame fr, String names[], String domains[][], final Key outputKey, final Key dataKey, final Key testKey, final Timer t_build ) {

DRFModel model = new DRFModel(outputKey,dataKey,testKey,names,domains,ntrees, max_depth, min_rows, nbins, mtries, sample_rate, seed);
DKV.put(outputKey, model);

Expand All @@ -147,7 +157,7 @@ public static String link(Key k, String content) {
// TODO: parallelize more? build more than k trees at each time, we need to care about temporary data
// Idea: launch more DRF at once.
Timer t_kTrees = new Timer();
ktrees = buildNextKTrees(fr,cmtries,sample_rate,rand);
ktrees = buildNextKTrees(fr,_mtry,sample_rate,rand);
Log.info(Sys.DRF__, "Tree "+(tid+1)+"x"+_nclass+" produced in "+t_kTrees);
if( cancelled() ) break; // If canceled during building, do not bulkscore

Expand Down
4 changes: 2 additions & 2 deletions src/main/java/hex/gbm/SharedTreeModelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ public abstract class SharedTreeModelBuilder extends ValidatedJob {
}
_nclass = response.isEnum() ? (char)(response.domain().length) : 1;
_errs = new double[0]; // No trees yet
if (_nclass < 1)
throw new IllegalArgumentException("Only one level in response column!");
if (classification && _nclass <= 1)
throw new IllegalArgumentException("Constant response column!");
if (_nclass > MAX_SUPPORTED_LEVELS)
throw new IllegalArgumentException("Too many levels in response column!");
}
Expand Down

0 comments on commit 099fa24

Please sign in to comment.