Skip to content

Commit

Permalink
Handle cancel across the board
Browse files Browse the repository at this point in the history
Rename the cancel function, which was confusing... and leading people to
use the wrong cancel function, which didn't actually test for the job
being canceled (except for the API driver code).
  • Loading branch information
cliffclick committed Jan 31, 2014
1 parent ab7088f commit c657cdd
Show file tree
Hide file tree
Showing 23 changed files with 52 additions and 48 deletions.
2 changes: 1 addition & 1 deletion h2o-samples/src/main/java/samples/NeuralNetMnist.java
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ protected void startTraining(Layer[] ls) {
final AtomicInteger evals = new AtomicInteger(1);
timer.schedule(new TimerTask() {
@Override public void run() {
if( NeuralNetMnist.this.cancelled() )
if( !Job.isRunning(self()) )
timer.cancel();
else {
double time = (System.nanoTime() - start) / 1e9;
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/hex/DGLM.java
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,7 @@ public GLMValidation xvalidate(Job job, ValueArray ary, int folds, double[] thre
tsk.reduce(child);
}
}
if( job.cancelled() ) throw new JobCancelledException();
if( !Job.isRunning(job.self()) ) throw new JobCancelledException();
GLMValidation res = new GLMValidation(_key, tsk._models, ErrMetric.SUMC, thresholds, System.currentTimeMillis() - t1);
if( _vals == null ) _vals = new GLMValidation[] { res };
else {
Expand Down Expand Up @@ -1781,7 +1781,7 @@ public GLMValidationFunc(GLMModel m, GLMParams params, double[] beta, double[] t
}
NewRowVecTask<GLMValidation> tsk = new NewRowVecTask<GLMValidation>(job, this, data);
tsk.invoke(data._ary._key);
if( job != null && job.cancelled() ) throw new JobCancelledException();
if( job != null && !Job.isRunning(job.self()) ) throw new JobCancelledException();
GLMValidation res = tsk._result;
res._time = System.currentTimeMillis() - t1;
if( _glmp._family._family != Family.binomial ) res._err = Math.sqrt(res._err / res._n);
Expand Down Expand Up @@ -1915,7 +1915,7 @@ public static GLMJob startGLMJob(Key dest, final DataFrame data, final LSMSolver
@Override public void compute2() {
try {
buildModel(job, job.dest(), data, lsm, params, beta, xval, parallel);
assert !job.cancelled();
assert Job.isRunning(job.self());
job.remove();
} catch( JobCancelledException e ) {
Lockable.delete(job.dest());
Expand Down Expand Up @@ -1986,7 +1986,7 @@ private static GLMModel buildModel(Job job, Key resKey, DataFrame data, LSMSolve
currentModel.delete_and_lock(job.self()); // Lock the new model
if( params._family._family != Family.gaussian ) do { // IRLSM
if( oldBeta == null ) oldBeta = MemoryManager.malloc8d(data.expandedSz());
if( job.cancelled() ) throw new JobCancelledException();
if( !Job.isRunning(job.self()) ) throw new JobCancelledException();
double[] b = oldBeta;
oldBeta = (gramF._beta = newBeta);
newBeta = b;
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/hex/FrameTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ protected void chunkDone(){}
* and adapts response according to the CaseMode/CaseValue if set.
*/
@Override public final void map(Chunk [] chunks, NewChunk [] outputs){
if(_job != null && _job.cancelled())throw new JobCancelledException();
if(_job != null && !Job.isRunning(_job.self()))throw new JobCancelledException();
chunkInit();
final int nrows = chunks[0]._len;
double [] nums = MemoryManager.malloc8d(_dinfo._nums);
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/hex/GLMGrid.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public void compute2() {
Futures fs = new Futures();
ValueArray ary = DKV.get(_aryKey).get();
try{
for( int l1 = 1; l1 <= _job._lambdas.length && !_job.cancelled(); l1++ ) {
for( int l1 = 1; l1 <= _job._lambdas.length && Job.isRunning(_job.self()); l1++ ) {
Key mkey = GLMModel.makeKey(false);
GLMModel m = DGLM.buildModel(_job,mkey,ary, _job._xs, _standardize, new ADMMSolver(_job._lambdas[N-l1], _job._alphas[_aidx]), _job._glmp,beta,_job._xfold, _parallel);
beta = m._normBeta.clone();
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/hex/GridSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class GridSearch extends Job {
UKV.put(destination_key, this);
int max = jobs[0].gridParallelism();
int head = 0, tail = 0;
while( head < jobs.length && !cancelled() ) {
while( head < jobs.length && isRunning(self()) ) {
if( tail - head < max && tail < jobs.length )
jobs[tail++].fork();
else {
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/hex/KMeans.java
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ private void run(KMeansModel res, ValueArray va, int k, Initialization init, int
sampler.invoke(va._key);
clusters = Utils.append(clusters, sampler._clust2);

if( cancelled() ) {
if( !isRunning(self()) ) {
remove();
return;
}
Expand Down Expand Up @@ -140,7 +140,7 @@ private void run(KMeansModel res, ValueArray va, int k, Initialization init, int
res.update(self());
if( res._iteration >= res._maxIter )
break;
if( cancelled() )
if( !isRunning(self()) )
break;
}
res.unlock(self());
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/hex/KMeans2.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public KMeans2() {
sampler.doAll(vecs);
clusters = Utils.append(clusters, sampler._sampled);

if( cancelled() )
if( !isRunning(self()) )
return Status.Done;
model.centers = normalize ? denormalize(clusters, vecs) : clusters;
model.total_within_SS = sqr._sqr;
Expand Down Expand Up @@ -149,7 +149,7 @@ public KMeans2() {
fr2.delete_and_lock(self()).unlock(self());
break;
}
if( cancelled() )
if( !isRunning(self()) )
break;
}
model.unlock(self());
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/hex/KMeansModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ public static Job run(final Key dest, final KMeansModel model, final ValueArray
*/
@Override public void map(Key key) {
assert key.home();
if( !_job.cancelled() ) {
if( Job.isRunning(_job.self()) ) {
ValueArray va = DKV.get(_arykey).get();
AutoBuffer bits = va.getChunk(key);
long startRow = va.startRow(ValueArray.getChunkIndex(key));
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/hex/NeuralNet.java
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ void startTrain() {

if (num >= num_samples_total) break;
if (mode != MapReduce) {
if (cancelled() || !running) break;
if (!isRunning(self()) || !running) break;
} else {
if (!running) break; //MapReduce calls cancel() early, we are waiting for running = false
}
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/hex/NewRowVecTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public static abstract class RowFunc<T extends Iced> extends Iced {
public T apply(Job j, DataFrame data) throws JobCancelledException{
NewRowVecTask<T> tsk = new NewRowVecTask<T>(j,this, data);
tsk.invoke(data._ary._key);
if(j != null && j.cancelled())throw new JobCancelledException();
if(j != null && !Job.isRunning(j.self()))throw new JobCancelledException();
return tsk._result;
}

Expand Down Expand Up @@ -213,7 +213,7 @@ public long memOverheadPerChunk(){

@Override
public void map(Key key) {
if(_job != null && _job.cancelled())return;
if(_job != null && !Job.isRunning(_job.self()))return;
T result = _func.newResult();
Sampling s = _data.getSampling();
AutoBuffer bits = _data._ary.getChunk(key);
Expand Down Expand Up @@ -274,7 +274,7 @@ public void map(Key key) {
}

@Override public void reduce(DRemoteTask drt) {
if(_job != null && _job.cancelled()) return;
if(_job != null && !Job.isRunning(_job.self())) return;
NewRowVecTask<T> rv = (NewRowVecTask<T>)drt;
assert _result != rv._result;
_result = (_result != null)?_func.reduce(_result, rv._result):rv._result;
Expand Down
10 changes: 5 additions & 5 deletions src/main/java/hex/Trainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public void run() {
for( ; _limit == 0 || _processed < _limit; _processed++ ) {
step();
input.move();
if( _job != null && (Job.cancelled(_job) || !NeuralNet.running ) )
if( _job != null && (!Job.isRunning(_job) || !NeuralNet.running ) )
break;
}
}
Expand Down Expand Up @@ -169,7 +169,7 @@ public Threaded(Layer[] ls, double epochs, final Key job, int threads) {
_threads[t] = new Thread("H2O Trainer " + t) {
@Override public void run() {
for( long i = 0; _stepsPerThread == 0 || i < _stepsPerThread; i++ ) {
if( job != null && (Job.cancelled(job) || !NeuralNet.running ) )
if( job != null && (!Job.isRunning(job) || !NeuralNet.running ) )
break;
try {
trainer.step();
Expand Down Expand Up @@ -326,7 +326,7 @@ static class Descent extends MRTask2<Descent> {
final boolean home = _key.home();
Thread thread = new Thread() {
@Override public void run() {
while( _job == null || !Job.cancelled(_job) ) {
while( _job == null || Job.isRunning(_job) ) {
if( !home )
_node.sync();
else {
Expand Down Expand Up @@ -377,7 +377,7 @@ private static class DescentEpoch extends NodeTask {
int _count;

@Override public void compute2() {
if( (_count < 0 || --_count >= 0) && (_node._job == null || !Job.cancelled(_node._job)) ) {
if( (_count < 0 || --_count >= 0) && (_node._job == null || Job.isRunning(_node._job)) ) {
for( Chunk[] cs : _node._chunks ) {
DescentChunk task = new DescentChunk();
task._node = _node;
Expand All @@ -397,7 +397,7 @@ static class DescentChunk extends NodeTask {
Chunk[] _cs;

@Override public void compute2() {
if( _node._job == null || (!Job.cancelled(_node._job) && NeuralNet.running)) {
if( _node._job == null || (Job.isRunning(_node._job) && NeuralNet.running)) {
Layer[] clones = new Layer[_node._ls.length];
ChunksInput input = new ChunksInput(Utils.remove(_cs, _cs.length - 1), (VecsInput) _node._ls[0]);
clones[0] = input;
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/hex/drf/DRF.java
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ public static String link(Key k, String content) {
// TODO: parallelize more? build more than k trees at each time, we need to care about temporary data
// Idea: launch more DRF at once.
ktrees = buildNextKTrees(fr,_mtry,sample_rate,rand);
if( cancelled() ) break; // If canceled during building, do not bulkscore
if( !Job.isRunning(self()) ) break; // If canceled during building, do not bulkscore

// Check latest predictions
tstats.updateBy(ktrees);
Expand Down Expand Up @@ -303,7 +303,7 @@ private DTree[] buildNextKTrees(Frame fr, int mtrys, float sample_rate, Random r
// Adds a layer to the trees each pass.
int depth=0;
for( ; depth<max_depth; depth++ ) {
if( cancelled() ) return null;
if( !Job.isRunning(self()) ) return null;

hcs = buildLayer(fr, ktrees, leafs, hcs, true, build_tree_per_node);

Expand Down
4 changes: 2 additions & 2 deletions src/main/java/hex/gbm/GBM.java
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public static String link(Key k, String content) {

// ESL2, page 387, Step 2b ii, iii, iv
ktrees = buildNextKTrees(fr);
if( cancelled() ) break; // If canceled during building, do not bulkscore
if( !Job.isRunning(self()) ) break; // If canceled during building, do not bulkscore

// Check latest predictions
tstats.updateBy(ktrees);
Expand Down Expand Up @@ -282,7 +282,7 @@ private DTree[] buildNextKTrees(Frame fr) {
// Adds a layer to the trees each pass.
int depth=0;
for( ; depth<max_depth; depth++ ) {
if( cancelled() ) return null;
if( !Job.isRunning(self()) ) return null;

hcs = buildLayer(fr, ktrees, leafs, hcs, false, false);

Expand Down
10 changes: 6 additions & 4 deletions src/main/java/hex/glm/GLM2.java
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ public Iteration(GLMModel model, LSMSolver solver, DataInfo dinfo, H2OCountedCom

@Override public Iteration clone(){return new Iteration(_model,_solver,_dinfo,_fjt);}
@Override public void callback(final GLMIterationTask glmt) {
if(!cancelled()){
if(isRunning(self())){
double [] newBeta = MemoryManager.malloc8d(glmt._xy.length);
double [] newBetaDeNorm = null;
_solver.solve(glmt._gram, glmt._xy, glmt._yy, newBeta);
Expand All @@ -255,19 +255,21 @@ public Iteration(GLMModel model, LSMSolver solver, DataInfo dinfo, H2OCountedCom
}
boolean done = false;
// _model = _oldModel.clone();
done = done || _glm.family == Family.gaussian || (glmt._iter+1) == max_iter || beta_diff(glmt._beta, newBeta) < beta_epsilon || cancelled();
done = done || _glm.family == Family.gaussian || (glmt._iter+1) == max_iter || beta_diff(glmt._beta, newBeta) < beta_epsilon || !isRunning(self());
_model.setLambdaSubmodel(_lambdaIdx,newBetaDeNorm == null?newBeta:newBetaDeNorm, newBetaDeNorm==null?null:newBeta, glmt._iter+1);
if(done){
H2OCallback fin = new H2OCallback<GLMValidationTask>() {
@Override public void callback(GLMValidationTask tsk) {
boolean improved = _model.setAndTestValidation(_lambdaIdx,tsk._res);
_model.unlock(self());
if(!diverged && (improved || _runAllLambdas) && _lambdaIdx < (lambda.length-1) ){ // continue with next lambda value?
_model.update(self());
_solver = new ADMMSolver(lambda[++_lambdaIdx],alpha[0]);
glmt._val = null;
Iteration.this.callback(glmt);
} else // nope, we're done
} else { // nope, we're done
_model.unlock(self());
_fjt.tryComplete(); // signal we're done to anyone waiting for the job
}
}
@Override public boolean onExceptionalCompletion(Throwable ex, CountedCompleter cc){
_fjt.completeExceptionally(ex);
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/hex/glm/GLMModelView.java
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ private void pprintTime(StringBuilder sb, long t){
glm_model = v.get();
if(Double.isNaN(lambda))lambda = glm_model.lambdas[glm_model.best_lambda_idx];
}
if( jjob == null || jjob.end_time > 0 || jjob.cancelled() )
if( jjob == null || jjob.end_time > 0 || jjob.isCancelled() )
return Response.done(this);
return Response.poll(this,(int)(100*jjob.progress()),100,"_modelKey",_modelKey.toString());
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/hex/rf/Tree.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ private StringBuffer computeStatistics() {

// Actually build the tree
@Override public void compute2() {
if(!_job.cancelled()) {
if(Job.isRunning(_job.self())) {
Timer timer = new Timer();
_stats[0] = new ThreadLocal<Statistic>();
_stats[1] = new ThreadLocal<Statistic>();
Expand Down
16 changes: 9 additions & 7 deletions src/main/java/water/Job.java
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,15 @@ public void cancel(final String msg) {
}
}.invoke(LIST);
}

protected void onCancelled() {
}
public boolean cancelled() { return end_time == CANCELLED_END_TIME; }
public static boolean cancelled(Key key) {
return DKV.get(key) == null;
}
// This querys the *current object* for its status.
// Only valid if you have a Job object that is being updated by somebody.
public boolean isCancelled() { return end_time == CANCELLED_END_TIME; }

// Check the K/V store to see the Job is still running
public static boolean isRunning(Key key) { return DKV.get(key) != null; }

public void remove() {
end_time = System.currentTimeMillis();
Expand Down Expand Up @@ -563,7 +565,7 @@ public static boolean isJobEnded(Key jobkey) {
done = true;
}

if (jobs[i].cancelled()) {
if (jobs[i].isCancelled()) {
done = true;
}

Expand Down Expand Up @@ -663,7 +665,7 @@ public ChunkProgressJob(long chunksTotal, Key destinationKey) {
}

public void updateProgress(final int c) { // c == number of processed chunks
if( !cancelled() ) {
if( isRunning(self()) ) {
new TAtomic<ChunkProgress>() {
@Override public ChunkProgress atomic(ChunkProgress old) {
if( old == null ) return null;
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/water/ValueArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ static private Futures readPut(Key key, InputStream is, Job job, final Futures f
off+=sz;
szl += off;
if( off<CHUNK_SZ ) break;
if( job != null && job.cancelled() ) break;
if( job != null && !Job.isRunning(job.self()) ) break;
final Key ckey = getChunkKey(cidx++,key);
final Value val = new Value(ckey,buf);
// Do the 'DKV.put' in a F/J task. For multi-JVM setups, this step often
Expand All @@ -420,7 +420,7 @@ static private Futures readPut(Key key, InputStream is, Job job, final Futures f
dkv_fs.add(subtask);
f_last = subtask;
}
assert is.read(new byte[1]) == -1 || job.cancelled();
assert is.read(new byte[1]) == -1 || !Job.isRunning(job.self());

// Last chunk is short, read it; combine buffers and make the last chunk larger
if( cidx > 0 ) {
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/water/api/Jobs.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public String caption(JsonArray array, String name) {
@Override
public String elementToString(JsonElement elm, String contextName) {
String html;
if( Job.cancelled(Key.make(elm.getAsString())) )
if( !Job.isRunning(Key.make(elm.getAsString())) )
html = "<button disabled class='btn btn-mini'>X</button>";
else {
String keyParam = KEY + "=" + elm.getAsString();
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/water/api/Progress.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ protected Response serve() {
Job job = findJob();
JsonObject jsonResponse = defaultJsonResponse();

if(job != null && job.cancelled()){
if( job != null && job.isCancelled() ) {
String msg = job.exception == null?"Job was cancelled by user!":job.exception;
return Response.error(msg);
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/water/api/Progress2.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ public static Response redirect(Request req, Key jobkey, Key dest) {
Job jjob = null;
if( job_key != null )
jjob = Job.findJob(job_key);
if( jjob != null && jjob.cancelled())
if( jjob != null && jjob.isCancelled())
return Response.error(jjob.exception == null ? "Job was cancelled by user!" : jjob.exception);
if( jjob == null || jjob.end_time > 0 || jjob.cancelled() )
if( jjob == null || jjob.end_time > 0 || jjob.isCancelled() )
return jobDone(jjob, destination_key);
return jobInProgress(jjob, destination_key);
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/water/api/RFView.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ protected Response jobDone(JsonObject jsonResp) {
// Handle cancelled/aborted jobs
if (_job.value()!=null) {
Job jjob = Job.findJob(_job.value());
if (jjob!=null && jjob.cancelled())
if (jjob!=null && jjob.isCancelled())
return Response.error(jjob.exception == null ? "Job was cancelled by user!" : jjob.exception);
}

Expand Down
4 changes: 2 additions & 2 deletions src/main/java/water/parser/DParseTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ public void phaseTwoInitialize() {
* splitting it into equal sized chunks.
*/
@Override public void map(Key key) {
if(_job.cancelled())
if(!Job.isRunning(_job.self()))
throw new JobCancelledException();
_map = true;
Key aryKey = null;
Expand Down Expand Up @@ -583,7 +583,7 @@ public void phaseTwoInitialize() {

@Override
public void reduce(DParseTask dpt) {
if(_job.cancelled())
if(!Job.isRunning(_job.self()))
return;
assert dpt._map;
if(_sigma == null)_sigma = dpt._sigma;
Expand Down

0 comments on commit c657cdd

Please sign in to comment.