Skip to content

Commit

Permalink
Added prior test to GLM.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasnykodym committed Apr 17, 2015
1 parent 03ba976 commit 146d651
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/main/java/hex/glm/GLM2.java
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,10 @@ public GLM2 setBetaConstraints(Frame f){
beta_constraints = f;
return this;
}

public GLM2 setPrior(double p){
this.prior = p;
return this;
}
static String arrayToString (double[] arr) {
if (arr == null) {
return "(null)";
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/hex/glm/GLMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class GLMModel extends Model implements Comparable<GLMModel> {
@API(help="lambda_value max, smallest lambda_value which drives all coefficients to zero")
final double lambda_max;
@API(help="mean of response in the training dataset")
final double ymu;
public final double ymu;

@API(help="actual expected mean of the response (given by the user before running the model or ymu)")
final double prior;
Expand Down
9 changes: 8 additions & 1 deletion src/test/java/hex/GLMTest2.java
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,20 @@ final static public void testHTML(GLMModel m) {
assertEquals(512.3, model.null_validation.residualDeviance(), 1e-1);
assertEquals(378.3, val.residualDeviance(),1e-1);
assertEquals(396.3, val.aic(), 1e-1);
double prior = 1e-5;
// test the same data and model with prior, should get the same model except for the intercept
new GLM2("GLM test on prostate.",Key.make(),modelKey,new Source(fr,fr.lastVec(),false),Family.binomial).setRegularization(new double []{0},new double[]{0}).setPrior(prior).doInit().fork().get();
GLMModel model2 = DKV.get(modelKey).get();
for(int i = 0; i < model2.beta().length-1; ++i)
assertEquals(model.beta()[i], model2.beta()[i], 1e-8);
assertEquals(model.beta()[model.beta().length-1] -Math.log(model.ymu * (1-prior)/(prior * (1-model.ymu))),model2.beta()[model.beta().length-1],1e-10);
} finally {
fr.delete();
if(model != null)model.delete();
}

}


@Test public void testNoNNegative() {
// glmnet's result:
// res2 <- glmnet(x=M,y=D$CAPSULE,lower.limits=0,family='binomial')
Expand Down

0 comments on commit 146d651

Please sign in to comment.