Skip to content

Commit

Permalink
HEXDEV-194: Rename reproducibiliy tests for DRF/GBM.
Browse files Browse the repository at this point in the history
  • Loading branch information
arnocandel committed Mar 17, 2015
1 parent a99ae66 commit d70a4b8
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 129 deletions.
90 changes: 43 additions & 47 deletions src/test/java/hex/drf/DRFTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -225,54 +225,50 @@ public void basicDRF(String fnametrain, String hexnametrain, String fnametest, S
}
}

public static class repro {
@BeforeClass public static void stall() { stall_till_cloudsize(1); }
@Test
public void run() {
Frame tfr=null;
final int N = 5;
double[] mses = new double[N];

Scope.enter();
try {
// Load data, hack frames
tfr = parseFrame(Key.make("air.hex"), "./smalldata/covtype/covtype.20k.data");

// rebalance to 256 chunks
Key dest = Key.make("df.rebalanced.hex");
RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
H2O.submitTask(rb);
rb.join();
tfr.delete();
tfr = DKV.get(dest).get();

for (int i=0; i<N; ++i) {
DRF parms = new DRF();
parms.source = tfr;
parms.response = tfr.lastVec();
parms.nbins = 1000;
parms.ntrees = 1;
parms.max_depth = 8;
parms.mtries = -1;
parms.min_rows = 10;
parms.seed = 1234;

// Build a first model; all remaining models should be equal
DRFModel drf = parms.fork().get();
mses[i] = drf.mse();

drf.delete();
}
} finally{
if (tfr != null) tfr.delete();
}
Scope.exit();
for (int i=0; i<mses.length; ++i) {
Log.info("trial: " + i + " -> mse: " + mses[i]);
}
for (int i=0; i<mses.length; ++i) {
assertEquals(mses[i], mses[0], 1e-15);
@Test public void testReproducibility() {
Frame tfr=null;
final int N = 5;
double[] mses = new double[N];

Scope.enter();
try {
// Load data, hack frames
tfr = parseFrame(Key.make("air.hex"), "./smalldata/covtype/covtype.20k.data");

// rebalance to 256 chunks
Key dest = Key.make("df.rebalanced.hex");
RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
H2O.submitTask(rb);
rb.join();
tfr.delete();
tfr = DKV.get(dest).get();

for (int i=0; i<N; ++i) {
DRF parms = new DRF();
parms.source = tfr;
parms.response = tfr.lastVec();
parms.nbins = 1000;
parms.ntrees = 1;
parms.max_depth = 8;
parms.mtries = -1;
parms.min_rows = 10;
parms.seed = 1234;

// Build a first model; all remaining models should be equal
DRFModel drf = parms.fork().get();
mses[i] = drf.mse();

drf.delete();
}
} finally{
if (tfr != null) tfr.delete();
}
Scope.exit();
for (int i=0; i<mses.length; ++i) {
Log.info("trial: " + i + " -> mse: " + mses[i]);
}
for (int i=0; i<mses.length; ++i) {
assertEquals(mses[i], mses[0], 1e-15);
}
}
}
158 changes: 76 additions & 82 deletions src/test/java/hex/gbm/GBMTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;

import water.*;
Expand Down Expand Up @@ -89,33 +88,33 @@ private static class CompErr extends MRTask2<CompErr> {
@Test public void testBasicGBM() {
// Regression tests
basicGBM("./smalldata/cars.csv","cars.hex",
new PrepData() { int prep(Frame fr ) { UKV.remove(fr.remove("name")._key); return ~fr.find("economy (mpg)"); }});
new PrepData() { int prep(Frame fr ) { UKV.remove(fr.remove("name")._key); return ~fr.find("economy (mpg)"); }});

// Classification tests
basicGBM("./smalldata/test/test_tree.csv","tree.hex",
new PrepData() { int prep(Frame fr) { return 1; }
});
new PrepData() { int prep(Frame fr) { return 1; }
});
basicGBM("./smalldata/test/test_tree_minmax.csv","tree_minmax.hex",
new PrepData() { int prep(Frame fr) { return fr.find("response"); }
});
new PrepData() { int prep(Frame fr) { return fr.find("response"); }
});
basicGBM("./smalldata/logreg/prostate.csv","prostate.hex",
new PrepData() {
int prep(Frame fr) {
assertEquals(380,fr.numRows());
// Remove patient ID vector
UKV.remove(fr.remove("ID")._key);
// Prostate: predict on CAPSULE
return fr.find("CAPSULE");
}
});
new PrepData() {
int prep(Frame fr) {
assertEquals(380,fr.numRows());
// Remove patient ID vector
UKV.remove(fr.remove("ID")._key);
// Prostate: predict on CAPSULE
return fr.find("CAPSULE");
}
});
basicGBM("./smalldata/cars.csv","cars.hex",
new PrepData() { int prep(Frame fr) { UKV.remove(fr.remove("name")._key); return fr.find("cylinders"); }
});
new PrepData() { int prep(Frame fr) { UKV.remove(fr.remove("name")._key); return fr.find("cylinders"); }
});
basicGBM("./smalldata/airlines/allyears2k_headers.zip","air.hex",
new PrepData() { int prep(Frame fr) {
for( String s : ignored_aircols ) UKV.remove(fr.remove(s)._key);
return fr.find("IsArrDelayed"); }
});
new PrepData() { int prep(Frame fr) {
for( String s : ignored_aircols ) UKV.remove(fr.remove(s)._key);
return fr.find("IsArrDelayed"); }
});
//basicGBM("../datasets/UCI/UCI-large/covtype/covtype.data","covtype.hex",
// new PrepData() {
// int prep(Frame fr) {
Expand All @@ -132,18 +131,18 @@ int prep(Frame fr) {
Scope.enter();
// Classification with Bernoulli family
basicGBM("./smalldata/logreg/prostate.csv","prostate.hex",
new PrepData() {
int prep(Frame fr) {
assertEquals(380,fr.numRows());
// Remove patient ID vector
UKV.remove(fr.remove("ID")._key);
// Change CAPSULE and RACE to categoricals
Scope.track(fr.factor(fr.find("CAPSULE"))._key);
Scope.track(fr.factor(fr.find("RACE" ))._key);
// Prostate: predict on CAPSULE
return fr.find("CAPSULE");
}
}, Family.bernoulli);
new PrepData() {
int prep(Frame fr) {
assertEquals(380,fr.numRows());
// Remove patient ID vector
UKV.remove(fr.remove("ID")._key);
// Change CAPSULE and RACE to categoricals
Scope.track(fr.factor(fr.find("CAPSULE"))._key);
Scope.track(fr.factor(fr.find("RACE" ))._key);
// Prostate: predict on CAPSULE
return fr.find("CAPSULE");
}
}, Family.bernoulli);
Scope.exit();
}

