Skip to content

Commit

Permalink
GLM2 update. Updated scoring with different lambda (GetScoringModelTa…
Browse files Browse the repository at this point in the history
…sk had synchronization bug).
  • Loading branch information
tomasnykodym committed Aug 6, 2014
1 parent fb9f123 commit 3fa93c4
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 85 deletions.
24 changes: 11 additions & 13 deletions src/main/java/hex/glm/GLM2.java
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ public void callback(H2OCountedCompleter cc) {
};
c.addToPendingCount(tasks.length-1);
for(int i = 0; i < tasks.length; ++i)
(tasks[i] = new GLMModel.GetScoringModelTask(c,pgs._glms[i].self(),pgs._glms[i].dest(),curentLambda)).fork();
(tasks[i] = new GLMModel.GetScoringModelTask(c,pgs._glms[i].dest(),curentLambda)).forkTask();
}
}

Expand Down Expand Up @@ -1161,37 +1161,35 @@ public GLMGrid (Key gridKey,Key jobKey, GLM2 [] jobs){
_startTime = System.currentTimeMillis();
}

public static class UnlockGridTsk extends DTask.DKeyTask {
public static class UnlockGridTsk extends DTask.DKeyTask<UnlockGridTsk,GLMGrid> {
final Key _jobKey;

public UnlockGridTsk(Key gridKey, Key jobKey, H2OCountedCompleter cc){
super(cc,gridKey);
_jobKey = jobKey;
}
@Override
public void compute2() {
GLMGrid g = H2O.get(_key).get();
addToPendingCount(g.destination_keys.length);
public void map(GLMGrid g) {
H2OCountedCompleter t = getCurrentTask();
t.addToPendingCount(g.destination_keys.length);
for(Key k:g.destination_keys)
new GLMModel.UnlockModelTask(this,k,_jobKey).forkTask();
new GLMModel.UnlockModelTask(t,k,_jobKey).forkTask();
g.unlock(_jobKey);
tryComplete();
}
}

public static class DeleteGridTsk extends DTask.DKeyTask {
public static class DeleteGridTsk extends DTask.DKeyTask<DeleteGridTsk,GLMGrid> {
public DeleteGridTsk(H2OCountedCompleter cc, Key gridKey){
super(cc,gridKey);
}
@Override
public void compute2() {
GLMGrid g = H2O.get(_key).get();
addToPendingCount(g.destination_keys.length);
public void map(GLMGrid g) {
H2OCountedCompleter t = getCurrentTask();
t.addToPendingCount(g.destination_keys.length);
for(Key k:g.destination_keys)
new GLMModel.DeleteModelTask(this,k).forkTask();
new GLMModel.DeleteModelTask(t,k).forkTask();
assert g.is_unlocked():"not unlocked??";
g.delete();
tryComplete();
}
}
@Override
Expand Down
68 changes: 24 additions & 44 deletions src/main/java/hex/glm/GLMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,49 +82,42 @@ public double devExplained(){
return 1.0 - val.residual_deviance/val.null_deviance;
}

public static class UnlockModelTask extends DTask.DKeyTask{
final Key _modelKey;
public static class UnlockModelTask extends DTask.DKeyTask<UnlockModelTask,GLMModel>{
final Key _jobKey;

public UnlockModelTask(H2OCountedCompleter cmp, Key modelKey, Key jobKey){
super(cmp,modelKey);
_modelKey = modelKey;
_jobKey = jobKey;
}
@Override
public void compute2() {
GLMModel m = H2O.get(_modelKey).get();
public void map(GLMModel m) {
Key [] xvals = m.xvalModels();
if(xvals != null){
addToPendingCount(xvals.length);
H2OCountedCompleter t = getCurrentTask();
t.addToPendingCount(xvals.length);
for(int i = 0; i < xvals.length; ++i)
new UnlockModelTask(this,xvals[i],_jobKey).forkTask();
new UnlockModelTask(t,xvals[i],_jobKey).forkTask();
}
m.unlock(_jobKey);
tryComplete();
}
}

public static class DeleteModelTask extends DTask.DKeyTask{
public static class DeleteModelTask extends DTask.DKeyTask<DeleteModelTask,GLMModel>{
final Key _modelKey;

public DeleteModelTask(H2OCountedCompleter cmp, Key modelKey){
super(cmp,modelKey);
_modelKey = modelKey;
}
@Override
public void compute2() {
if(H2O.get(_modelKey) != null) {
GLMModel m = H2O.get(_modelKey).get();
Key[] xvals = m.xvalModels();
if (xvals != null) {
addToPendingCount(xvals.length);
for (int i = 0; i < xvals.length; ++i)
new DeleteModelTask(this, xvals[i]).forkTask();
}
m.delete();
public void map(GLMModel m) {
Key[] xvals = m.xvalModels();
if (xvals != null) {
H2OCountedCompleter t = getCurrentTask();
t.addToPendingCount(xvals.length);
for (int i = 0; i < xvals.length; ++i)
new DeleteModelTask(t, xvals[i]).forkTask();
}
tryComplete();
m.delete();
}
}

Expand Down Expand Up @@ -308,34 +301,21 @@ public static void setSubmodel(H2OCountedCompleter cmp, Key modelKey, final doub
setSubmodel(cmp,modelKey,lambda,beta,norm_beta,iteration,runtime,sparseCoef,null);
}

public static class GetScoringModelTask extends DTask<GetScoringModelTask>{
final Key _modelKey;
final Key _jobKey;
public static class GetScoringModelTask extends DTask.DKeyTask<GetScoringModelTask,GLMModel> {
final double _lambda;
public GLMModel _res;
public GetScoringModelTask(H2OCountedCompleter cmp, Key jobKey, Key modelKey, double lambda){
super(cmp);
_jobKey = jobKey;
_modelKey = modelKey;
public GetScoringModelTask(H2OCountedCompleter cmp, Key modelKey, double lambda){
super(cmp,modelKey);
_lambda = lambda;
}
@Override
public void compute2() {
if(_modelKey.home()){
Value v = H2O.get(_modelKey);
if(v == null && _jobKey != null){
assert !Job.isRunning(_jobKey):"missing model (" + _modelKey + " ) while job is still running";
throw new Job.JobCancelledException();
} else {
_res = (GLMModel) v.get().clone();
Submodel sm = _res.submodelForLambda(_lambda);
assert sm != null : "GLM[" + _modelKey + "]: missing submodel for lambda " + _lambda;
sm = (Submodel) sm.clone();
_res.submodels = new Submodel[]{sm};
_res.setSubmodelIdx(0);
}
tryComplete();
} else new RPC(_modelKey.home_node(),this).call();
public void map(GLMModel m) {
_res = m.clone();
Submodel sm = _res.submodelForLambda(_lambda);
assert sm != null : "GLM[" + m._key + "]: missing submodel for lambda " + _lambda;
sm = (Submodel) sm.clone();
_res.submodels = new Submodel[]{sm};
_res.setSubmodelIdx(0);
}
}

Expand Down
52 changes: 38 additions & 14 deletions src/main/java/water/DTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -137,23 +137,47 @@ private RuntimeException barf(String method) {

/**
* Task to be executed at home of the given key.
* Calling submit will either submitTask if key is local or
* invoke RPC to the key's home node.
* Basically a wrapper around DTask which enables us to bypass
* remote/local distinction (RPC versus submitTask).
*/
public static abstract class DKeyTask extends DTask {
protected final Key _key;
public static abstract class DKeyTask<T extends DKeyTask,V extends Iced> extends Iced{
private transient H2OCountedCompleter _task;
public DKeyTask(final Key k) {this(null,k);}
public DKeyTask(H2OCountedCompleter cmp,final Key k) {
final DKeyTask dk = this;

public DKeyTask(H2OCountedCompleter cmp,Key k) {
super(cmp);
_key = k;
final DTask dt = new DTask(cmp) {
@Override
public void compute2() {
Value val = H2O.get(k);
if(val != null) {
V v = val.get();
dk._task = this;
dk.map(v);
}
tryComplete();
}
};
if(k.home()) _task = dt;
else {
_task = new H2OCountedCompleter() {
@Override
public void compute2() {
new RPC(k.home_node(),dt).addCompleter(this).call();
}
};
}
}
public void submitTask() {
if (_key.home()) H2O.submitTask(this);
else RPC.call(_key.home_node(), this);
}
public void forkTask() {
if (_key.home()) fork();
else RPC.call(_key.home_node(), this);
protected H2OCountedCompleter getCurrentTask(){ return _task;}
protected abstract void map(V v);
public void submitTask() {H2O.submitTask(_task);}
public void forkTask() {_task.fork();}

public T invokeTask() {
assert _task.getCompleter() == null;
submitTask();
_task.join();
return (T)this;
}
}
}
11 changes: 3 additions & 8 deletions src/main/java/water/Lockable.java
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,10 @@ private class Update extends TAtomic<Lockable> {
}
}
public static void unlock_lockable(final Key lockable, final Key job){
H2O.H2OEmptyCompleter cmp = new H2O.H2OEmptyCompleter();
new DTask.DKeyTask(cmp,lockable){
new DTask.DKeyTask<DTask.DKeyTask,Lockable>(null,lockable){
@Override
public void compute2() {
H2O.get(_key).<Lockable>get().unlock(job);
tryComplete();
}
}.forkTask();
cmp.join();
public void map(Lockable l) { l.unlock(job);}
}.invokeTask();
}
// -----------
// Atomically set a new version of self & unlock.
Expand Down
8 changes: 2 additions & 6 deletions src/main/java/water/api/GLMPredict.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,12 @@ public static String link(Key k, double lambda, String content) {
if( model == null )
throw new IllegalArgumentException("Model is required to perform validation!");
final Key predictionKey = ( prediction == null )?Key.make("__Prediction_" + Key.make()):prediction;
GLMModel.GetScoringModelTask task = new GLMModel.GetScoringModelTask(null,null, model,lambda);
H2O.submitTask(task);
task.get();
GLMModel model= task._res;
GLMModel m = new GLMModel.GetScoringModelTask(null, model,lambda).invokeTask()._res;
// Create a new random key
if ( prediction == null )
prediction = Key.make("__Prediction_" + Key.make());
Frame fr = new Frame(prediction,new String[0],new Vec[0]).delete_and_lock(null);
if( model instanceof Model )
fr = (( Model)model).score(data);
fr = m.score(data);
fr = new Frame(prediction,fr._names,fr.vecs()); // Jam in the frame key
fr.unlock(null);
return Inspect2.redirect(this, prediction.toString());
Expand Down

0 comments on commit 3fa93c4

Please sign in to comment.