Skip to content

Commit

Permalink
Mapping of columns domains.
Browse files Browse the repository at this point in the history
If enums domains of compatible columns do not match, a mapping is
computed and a new vector transforming values is used.
  • Loading branch information
mmalohlava committed Oct 1, 2013
1 parent 1f94f78 commit 25a501f
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 19 deletions.
3 changes: 3 additions & 0 deletions smalldata/test/classifier/coldom_test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
A,b,X
A,a,X
D,b,X
13 changes: 13 additions & 0 deletions smalldata/test/classifier/coldom_train.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
A,a,X
A,b,Y
A,c,Y
A,d,Y
B,a,X
B,b,X
B,c,Y
A,d,Y
C,a,X
C,b,X
C,c,X
A,d,Y
D,a,X
52 changes: 38 additions & 14 deletions src/main/java/water/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ public Frame score( Frame fr, boolean exact ) {
Vec v = fr2.anyVec().makeZero();
// If the model produces a classification/enum, copy the domain into the
// result vector.
// FIXME adapt domain according to a mapping!
v._domain = _domains[_domains.length-1];
fr2.add("predict",v);
if( nclasses() > 1 )
Expand All @@ -108,7 +109,13 @@ public Frame score( Frame fr, boolean exact ) {
}.doAll(fr2);
// Return just the output columns
int x=_names.length-1, y=fr2.numCols();
return new Frame(Arrays.copyOfRange(fr2._names,x,y),Arrays.copyOfRange(fr2.vecs(),x,y));
Frame result = new Frame(Arrays.copyOfRange(fr2._names,x,y),Arrays.copyOfRange(fr2.vecs(),x,y));
// FIXME make a generic code in Frame
int[] col2rem = new int[y-x];
for (int i=0;i<col2rem.length;i++) col2rem[i] = x+i;
fr2.remove(col2rem);
fr2.remove();
return result;
}

