forked from h2oai/h2o-2
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
If enums domains of compatible columns do not match, a mapping is computed and a new vector transforming values is used.
- Loading branch information
1 parent
1f94f78
commit 25a501f
Showing
7 changed files
with
200 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
A,b,X | ||
A,a,X | ||
D,b,X |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
A,a,X | ||
A,b,Y | ||
A,c,Y | ||
A,d,Y | ||
B,a,X | ||
B,b,X | ||
B,c,Y | ||
A,d,Y | ||
C,a,X | ||
C,b,X | ||
C,c,X | ||
A,d,Y | ||
D,a,X |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
package water.fvec; | ||
|
||
import water.*; | ||
|
||
/** | ||
* Dummy vector transforming values of given vector according to given domain mapping. | ||
*/ | ||
public class TransfVec extends Vec { | ||
|
||
final Key _masterVecKey; | ||
final int[] _domMap; | ||
|
||
transient Vec _masterVec; | ||
|
||
public TransfVec(Key masterVecKey, int[] domMap, Key key, long[] espc) { | ||
super(key, espc); | ||
_masterVecKey = masterVecKey; | ||
_domMap = domMap; | ||
} | ||
|
||
private Vec masterVec() { | ||
if (_masterVec==null) _masterVec = DKV.get(_masterVecKey).get(); | ||
return _masterVec; | ||
} | ||
|
||
@Override public Chunk elem2BV(int cidx) { | ||
Chunk c = masterVec().elem2BV(cidx); | ||
return new TransfChunk(c, _domMap); | ||
} | ||
|
||
static class TransfChunk extends Chunk { | ||
Chunk _c; | ||
int[] _domMap; | ||
|
||
public TransfChunk(Chunk c, int[] domMap) { super(); _c = c; _domMap = domMap; } | ||
|
||
@Override protected double atd_impl(int idx) { | ||
double val = _c.atd_impl(idx); | ||
return _domMap[(int)val]; | ||
} | ||
|
||
@Override protected long at8_impl(int idx) { | ||
long val = _c.at8_impl(idx); | ||
return _domMap[(int)val]; | ||
} | ||
|
||
@Override protected boolean isNA_impl(int idx) { return _c.isNA_impl(idx); } | ||
|
||
@Override boolean set_impl(int idx, long l) { return false; } | ||
@Override boolean set_impl(int idx, double d) { return false; } | ||
@Override boolean set_impl(int idx, float f) { return false; } | ||
@Override boolean setNA_impl(int idx) { return false; } | ||
|
||
@Override boolean hasFloat() { return _c.hasFloat(); } | ||
|
||
@Override NewChunk inflate_impl(NewChunk nc) { throw new UnsupportedOperationException(); } | ||
@Override public AutoBuffer write(AutoBuffer bb) { throw new UnsupportedOperationException(); } | ||
@Override public Chunk read(AutoBuffer bb) { throw new UnsupportedOperationException(); } | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package hex.gbm; | ||
|
||
|
||
import java.io.File; | ||
|
||
import org.junit.BeforeClass; | ||
import org.junit.Test; | ||
|
||
import water.*; | ||
import water.fvec.*; | ||
|
||
public class GBMDomainTest extends TestUtil { | ||
|
||
private abstract class PrepData { abstract Vec prep(Frame fr); } | ||
|
||
@BeforeClass public static void stall() { stall_till_cloudsize(1); } | ||
|
||
@Test public void testModelAdapt() { | ||
runAndScoreGBM( | ||
"./smalldata/test/classifier/coldom_train.csv", | ||
"./smalldata/test/classifier/coldom_test.csv", | ||
new PrepData() { @Override Vec prep(Frame fr) { return fr.vecs()[fr.numCols()-1]; } }); | ||
} | ||
|
||
// Adapt a trained model to a test dataset with different enums | ||
void runAndScoreGBM(String train, String test, PrepData prepData) { | ||
File file1 = TestUtil.find_test_file(train); | ||
Key fkey1 = NFSFileVec.make(file1); | ||
Key dest1 = Key.make("train.hex"); | ||
File file2 = TestUtil.find_test_file(test); | ||
Key fkey2 = NFSFileVec.make(file2); | ||
Key dest2 = Key.make("test.hex"); | ||
GBM gbm = null; | ||
Frame preds = null; | ||
try { | ||
gbm = new GBM(); | ||
gbm.source = ParseDataset2.parse(dest1,new Key[]{fkey1}); | ||
gbm.response = prepData.prep(gbm.source); | ||
gbm.ntrees = 2; | ||
gbm.max_depth = 3; | ||
gbm.learn_rate = 0.2f; | ||
gbm.min_rows = 10; | ||
gbm.nbins = 1024; | ||
gbm.cols = new int[] {0,1,2}; | ||
gbm.run(); | ||
|
||
// The test data set has a few more enums than the train | ||
Frame ftest = ParseDataset2.parse(dest2,new Key[]{fkey2}); | ||
preds = gbm.score(ftest); | ||
// Delete test frame | ||
ftest.remove(); | ||
|
||
System.err.println(preds); | ||
|
||
} catch (Throwable t) { | ||
t.printStackTrace(); | ||
} finally { | ||
UKV.remove(fkey1); | ||
UKV.remove(dest1); // Remove original hex frame key | ||
UKV.remove(fkey2); | ||
UKV.remove(dest2); | ||
if( gbm != null ) { | ||
UKV.remove(gbm.dest()); // Remove the model | ||
UKV.remove(gbm.response._key); | ||
gbm.remove(); // Remove GBM Job | ||
if( preds != null ) preds.remove(); | ||
} | ||
} | ||
} | ||
|
||
} |