|
2 | 2 |
|
3 | 3 | import java.util.Arrays;
|
4 | 4 |
|
5 |
| -import com.google.gson.JsonObject; |
6 |
| - |
| 5 | +import water.ValueArray.Column; |
7 | 6 | import water.api.Constants;
|
8 | 7 |
|
| 8 | +import com.google.gson.JsonObject; |
| 9 | + |
9 | 10 | /**
|
10 | 11 | * A Model models reality (hopefully).
|
11 | 12 | * A model can be used to 'score' a row, or a collection of rows on any
|
@@ -168,36 +169,116 @@ public final boolean isCompatible( ValueArray data ) {
|
168 | 169 | //return isCompatible(data.colNames());
|
169 | 170 | }
|
170 | 171 |
|
171 |
| - /** Single row scoring. Data can be in any order. No checking on a sane |
172 |
| - * mapping. */ |
173 |
| - public final double score( double[] data, int[] mapping ) { |
174 |
| - assert isCompatible(mapping); |
175 |
| - if( identityMap(mapping) ) { // Shortcut for well-behaved data |
176 |
| - assert data.length == _va._cols.length; |
177 |
| - return score0(data); |
| 172 | + private static class ModelDataAdaptor extends Model { |
| 173 | + final Model M; |
| 174 | + final int _yCol; |
| 175 | + final int [] _xCols; |
| 176 | + final int [][] _catMap; |
| 177 | + final double [] _row; |
| 178 | + |
| 179 | + public ModelDataAdaptor(Model M, int yCol, int [] cols, int [][] catMap){ |
| 180 | + this.M = M; |
| 181 | + _row = MemoryManager.malloc8d(cols.length); |
| 182 | + _xCols = cols; |
| 183 | + _catMap = catMap; |
| 184 | + _yCol = yCol; |
| 185 | + } |
| 186 | + private final double translateCat(int col, int val){ |
| 187 | + int res = _catMap[col][val]; |
| 188 | + return res == -1?Double.NaN:res; |
| 189 | + } |
| 190 | + private final double translateCat(int col, double val){ |
| 191 | + if(Double.isNaN(val))return val; |
| 192 | + assert val == (int)val; |
| 193 | + return translateCat(col, (int)val); |
| 194 | + } |
| 195 | + |
| 196 | + @Override public final double score(double[] data) { |
| 197 | + int j = 0; |
| 198 | + for(int i = 0; i < _xCols.length; ++i) |
| 199 | + _row[j++] = (_catMap == null || _catMap[i] == null)?data[_xCols[i]]:translateCat(i, data[_xCols[i]]); |
| 200 | + return M.score0(_row); |
| 201 | + } |
| 202 | + @Override protected double score0(ValueArray data, int row) { |
| 203 | + int j = 0; |
| 204 | + for(int c:_xCols) |
| 205 | + _row[j++] = (_catMap == null || _catMap[c] == null) |
| 206 | + ?data.datad(row, c) |
| 207 | + :translateCat(c,(int)data.data(row, c)); |
| 208 | + return M.score0(_row); |
| 209 | + } |
| 210 | + @Override protected double score0(ValueArray data, AutoBuffer ab, int row) { |
| 211 | + int j = 0; |
| 212 | + for(int c:_xCols) |
| 213 | + _row[j++] = (_catMap == null || _catMap[c] == null) |
| 214 | + ?data.datad(ab,row, c) |
| 215 | + :translateCat(c,(int)data.data(ab,row, c)); |
| 216 | + return M.score0(_row); |
178 | 217 | }
|
179 |
| - // Build a mapped row and score it. Explodes if mapping is busted. |
180 |
| - double[] d = new double[_va._cols.length]; |
181 |
| - for( int i=0; i<_va._cols.length-1; i++ ) |
182 |
| - d[i] = data[mapping[i]]; |
183 |
| - return score0(d); |
| 218 | + // always should call directly M.score0... |
| 219 | + @Override protected final double score0(double[] data) { |
| 220 | + throw new RuntimeException("should NEVER be called!"); |
| 221 | + } |
| 222 | + @Override public JsonObject toJson() {return M.toJson();} |
| 223 | + // keep only one adaptor layer! (just in case there would be multiple adapt calls...) |
| 224 | + @Override public final Model adapt(ValueArray ary){return M.adapt(ary);} |
| 225 | + @Override public final Model adapt(String [] cols){return M.adapt(cols);} |
| 226 | + } |
| 227 | + |
| 228 | + /** |
| 229 | + * Adapt model for the given dataset. |
| 230 | + * Default behavior is to map columns and categoricals to their original indexes. |
| 231 | + * |
| 232 | + * @param ary - tst dataset |
| 233 | + * @return Model - model adapted to be applied on the given data |
| 234 | + */ |
| 235 | + public Model adapt(ValueArray ary){ |
| 236 | + boolean id = true; |
| 237 | + final int [] colMap = columnMapping(ary.colNames()); |
| 238 | + if(!isCompatible(colMap))throw new IllegalArgumentException("This model uses different columns than those provided"); |
| 239 | + final int[][] catMap = new int[colMap.length][]; |
| 240 | + for(int i = 0; i < colMap.length-1; ++i){ |
| 241 | + Column c = ary._cols[colMap[i]]; |
| 242 | + if(c.isEnum() && !Arrays.deepEquals(_va._cols[i]._domain, c._domain)){ |
| 243 | + id = false; |
| 244 | + catMap[i] = new int[c._domain.length]; |
| 245 | + for(int j = 0; j < c._domain.length; ++j) |
| 246 | + catMap[i][j] = find(c._domain[j],_va._cols[i]._domain); |
| 247 | + } |
| 248 | + } |
| 249 | + return (id&&identityMap(colMap))?this:new ModelDataAdaptor(this,colMap[colMap.length-1],Arrays.copyOf(colMap,colMap.length-1),catMap); |
| 250 | + } |
| 251 | + /** |
| 252 | + * Adapt model for given columns. |
| 253 | + * Only permutes the columns by the column names (factor levels MUST match the training dataset). |
| 254 | + * @param colNames |
| 255 | + * @return |
| 256 | + */ |
| 257 | + public Model adapt(String [] colNames){ |
| 258 | + final int [] colMap = columnMapping(colNames); |
| 259 | + if(!isCompatible(colMap))throw new IllegalArgumentException("This model uses different columns than those provided"); |
| 260 | + if(identityMap(colMap))return this; |
| 261 | + return new ModelDataAdaptor(this, colMap[colMap.length-1],Arrays.copyOf(colMap,colMap.length-1), null); |
| 262 | + } |
| 263 | + public double score(double [] data){ |
| 264 | + return score0(data); |
184 | 265 | }
|
185 | 266 |
|
186 | 267 | // Subclasses implement the scoring logic. They can assume all datasets are
|
187 | 268 | // compatible already
|
| 269 | + protected abstract double score0(double [] data); |
188 | 270 |
|
189 |
| - /** Single row scoring, on properly ordered data */ |
190 |
| - protected abstract double score0( double[] data ); |
191 | 271 |
|
192 | 272 | /** Single row scoring, on a compatible ValueArray (when pushed throw the mapping) */
|
193 |
| - protected abstract double score0( ValueArray data, int row, int[] mapping ); |
| 273 | + protected abstract double score0( ValueArray data, int row); |
194 | 274 |
|
195 | 275 | /** Bulk scoring API, on a compatible ValueArray (when pushed throw the mapping) */
|
196 |
| - protected abstract double score0( ValueArray data, AutoBuffer ab, int row_in_chunk, int[] mapping ); |
| 276 | + protected abstract double score0( ValueArray data, AutoBuffer ab, int row_in_chunk); |
197 | 277 |
|
198 | 278 | public abstract JsonObject toJson();
|
199 | 279 |
|
200 | 280 | public void fromJson(JsonObject json) {
|
201 | 281 | // TODO
|
202 | 282 | }
|
| 283 | + |
203 | 284 | }
|
0 commit comments