Expand Down Expand Up @@ -336,7 +335,7 @@ public GBMModel basicGBM(String fname, String hexname, PrepData prep, boolean va
// Prostate: predict on CAPSULE
return fr.find("CAPSULE");
}
};
};
double[] mseWithoutVal = basicGBM("./smalldata/logreg/prostate.csv","prostate.hex", prostatePrep, false).errs;
double[] mseWithVal = basicGBM("./smalldata/logreg/prostate.csv","prostate.hex", prostatePrep, true ).errs;
Assert.assertArrayEquals("GBM has to report same list of MSEs for run without/with validation dataset (which is equal to training data)", mseWithoutVal, mseWithVal, 0.0001);
Expand All @@ -350,61 +349,56 @@ public GBMModel basicGBM(String fname, String hexname, PrepData prep, boolean va
// Airlines: predict on CAPSULE
return fr.find("survived");
}
};
};
double[] mseWithoutVal = basicGBM("./smalldata/titanicalt.csv","titanic.hex", titanicPrep, false).errs;
double[] mseWithVal = basicGBM("./smalldata/titanicalt.csv","titanic.hex", titanicPrep, true ).errs;
Assert.assertArrayEquals("GBM has to report same list of MSEs for run without/with validation dataset (which is equal to training data)", mseWithoutVal, mseWithVal, 0.0001);
}

@Test public void testReproducibility() {
Frame tfr=null;
final int N = 5;
double[] mses = new double[N];

public static class repro {
@BeforeClass public static void stall() { stall_till_cloudsize(1); }
@Test
public void run() {
Frame tfr=null;
final int N = 5;
double[] mses = new double[N];

Scope.enter();
try {
// Load data, hack frames
tfr = parseFrame(Key.make("air.hex"), "./smalldata/covtype/covtype.20k.data");

// rebalance to 256 chunks
Key dest = Key.make("df.rebalanced.hex");
RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
H2O.submitTask(rb);
rb.join();
tfr.delete();
tfr = DKV.get(dest).get();

for (int i=0; i<N; ++i) {
GBM parms = new GBM();
parms.source = tfr;
parms.response = tfr.lastVec();
parms.nbins = 1000;
parms.ntrees = 1;
parms.max_depth = 8;
parms.learn_rate = 0.1;
parms.min_rows = 10;
parms.family = Family.AUTO;

// Build a first model; all remaining models should be equal
GBMModel gbm = parms.fork().get();
mses[i] = gbm.mse();

gbm.delete();
}
} finally{
if (tfr != null) tfr.delete();
}
Scope.exit();
for (int i=0; i<mses.length; ++i) {
Log.info("trial: " + i + " -> mse: " + mses[i]);
}
for (int i=0; i<mses.length; ++i) {
assertEquals(mses[i], mses[0], 1e-15);
Scope.enter();
try {
// Load data, hack frames
tfr = parseFrame(Key.make("air.hex"), "./smalldata/covtype/covtype.20k.data");

// rebalance to 256 chunks
Key dest = Key.make("df.rebalanced.hex");
RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
H2O.submitTask(rb);
rb.join();
tfr.delete();
tfr = DKV.get(dest).get();

for (int i=0; i<N; ++i) {
GBM parms = new GBM();
parms.source = tfr;
parms.response = tfr.lastVec();
parms.nbins = 1000;
parms.ntrees = 1;
parms.max_depth = 8;
parms.learn_rate = 0.1;
parms.min_rows = 10;
parms.family = Family.AUTO;

// Build a first model; all remaining models should be equal
GBMModel gbm = parms.fork().get();
mses[i] = gbm.mse();

gbm.delete();
}
} finally{
if (tfr != null) tfr.delete();
}
Scope.exit();
for (int i=0; i<mses.length; ++i) {
Log.info("trial: " + i + " -> mse: " + mses[i]);
}
for (int i=0; i<mses.length; ++i) {
assertEquals(mses[i], mses[0], 1e-15);
}
}

Expand Down

0 comments on commit d70a4b8

Please sign in to comment.