/** Single row scoring, on a compatible Frame. */
Expand Down Expand Up @@ -188,17 +195,7 @@ private int[][] adapt( String names[], String domains[][], boolean exact ) {
throw new IllegalArgumentException("Incompatible column: '" + _names[c] + "', expected (trained on) categorical, was passed a numeric");
throw H2O.unimpl(); // Attempt an asEnum?
} else if( !Arrays.deepEquals(ms, ds) ) {
int emap[] = map[c] = new int[ds.length];
HashMap<String,Integer> md = new HashMap<String, Integer>();
for( int i = 0; i < ms.length; i++) md.put(ms[i], i);
for( int i = 0; i < ds.length; i++) {
Integer I = md.get(ds[i]);
if( I==null && exact )
throw new IllegalArgumentException("Column "+_names[c]+" was not trained with factor '"+ds[i]+"' which appears in the data");
emap[i] = I==null ? -1 : I;
}
for( int i = 0; i < ds.length; i++)
assert emap[i]==-1 || ms[emap[i]].equals(ds[i]);
map[c] = getDomainMapping(_names[c], ms, ds, exact);
} else {
// null mapping is equal to identity mapping
}
Expand All @@ -214,13 +211,14 @@ public Frame adapt( Frame fr, boolean exact ) {
int map[][] = adapt(fr.names(),fr.domains(),exact);
int cmap[] = map[_names.length-1];
Vec vecs[] = new Vec[_names.length-1];
for( int c=0; c<cmap.length; c++ ) {
for( int c=0; c<cmap.length; c++ ) { // iterate over columns
int d = cmap[c]; // Data index
if( d == -1 ) throw H2O.unimpl(); // Swap in a new all-NA Vec
else if( map[c] == null ) { // No or identity domain map?
vecs[c] = fr.vecs()[d]; // Just use the Vec as-is
} else {
throw H2O.unimpl(); // Domain mapping needed!
// Domain mapping - creates a new vector
vecs[c] = remapVecDomain(map[c], fr.vecs()[d]);
}
}
return new Frame(Arrays.copyOf(_names,_names.length-1),vecs);
Expand All @@ -244,4 +242,30 @@ protected float[] score0( Chunk chks[], int row_in_chunk, double[] tmp, float[]
// Version where the user has just ponied-up an array of data to be scored.
// Data must be in proper order. Handy for JUnit tests.
public double score(double [] data){ return Utils.maxIndex(score0(data,new float[nclasses()])); }

/**
* Returns a mapping between values domains for a given column.
*/
public static int[] getDomainMapping(String colName, String[] modelDom, String[] dom, boolean exact) {
int emap[] = new int[dom.length];
HashMap<String,Integer> md = new HashMap<String, Integer>();
for( int i = 0; i < modelDom.length; i++) md.put(modelDom[i], i);
for( int i = 0; i < dom.length; i++) {
Integer I = md.get(dom[i]);
if( I==null && exact )
throw new IllegalArgumentException("Column "+colName+" was not trained with factor '"+dom[i]+"' which appears in the data");
emap[i] = I==null ? -1 : I;
}
for( int i = 0; i < dom.length; i++)
assert emap[i]==-1 || modelDom[emap[i]].equals(dom[i]);
return emap;
}

/** Recreate given vector respecting given domain mapping. */
public static Vec remapVecDomain(int[] map, Vec vec) {
assert vec._domain != null; // support only string enums
// Make a vector transforming original vector on-the-fly according to a given map
Vec rVec = vec.makeTransf(map);
return rVec;
}
}
10 changes: 5 additions & 5 deletions src/main/java/water/fvec/Frame.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ public class Frame extends Iced {

public Frame( Frame fr ) { this(fr._names.clone(), fr.vecs().clone()); _col0 = fr._col0; }
public Frame( Vec... vecs ){ this(null,vecs);}
public Frame( String[] names, Vec[] vecs ) {
_names=names;
_vecs=vecs;
public Frame( String[] names, Vec[] vecs ) {
_names=names;
_vecs=vecs;
_keys = new Key[vecs.length];
for( int i=0; i<vecs.length; i++ ) {
Key k = _keys[i] = vecs[i]._key;
Expand All @@ -33,12 +33,12 @@ public Frame( String[] names, Vec[] vecs ) {
}
}

public final Vec[] vecs() {
public final Vec[] vecs() {
if( _vecs != null ) return _vecs;
_vecs = new Vec[_keys.length];
for( int i=0; i<_keys.length; i++ )
_vecs[i] = DKV.get(_keys[i]).get();
return _vecs;
return _vecs;
}
// Force a cache-flush & reload, assuming vec mappings were altered remotely
public final Vec[] reloadVecs() { _vecs=null; return vecs(); }
Expand Down
60 changes: 60 additions & 0 deletions src/main/java/water/fvec/TransfVec.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package water.fvec;

import water.*;

/**
* Dummy vector transforming values of given vector according to given domain mapping.
*/
public class TransfVec extends Vec {

final Key _masterVecKey;
final int[] _domMap;

transient Vec _masterVec;

public TransfVec(Key masterVecKey, int[] domMap, Key key, long[] espc) {
super(key, espc);
_masterVecKey = masterVecKey;
_domMap = domMap;
}

private Vec masterVec() {
if (_masterVec==null) _masterVec = DKV.get(_masterVecKey).get();
return _masterVec;
}

@Override public Chunk elem2BV(int cidx) {
Chunk c = masterVec().elem2BV(cidx);
return new TransfChunk(c, _domMap);
}

static class TransfChunk extends Chunk {
Chunk _c;
int[] _domMap;

public TransfChunk(Chunk c, int[] domMap) { super(); _c = c; _domMap = domMap; }

@Override protected double atd_impl(int idx) {
double val = _c.atd_impl(idx);
return _domMap[(int)val];
}

@Override protected long at8_impl(int idx) {
long val = _c.at8_impl(idx);
return _domMap[(int)val];
}

@Override protected boolean isNA_impl(int idx) { return _c.isNA_impl(idx); }

@Override boolean set_impl(int idx, long l) { return false; }
@Override boolean set_impl(int idx, double d) { return false; }
@Override boolean set_impl(int idx, float f) { return false; }
@Override boolean setNA_impl(int idx) { return false; }

@Override boolean hasFloat() { return _c.hasFloat(); }

@Override NewChunk inflate_impl(NewChunk nc) { throw new UnsupportedOperationException(); }
@Override public AutoBuffer write(AutoBuffer bb) { throw new UnsupportedOperationException(); }
@Override public Chunk read(AutoBuffer bb) { throw new UnsupportedOperationException(); }
}
}
10 changes: 10 additions & 0 deletions src/main/java/water/fvec/Vec.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,16 @@ public Vec makeCon( final double d ) {
return v0;
}

// Create a vector transforming values according given domain map
public Vec makeTransf(final int[] domMap) {
Futures fs = new Futures();
if( _espc == null ) throw H2O.unimpl();
Vec v0 = new TransfVec(this._key, domMap, group().addVecs(1)[0],_espc);
DKV.put(v0._key,v0,fs);
fs.blockForPending();
return v0;
}

/** Number of elements in the vector. Overridden by subclasses that compute
* length in an alternative way, such as file-backed Vecs. */
public long length() { return _espc[_espc.length-1]; }
Expand Down
71 changes: 71 additions & 0 deletions src/test/java/hex/gbm/GBMDomainTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package hex.gbm;


import java.io.File;

import org.junit.BeforeClass;
import org.junit.Test;

import water.*;
import water.fvec.*;

public class GBMDomainTest extends TestUtil {

private abstract class PrepData { abstract Vec prep(Frame fr); }

@BeforeClass public static void stall() { stall_till_cloudsize(1); }

@Test public void testModelAdapt() {
runAndScoreGBM(
"./smalldata/test/classifier/coldom_train.csv",
"./smalldata/test/classifier/coldom_test.csv",
new PrepData() { @Override Vec prep(Frame fr) { return fr.vecs()[fr.numCols()-1]; } });
}

// Adapt a trained model to a test dataset with different enums
void runAndScoreGBM(String train, String test, PrepData prepData) {
File file1 = TestUtil.find_test_file(train);
Key fkey1 = NFSFileVec.make(file1);
Key dest1 = Key.make("train.hex");
File file2 = TestUtil.find_test_file(test);
Key fkey2 = NFSFileVec.make(file2);
Key dest2 = Key.make("test.hex");
GBM gbm = null;
Frame preds = null;
try {
gbm = new GBM();
gbm.source = ParseDataset2.parse(dest1,new Key[]{fkey1});
gbm.response = prepData.prep(gbm.source);
gbm.ntrees = 2;
gbm.max_depth = 3;
gbm.learn_rate = 0.2f;
gbm.min_rows = 10;
gbm.nbins = 1024;
gbm.cols = new int[] {0,1,2};
gbm.run();

// The test data set has a few more enums than the train
Frame ftest = ParseDataset2.parse(dest2,new Key[]{fkey2});
preds = gbm.score(ftest);
// Delete test frame
ftest.remove();

System.err.println(preds);

} catch (Throwable t) {
t.printStackTrace();
} finally {
UKV.remove(fkey1);
UKV.remove(dest1); // Remove original hex frame key
UKV.remove(fkey2);
UKV.remove(dest2);
if( gbm != null ) {
UKV.remove(gbm.dest()); // Remove the model
UKV.remove(gbm.response._key);
gbm.remove(); // Remove GBM Job
if( preds != null ) preds.remove();
}
}
}

}

0 comments on commit 25a501f

Please sign in to comment.