Skip to content

Commit

Permalink
GLM fix: fixed issue with constant offset column (was incorrectly bei…
Browse files Browse the repository at this point in the history
…ng ignored) and added test for it.
  • Loading branch information
tomasnykodym committed Mar 23, 2015
1 parent f341727 commit 1ee5e75
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 6 deletions.
5 changes: 5 additions & 0 deletions smalldata/glm_test/abcd.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
A B C D E
1 0 0 0 1
0 1 0 0 1
0 0 1 0 1
1 1 0 1 1
19 changes: 14 additions & 5 deletions src/main/java/hex/glm/GLM2.java
Original file line number Diff line number Diff line change
Expand Up @@ -469,22 +469,30 @@ private double computeIntercept(DataInfo dinfo, double ymu, Vec offset, Vec resp
//pass
}
toEnum = family == Family.binomial && (!response.isEnum() && (response.min() < 0 || response.max() > 1));
String offsetName = "";
int offsetId = -1;
if(offset != null) {
offsetId = source2.find(offset);
offsetName = source2.names()[offsetId];
source2.remove(offsetId);
}

Frame fr = DataInfo.prepareFrame(source2, response, ignored_cols, toEnum, true, true);
if(offset != null){ // now put the offset just before response
int id = fr.find(offset);
String offsetName = fr.names()[id];
String responseName = fr.names()[fr.numCols()-1];
Vec responseVec = fr.remove(fr.numCols()-1);
fr.add(offsetName, fr.remove(id));
fr.add(offsetName, offset);
fr.add(responseName,responseVec);
}
TransformType dt = TransformType.NONE;
if (standardize)
dt = intercept ? TransformType.STANDARDIZE : TransformType.DESCALE;
_srcDinfo = new DataInfo(fr, 1, intercept, use_all_factor_levels || lambda_search, dt, DataInfo.TransformType.NONE);
if(offset != null && dt != TransformType.NONE) { // do not standardize offset
_srcDinfo._normMul[_srcDinfo._normMul.length-1] = 1;
_srcDinfo._normSub[_srcDinfo._normSub.length-1] = 0;
if(_srcDinfo._normMul != null)
_srcDinfo._normMul[_srcDinfo._normMul.length-1] = 1;
if(_srcDinfo._normSub != null)
_srcDinfo._normSub[_srcDinfo._normSub.length-1] = 0;
}
if (!intercept && _srcDinfo._cats > 0)
throw new IllegalArgumentException("Models with no intercept are only supported with all-numeric predictors.");
Expand Down Expand Up @@ -563,6 +571,7 @@ else if(_bgs != null)
_lbs[i] = 0;
}
} catch(RuntimeException e) {
e.printStackTrace();
cleanup();
throw e;
}
Expand Down
1 change: 1 addition & 0 deletions src/main/java/hex/glm/GLMValidation.java
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ public void add(GLMValidation v){
else for(int i = 0; i < _cms.length; ++i)_cms[i].add(v._cms[i]);
}
public final double residualDeviance(){return residual_deviance;}
public final double nullDeviance(){return null_deviance;}
public final long resDOF(){return nobs - _rank -1;}
public double auc(){return auc;}
public double aic(){return aic;}
Expand Down
8 changes: 7 additions & 1 deletion src/test/java/hex/GLMTest2.java
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,13 @@ final static public void testHTML(GLMModel m) {
assertEquals(2015, model.null_validation.residualDeviance(),1e-1);
assertEquals(1516, val.residualDeviance(),1e-1);
assertEquals(1532, val.aic(),1e-1);
// test constant offset (had issues with constant-column filtering)
fr = getFrameForFile(parsed, "smalldata/glm_test/abcd.csv", new String[0], "D");
new GLM2("GLM testing constant offset on a toy dataset.",Key.make(),modelKey,new GLM2.Source(fr,fr.vec("D"),false,false,fr.vec("E")),Family.gaussian).setRegularization(new double []{0},new double[]{0}).doInit().fork().get();
// just test it does not blow up and the model is sane
model = DKV.get(modelKey).get();
assertEquals(model.coefficients().get("E"),1,0); // should be exactly 1
assertTrue(model.validation().residualDeviance() <= model.validation().nullDeviance());
} finally {
fr.delete();
if(model != null)model.delete();
Expand Down Expand Up @@ -390,7 +397,6 @@ final static public void testHTML(GLMModel m) {
fr = DKV.get(k).get();
fr.remove("ID");
Key betaConsKey = Key.make("beta_constraints");

//String[] cfs1 = new String[]{"RACE", "AGE", "DPROS", "DCAPS", "PSA", "VOL", "GLEASON","Intercept"};
//double[] vals = new double[]{0, 0, 0.54788332,0.53816534, 0.02380097, 0, 0.98115670,-8.945984};
// [AGE, RACE, DPROS, DCAPS, PSA, VOL, GLEASON, Intercept]
Expand Down

0 comments on commit 1ee5e75

Please sign in to comment.