|
7 | 7 |
|
8 | 8 | import static org.junit.Assert.assertEquals;
|
9 | 9 | import water.*;
|
| 10 | +import water.api.AUC; |
10 | 11 | import water.api.DRFModelView;
|
11 | 12 | import water.fvec.Frame;
|
12 | 13 | import water.fvec.RebalanceDataSet;
|
@@ -100,7 +101,7 @@ abstract static class PrepData { abstract int prep(Frame fr); }
|
100 | 101 | } catch( IllegalArgumentException iae ) { /*pass*/ }
|
101 | 102 | }
|
102 | 103 |
|
103 |
| - @Test public void testBadData() throws Throwable { |
| 104 | + @Ignore @Test public void testBadData() throws Throwable { |
104 | 105 | basicDRFTestOOBE(
|
105 | 106 | "./smalldata/test/drf_infinitys.csv","infinitys.hex",
|
106 | 107 | new PrepData() { @Override int prep(Frame fr) { return fr.find("DateofBirth"); } },
|
@@ -161,8 +162,8 @@ public void testCreditProstate1() throws Throwable {
|
161 | 162 | return fr.find("IsDepDelayed"); }
|
162 | 163 | },
|
163 | 164 | 50,
|
164 |
| - a( a(13987, 6900), |
165 |
| - a( 6147,16944)), |
| 165 | + a( a(13941, 6946), |
| 166 | + a( 5885,17206)), |
166 | 167 | s("NO", "YES"));
|
167 | 168 | }
|
168 | 169 |
|
@@ -272,4 +273,47 @@ public void basicDRF(String fnametrain, String hexnametrain, String fnametest, S
|
272 | 273 | assertEquals(mses[i], mses[0], 1e-15);
|
273 | 274 | }
|
274 | 275 | }
|
| 276 | + |
| 277 | + public static class repro { |
| 278 | + @Ignore |
| 279 | + @Test public void testAirline() throws InterruptedException { |
| 280 | + Frame tfr=null; |
| 281 | + Frame test=null; |
| 282 | + |
| 283 | + Scope.enter(); |
| 284 | + try { |
| 285 | + // Load data, hack frames |
| 286 | + tfr = parseFrame(Key.make("air.hex"), "/users/arno/sz_bench_data/train-1m.csv"); |
| 287 | + test = parseFrame(Key.make("airt.hex"), "/users/arno/sz_bench_data/test.csv"); |
| 288 | + |
| 289 | + DRF parms = new DRF(); |
| 290 | + parms.source = tfr; |
| 291 | + parms.validation = test; |
| 292 | + //parms.ignored_cols_by_name = new int[]{4,5,6}; |
| 293 | + parms.ignored_cols_by_name = new int[]{1,2,3,4,5,7}; |
| 294 | + parms.response = tfr.lastVec(); |
| 295 | + parms.nbins = 20; |
| 296 | + parms.ntrees = 1; |
| 297 | + parms.max_depth = 5; |
| 298 | + parms.mtries = 1; |
| 299 | + parms.sample_rate = 1; |
| 300 | + parms.min_rows = 1; |
| 301 | + parms.classification = true; |
| 302 | + parms.seed = 1; |
| 303 | + |
| 304 | + DRFModel drf = parms.fork().get(); |
| 305 | + Frame pred = drf.score(test); |
| 306 | + AUC auc = new AUC(); |
| 307 | + auc.vactual = test.lastVec(); |
| 308 | + auc.vpredict = pred.lastVec(); |
| 309 | + auc.invoke(); |
| 310 | + Log.info("Test set AUC: " + auc.data().AUC); |
| 311 | + drf.delete(); |
| 312 | + } finally{ |
| 313 | + if (tfr != null) tfr.delete(); |
| 314 | + if (test != null) test.delete(); |
| 315 | + } |
| 316 | + Scope.exit(); |
| 317 | + } |
| 318 | + } |
275 | 319 | }
|
0 commit comments