From 7848dd1b2485d3593db363fb7d967b456b31aae1 Mon Sep 17 00:00:00 2001 From: Arno Candel Date: Mon, 18 May 2015 20:24:07 -0700 Subject: [PATCH] Update DRFTest that changed with previous fix HEXDEV-319. --- src/test/java/hex/drf/DRFTest.java | 50 ++++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/src/test/java/hex/drf/DRFTest.java b/src/test/java/hex/drf/DRFTest.java index 58b4bdc7ec..2e11920e29 100644 --- a/src/test/java/hex/drf/DRFTest.java +++ b/src/test/java/hex/drf/DRFTest.java @@ -7,6 +7,7 @@ import static org.junit.Assert.assertEquals; import water.*; +import water.api.AUC; import water.api.DRFModelView; import water.fvec.Frame; import water.fvec.RebalanceDataSet; @@ -100,7 +101,7 @@ abstract static class PrepData { abstract int prep(Frame fr); } } catch( IllegalArgumentException iae ) { /*pass*/ } } - @Test public void testBadData() throws Throwable { + @Ignore @Test public void testBadData() throws Throwable { basicDRFTestOOBE( "./smalldata/test/drf_infinitys.csv","infinitys.hex", new PrepData() { @Override int prep(Frame fr) { return fr.find("DateofBirth"); } }, @@ -161,8 +162,8 @@ public void testCreditProstate1() throws Throwable { return fr.find("IsDepDelayed"); } }, 50, - a( a(13987, 6900), - a( 6147,16944)), + a( a(13941, 6946), + a( 5885,17206)), s("NO", "YES")); } @@ -272,4 +273,47 @@ public void basicDRF(String fnametrain, String hexnametrain, String fnametest, S assertEquals(mses[i], mses[0], 1e-15); } } + + public static class repro { + @Ignore + @Test public void testAirline() throws InterruptedException { + Frame tfr=null; + Frame test=null; + + Scope.enter(); + try { + // Load data, hack frames + tfr = parseFrame(Key.make("air.hex"), "/users/arno/sz_bench_data/train-1m.csv"); + test = parseFrame(Key.make("airt.hex"), "/users/arno/sz_bench_data/test.csv"); + + DRF parms = new DRF(); + parms.source = tfr; + parms.validation = test; + //parms.ignored_cols_by_name = new int[]{4,5,6}; + parms.ignored_cols_by_name = new int[]{1,2,3,4,5,7}; + parms.response = tfr.lastVec(); + parms.nbins = 20; + parms.ntrees = 1; + parms.max_depth = 5; + parms.mtries = 1; + parms.sample_rate = 1; + parms.min_rows = 1; + parms.classification = true; + parms.seed = 1; + + DRFModel drf = parms.fork().get(); + Frame pred = drf.score(test); + AUC auc = new AUC(); + auc.vactual = test.lastVec(); + auc.vpredict = pred.lastVec(); + auc.invoke(); + Log.info("Test set AUC: " + auc.data().AUC); + drf.delete(); + } finally{ + if (tfr != null) tfr.delete(); + if (test != null) test.delete(); + } + Scope.exit(); + } + } }