Skip to content

Commit

Permalink
Update (ignored) JUnit for 1M Airline to build 10 trees on all predic…
Browse files Browse the repository at this point in the history
…tors (num+cat) to get AUC 0.7426.
  • Loading branch information
arnocandel committed May 19, 2015
1 parent 8f7f697 commit cce7bed
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions src/test/java/hex/drf/DRFTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -285,21 +285,25 @@ public static class repro {
// Load data, hack frames
tfr = parseFrame(Key.make("air.hex"), "/users/arno/sz_bench_data/train-1m.csv");
test = parseFrame(Key.make("airt.hex"), "/users/arno/sz_bench_data/test.csv");
for (int i : new int[]{4,5,6}) {
tfr.vecs()[i] = tfr.vecs()[i].toEnum();
test.vecs()[i] = test.vecs()[i].toEnum();
}

DRF parms = new DRF();
parms.source = tfr;
parms.validation = test;
//parms.ignored_cols_by_name = new int[]{4,5,6};
parms.ignored_cols_by_name = new int[]{0,1,2,3,4,5,7};
// parms.ignored_cols_by_name = new int[]{4,5,6};
// parms.ignored_cols_by_name = new int[]{0,1,2,3,4,5,7};
parms.response = tfr.lastVec();
parms.nbins = 20;
parms.ntrees = 1;
parms.max_depth = 5;
parms.mtries = 1;
parms.sample_rate = 1;
parms.min_rows = 1;
parms.ntrees = 10;
parms.max_depth = 20;
parms.mtries = -1;
parms.sample_rate = 0.667f;
parms.min_rows = 10;
parms.classification = true;
parms.seed = 1;
parms.seed = 12;

DRFModel drf = parms.fork().get();
Frame pred = drf.score(test);
Expand Down

0 comments on commit cce7bed

Please sign in to comment.