Skip to content

Commit

Permalink
Add DRF reproducibility test, same as for HEXDEV-194 in h2o-dev.
Browse files Browse the repository at this point in the history
  • Loading branch information
arnocandel committed Mar 17, 2015
1 parent 44b07e7 commit a99ae66
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
55 changes: 55 additions & 0 deletions src/test/java/hex/drf/DRFTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import hex.drf.DRF.DRFModel;

import hex.gbm.GBM;
import org.junit.*;

import static org.junit.Assert.assertEquals;
import water.*;
import water.api.DRFModelView;
import water.fvec.Frame;
import water.fvec.RebalanceDataSet;
import water.fvec.Vec;
import water.util.Log;

public class DRFTest extends TestUtil {

Expand Down Expand Up @@ -220,4 +224,55 @@ public void basicDRF(String fnametrain, String hexnametrain, String fnametest, S
if( pred != null ) pred.delete();
}
}

public static class repro {
@BeforeClass public static void stall() { stall_till_cloudsize(1); }
@Test
public void run() {
Frame tfr=null;
final int N = 5;
double[] mses = new double[N];

Scope.enter();
try {
// Load data, hack frames
tfr = parseFrame(Key.make("air.hex"), "./smalldata/covtype/covtype.20k.data");

// rebalance to 256 chunks
Key dest = Key.make("df.rebalanced.hex");
RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
H2O.submitTask(rb);
rb.join();
tfr.delete();
tfr = DKV.get(dest).get();

for (int i=0; i<N; ++i) {
DRF parms = new DRF();
parms.source = tfr;
parms.response = tfr.lastVec();
parms.nbins = 1000;
parms.ntrees = 1;
parms.max_depth = 8;
parms.mtries = -1;
parms.min_rows = 10;
parms.seed = 1234;

// Build a first model; all remaining models should be equal
DRFModel drf = parms.fork().get();
mses[i] = drf.mse();

drf.delete();
}
} finally{
if (tfr != null) tfr.delete();
}
Scope.exit();
for (int i=0; i<mses.length; ++i) {
Log.info("trial: " + i + " -> mse: " + mses[i]);
}
for (int i=0; i<mses.length; ++i) {
assertEquals(mses[i], mses[0], 1e-15);
}
}
}
}
5 changes: 4 additions & 1 deletion src/test/java/hex/gbm/GBMTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;

import water.*;
Expand Down Expand Up @@ -357,7 +358,9 @@ public GBMModel basicGBM(String fname, String hexname, PrepData prep, boolean va


public static class repro {
@Test public void testChunkReprodubility() {
@BeforeClass public static void stall() { stall_till_cloudsize(1); }
@Test
public void run() {
Frame tfr=null;
final int N = 5;
double[] mses = new double[N];
Expand Down

0 comments on commit a99ae66

Please sign in to comment.