Skip to content

Commit

Permalink
added test of proximal interface of glm.
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasnykodym committed Feb 28, 2015
1 parent d8605e7 commit 4cc7aa5
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions src/test/java/hex/GLMTest2.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package hex;

import hex.FrameTask.DataInfo.TransformType;
import hex.glm.GLM2.Source;
import hex.glm.GLMParams.Link;
import hex.glm.GLMTask.GLMIterationTask;
import org.junit.Assert;

import static junit.framework.Assert.assertTrue;
Expand All @@ -14,6 +16,7 @@
import water.deploy.Node;
import water.deploy.NodeVM;
import water.fvec.*;
import water.util.ModelUtils;

import java.io.File;
import java.util.Arrays;
Expand Down Expand Up @@ -367,6 +370,53 @@ final static public void testHTML(GLMModel m) {
}
}

@Test public void testProximal() {
// glmnet's result:
// res2 <- glmnet(x=M,y=D$CAPSULE,lower.limits=-.5,upper.limits=.5,family='binomial')
// res2$beta[,58]
// AGE RACE DPROS PSA VOL GLEASON
// -0.00616326 -0.50000000 0.50000000 0.03628192 -0.01249324 0.50000000 // res2$a0[100]
// res2$a0[58]
// s57
// -4.155864
// lambda = 0.001108, null dev = 512.2888, res dev = 379.7597
Key parsed = Key.make("prostate_parsed");
Key modelKey = Key.make("prostate_model");
GLMModel model = null;
Frame fr = getFrameForFile(parsed, "smalldata/logreg/prostate.csv", new String[]{"ID"}, "CAPSULE");
Key k = Key.make("rebalanced");
H2O.submitTask(new RebalanceDataSet(fr, k, 64)).join();
fr.delete();
fr = DKV.get(k).get();
fr.remove("ID");
Key betaConsKey = Key.make("beta_constraints");

//String[] cfs1 = new String[]{"RACE", "AGE", "DPROS", "DCAPS", "PSA", "VOL", "GLEASON","Intercept"};
//double[] vals = new double[]{0, 0, 0.54788332,0.53816534, 0.02380097, 0, 0.98115670,-8.945984};
// [AGE, RACE, DPROS, DCAPS, PSA, VOL, GLEASON, Intercept]
FVecTest.makeByteVec(betaConsKey, "names, beta_given, rho\n AGE, .5, 2\n RACE, .75, 1 \n DPROS, -.5, 10 \n DCAPS, .4, .5 \n PSA, -.15, 25\n VOL, .1, .5\nGLEASON, -.5, .5\n Intercept, 0, 0 \n");
Frame betaConstraints = ParseDataset2.parse(parsed, new Key[]{betaConsKey});
try {
// H2O differs on intercept and race, same residual deviance though
GLM2.Source src = new GLM2.Source((Frame)fr.clone(), fr.vec("CAPSULE"), true, true);
new GLM2("GLM offset test on prostate.", Key.make(), modelKey, src, Family.binomial).setNonNegative(false).setRegularization(new double[]{0},new double[]{0.000}).setBetaConstraints(betaConstraints).doInit().fork().get(); //.setHighAccuracy().doInit().fork().get();
model = DKV.get(modelKey).get();
System.out.println(Arrays.toString(model.coefficients_names));
System.out.println(model.coefficients());
fr.add("CAPSULE", fr.remove("CAPSULE"));
// public GLMIterationTask(int noff, Key jobKey, DataInfo dinfo, GLMParams glm, boolean computeGram, boolean validate, boolean computeGradient, double[] beta, double ymu, double reg, float[] thresholds, H2OCountedCompleter cmp) {
DataInfo dinfo = new DataInfo(fr, 1, true, false, TransformType.NONE, DataInfo.TransformType.NONE);
GLMIterationTask glmt = new GLMTask.GLMIterationTask(0,null, dinfo, new GLMParams(Family.binomial),false, true, true, model.beta(), 0, 1.0/380, ModelUtils.DEFAULT_THRESHOLDS, null).doAll(dinfo._adaptedFrame);
double [] beta = model.beta();
double [] grad = glmt.gradient(0,0);
// for(int i = 0; i < beta.length; ++i) {
// System.out.println("grad[" + i + "] = " + grad[i] + ", penaltyGrad = " + betaConstraints.vec("rho").at(i) * (beta[i] - betaConstraints.vec("beta_given").at(i)) + ", res = " + (grad[i] + betaConstraints.vec("rho").at(i) * (beta[i] - betaConstraints.vec("beta_given").at(i))));
// }
} finally {
fr.delete();
if(model != null)model.delete();
}
}


@Test public void testNoIntercept(){
Expand All @@ -390,6 +440,7 @@ final static public void testHTML(GLMModel m) {
Frame score = null;
try{
// H2O differs on intercept and race, same residual deviance though
// VOL=0.07800102000216368, AGE=0.4763358013317742, Intercept=-33.31091128704016, DCAPS=1.3862187988985282, PSA=-0.14941426719071213, DPROS=-0.47448094905875854, RACE=0.7167868302661414, GLEASON=0.013676100109077757}
String [] cfs1 = new String [] {"RACE", "AGE", "DPROS", "DCAPS", "PSA", "VOL", "GLEASON"};
double [] vals = new double [] { -1.23262,-0.07205, 0.47899, 0.13934, 0.03626, -0.01155, 0.63645};
new GLM2("GLM offset test on prostate.",Key.make(),modelKey,new GLM2.Source((Frame)fr.clone(),fr.vec("CAPSULE"),false,false),Family.binomial).setRegularization(new double[]{0},new double[]{0}).doInit().fork().get(); //.setHighAccuracy().doInit().fork().get();
Expand Down

0 comments on commit 4cc7aa5

Please sign in to comment.