Skip to content

Commit 074a722

Browse files
committed
Adapted models to deal with different sets of categoricals in training and tst set.
Model now has method adapt(ValueArray) and adapt(String[]) which produces model adapted for different dataset. Default implementations provided but the Model.class which simply maps the columns and categoricals to their matching pair. Categoricals with no matching value (seen only in testing data, not in training) are mapped to NaN. Also, I removed mapping argument from all score methods.
1 parent 6e80584 commit 074a722

File tree

7 files changed

+126
-41
lines changed

7 files changed

+126
-41
lines changed

src/main/java/hex/DGLM.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -682,8 +682,8 @@ protected double score0( double[] data ) {
682682
p += _beta[idx]*d;
683683
} else {
684684
int d = (int)data[i]; // Enum value
685-
idx += d; // Which expanded column to use
686-
if( idx < _colCatMap[i+1] )
685+
// d can be -1 if we got enum values not seen in training
686+
if(d > 0 && (idx += d) < _colCatMap[i+1])
687687
p += _beta[idx]/* *1.0 */;
688688
else // Enum out of range?
689689
p = Double.NaN;// Can use a zero, or a NaN
@@ -697,12 +697,12 @@ protected double score0( double[] data ) {
697697
}
698698

699699
/** Single row scoring, on a compatible ValueArray (when pushed throw the mapping) */
700-
protected double score0( ValueArray data, int row, int[] mapping ) {
700+
protected double score0( ValueArray data, int row) {
701701
throw H2O.unimpl();
702702
}
703703

704704
/** Bulk scoring API, on a compatible ValueArray (when pushed throw the mapping) */
705-
protected double score0( ValueArray data, AutoBuffer ab, int row_in_chunk, int[] mapping ) {
705+
protected double score0( ValueArray data, AutoBuffer ab, int row_in_chunk) {
706706
throw H2O.unimpl();
707707
}
708708
}

src/main/java/hex/KMeansModel.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,12 @@ protected double score0(double[] data) {
8888
}
8989

9090
/** Single row scoring, on a compatible ValueArray (when pushed throw the mapping) */
91-
protected double score0(ValueArray data, int row, int[] mapping) {
91+
protected double score0(ValueArray data, int row) {
9292
throw H2O.unimpl();
9393
}
9494

9595
/** Bulk scoring API, on a compatible ValueArray (when pushed throw the mapping) */
96-
protected double score0(ValueArray data, AutoBuffer ab, int row_in_chunk, int[] mapping) {
96+
protected double score0(ValueArray data, AutoBuffer ab, int row_in_chunk) {
9797
throw H2O.unimpl();
9898
}
9999

src/main/java/hex/rf/RFModel.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,10 @@ protected double score0(double[] data) {
203203
}
204204

205205
/** Single row scoring, on a compatible ValueArray (when pushed throw the mapping) */
206-
protected double score0( ValueArray data, int row, int[] mapping ) { throw H2O.unimpl(); }
206+
protected double score0( ValueArray data, int row) { throw H2O.unimpl(); }
207207

208208
/** Bulk scoring API, on a compatible ValueArray (when pushed throw the mapping) */
209-
protected double score0(ValueArray data, AutoBuffer ab, int row_in_chunk, int[] mapping) { throw H2O.unimpl(); }
209+
protected double score0(ValueArray data, AutoBuffer ab, int row_in_chunk) { throw H2O.unimpl(); }
210210

211211
@Override public JsonObject toJson() {
212212
JsonObject res = new JsonObject();

src/main/java/water/InternalInterface.java

+10-6
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,18 @@ public double scoreKey( Object modelKey, String [] colNames, double[] row ) {
3333
}
3434
}
3535

36+
public Model adaptModel(Object model, String [] colNames){
37+
return ((Model)model).adapt(colNames);
38+
}
3639
// Call to map the columns and score
37-
public double scoreModel( Object model, String [] colNames, double[] row ) {
38-
Model M = (Model)model;
39-
int[] map = M.columnMapping( colNames);
40-
if( !Model.isCompatible(map) )
41-
throw new IllegalArgumentException("This model uses different columns than those provided");
42-
return M.score(row,map);
40+
public double scoreModel( Object model,double[] row ) {
41+
return ((Model)model).score(row);
4342
}
4443

4544
public JsonObject cloudStatus( ) { return new Cloud().serve().toJson(); }
45+
46+
@Override public double scoreModel(Object model, String[] colNames, double[] row) {
47+
return adaptModel(model, colNames).score(row);
48+
}
49+
4650
}

src/main/java/water/Model.java

+99-18
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import java.util.Arrays;
44

5-
import com.google.gson.JsonObject;
6-
5+
import water.ValueArray.Column;
76
import water.api.Constants;
87

8+
import com.google.gson.JsonObject;
9+
910
/**
1011
* A Model models reality (hopefully).
1112
* A model can be used to 'score' a row, or a collection of rows on any
@@ -168,36 +169,116 @@ public final boolean isCompatible( ValueArray data ) {
168169
//return isCompatible(data.colNames());
169170
}
170171

171-
/** Single row scoring. Data can be in any order. No checking on a sane
172-
* mapping. */
173-
public final double score( double[] data, int[] mapping ) {
174-
assert isCompatible(mapping);
175-
if( identityMap(mapping) ) { // Shortcut for well-behaved data
176-
assert data.length == _va._cols.length;
177-
return score0(data);
172+
private static class ModelDataAdaptor extends Model {
173+
final Model M;
174+
final int _yCol;
175+
final int [] _xCols;
176+
final int [][] _catMap;
177+
final double [] _row;
178+
179+
public ModelDataAdaptor(Model M, int yCol, int [] cols, int [][] catMap){
180+
this.M = M;
181+
_row = MemoryManager.malloc8d(cols.length);
182+
_xCols = cols;
183+
_catMap = catMap;
184+
_yCol = yCol;
185+
}
186+
private final double translateCat(int col, int val){
187+
int res = _catMap[col][val];
188+
return res == -1?Double.NaN:res;
189+
}
190+
private final double translateCat(int col, double val){
191+
if(Double.isNaN(val))return val;
192+
assert val == (int)val;
193+
return translateCat(col, (int)val);
194+
}
195+
196+
@Override public final double score(double[] data) {
197+
int j = 0;
198+
for(int i = 0; i < _xCols.length; ++i)
199+
_row[j++] = (_catMap == null || _catMap[i] == null)?data[_xCols[i]]:translateCat(i, data[_xCols[i]]);
200+
return M.score0(_row);
201+
}
202+
@Override protected double score0(ValueArray data, int row) {
203+
int j = 0;
204+
for(int c:_xCols)
205+
_row[j++] = (_catMap == null || _catMap[c] == null)
206+
?data.datad(row, c)
207+
:translateCat(c,(int)data.data(row, c));
208+
return M.score0(_row);
209+
}
210+
@Override protected double score0(ValueArray data, AutoBuffer ab, int row) {
211+
int j = 0;
212+
for(int c:_xCols)
213+
_row[j++] = (_catMap == null || _catMap[c] == null)
214+
?data.datad(ab,row, c)
215+
:translateCat(c,(int)data.data(ab,row, c));
216+
return M.score0(_row);
178217
}
179-
// Build a mapped row and score it. Explodes if mapping is busted.
180-
double[] d = new double[_va._cols.length];
181-
for( int i=0; i<_va._cols.length-1; i++ )
182-
d[i] = data[mapping[i]];
183-
return score0(d);
218+
// always should call directly M.score0...
219+
@Override protected final double score0(double[] data) {
220+
throw new RuntimeException("should NEVER be called!");
221+
}
222+
@Override public JsonObject toJson() {return M.toJson();}
223+
// keep only one adaptor layer! (just in case there would be multiple adapt calls...)
224+
@Override public final Model adapt(ValueArray ary){return M.adapt(ary);}
225+
@Override public final Model adapt(String [] cols){return M.adapt(cols);}
226+
}
227+
228+
/**
229+
* Adapt model for the given dataset.
230+
* Default behavior is to map columns and categoricals to their original indexes.
231+
*
232+
* @param ary - tst dataset
233+
* @return Model - model adapted to be applied on the given data
234+
*/
235+
public Model adapt(ValueArray ary){
236+
boolean id = true;
237+
final int [] colMap = columnMapping(ary.colNames());
238+
if(!isCompatible(colMap))throw new IllegalArgumentException("This model uses different columns than those provided");
239+
final int[][] catMap = new int[colMap.length][];
240+
for(int i = 0; i < colMap.length-1; ++i){
241+
Column c = ary._cols[colMap[i]];
242+
if(c.isEnum() && !Arrays.deepEquals(_va._cols[i]._domain, c._domain)){
243+
id = false;
244+
catMap[i] = new int[c._domain.length];
245+
for(int j = 0; j < c._domain.length; ++j)
246+
catMap[i][j] = find(c._domain[j],_va._cols[i]._domain);
247+
}
248+
}
249+
return (id&&identityMap(colMap))?this:new ModelDataAdaptor(this,colMap[colMap.length-1],Arrays.copyOf(colMap,colMap.length-1),catMap);
250+
}
251+
/**
252+
* Adapt model for given columns.
253+
* Only permutes the columns by the column names (factor levels MUST match the training dataset).
254+
* @param colNames
255+
* @return
256+
*/
257+
public Model adapt(String [] colNames){
258+
final int [] colMap = columnMapping(colNames);
259+
if(!isCompatible(colMap))throw new IllegalArgumentException("This model uses different columns than those provided");
260+
if(identityMap(colMap))return this;
261+
return new ModelDataAdaptor(this, colMap[colMap.length-1],Arrays.copyOf(colMap,colMap.length-1), null);
262+
}
263+
public double score(double [] data){
264+
return score0(data);
184265
}
185266

186267
// Subclasses implement the scoring logic. They can assume all datasets are
187268
// compatible already
269+
protected abstract double score0(double [] data);
188270

189-
/** Single row scoring, on properly ordered data */
190-
protected abstract double score0( double[] data );
191271

192272
/** Single row scoring, on a compatible ValueArray (when pushed throw the mapping) */
193-
protected abstract double score0( ValueArray data, int row, int[] mapping );
273+
protected abstract double score0( ValueArray data, int row);
194274

195275
/** Bulk scoring API, on a compatible ValueArray (when pushed throw the mapping) */
196-
protected abstract double score0( ValueArray data, AutoBuffer ab, int row_in_chunk, int[] mapping );
276+
protected abstract double score0( ValueArray data, AutoBuffer ab, int row_in_chunk);
197277

198278
public abstract JsonObject toJson();
199279

200280
public void fromJson(JsonObject json) {
201281
// TODO
202282
}
283+
203284
}

src/main/java/water/api/Score.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public static Score create( Properties parms ) {
8787
res.add(ROWS, rows);
8888

8989
// Score the row on the model. May destroy 'd'.
90-
double response = M.score(d,null);
90+
double response = M.score(d);
9191
res.addProperty(CLASS, response);
9292

9393
// Display HTML setup

src/test/java/water/parser/RReaderTest.java

+8-8
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,15 @@ public void testIrisModel() throws Exception {
2222
Key key = Key.make("irisModel");
2323
RReader.run(key, new FileInputStream(file));
2424
RFModel model = UKV.get(key);
25-
int[] map = model.columnMapping(iris.colNames());
26-
Assert.assertTrue(Model.isCompatible(map));
25+
Model m = model.adapt(iris.colNames());
2726

2827
// Can I score on the model now?
29-
double[] row = new double[map.length];
28+
double[] row = new double[iris._cols.length];
3029
for( int i=0; i<iris._numrows; i++ ) {
31-
for( int j=0; j<map.length-1; j++ )
30+
for( int j=0; j<iris._cols.length-1; j++ )
3231
row[j] = iris.datad(i,j);
33-
assertEquals(iris.datad(i,map.length-1),model.score(row,map),0.0001);
32+
assertEquals(iris.datad(i,iris._cols.length-1),m.score(row),0.0001);
3433
}
35-
3634
model.deleteKeys();
3735
UKV.remove(key);
3836
UKV.remove(irisk);
@@ -54,10 +52,11 @@ public void testProstateModel() throws Exception {
5452
// Can I score on the model now?
5553
double[] row = new double[map.length];
5654
int errs = 0;
55+
Model M = model.adapt(pro);
5756
for( int i=0; i<pro._numrows; i++ ) {
5857
for( int j=0; j<map.length; j++ )
5958
row[j] = pro.datad(i,j);
60-
double score = model.score(row,map);
59+
double score = M.score(row);
6160
if( Math.abs(pro.datad(i,classCol) - score) > 0.0001 ) errs++;
6261
}
6362
assertEquals(100,errs);
@@ -84,11 +83,12 @@ public void testCovtypeModel() throws Exception {
8483
// Can I score on the model now?
8584
long start = System.currentTimeMillis();
8685
double[] row = new double[map.length];
86+
Model M = model.adapt(pro);
8787
int errs = 0;
8888
for( int i=0; i<pro._numrows; i++ ) {
8989
for( int j=0; j<map.length; j++ )
9090
row[j] = pro.datad(i,j);
91-
double score = model.score(row,map);
91+
double score = M.score(row);
9292
System.out.println(" "+i+" "+score+" "+(pro.datad(i,classCol)));
9393
if( Math.abs(pro.datad(i,classCol) - score) > 0.0001 ) errs++;
9494
}

0 commit comments

Comments
 (0)