diff --git a/src/main/java/hex/GLMGrid.java b/src/main/java/hex/GLMGrid.java index f7d8aeb304..f2dd57b77d 100644 --- a/src/main/java/hex/GLMGrid.java +++ b/src/main/java/hex/GLMGrid.java @@ -7,12 +7,15 @@ import hex.NewRowVecTask.JobCancelledException; import java.util.*; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; import water.*; import water.H2O.H2OCountedCompleter; import com.google.gson.JsonArray; import com.google.gson.JsonObject; +import water.util.Log; public class GLMGrid extends Job { Key _datakey; // Data to work on @@ -22,10 +25,11 @@ public class GLMGrid extends Job { double[] _ts; // Thresholds double[] _alphas; // Grid search values int _xfold; - boolean _parallel; + boolean _parallelFlag; + int _parallelism; GLMParams _glmp; - public GLMGrid(Key dest, ValueArray va, GLMParams glmp, int[] xs, double[] ls, double[] as, double[] thresholds, int xfold, boolean parallel) { + public GLMGrid(Key dest, ValueArray va, GLMParams glmp, int[] xs, double[] ls, double[] as, double[] thresholds, int xfold, boolean pflag, int par) { destination_key = dest; _ary = va; // VA is large, and already in a Key so make it transient _datakey = va._key; // ... and use the data key instead when reloading @@ -36,7 +40,8 @@ public GLMGrid(Key dest, ValueArray va, GLMParams glmp, int[] xs, double[] ls, d _ts = thresholds; _alphas = as; _xfold = xfold; - _parallel = parallel; + _parallelFlag = pflag; + _parallelism = par; _glmp.checkResponseCol(_ary._cols[xs[xs.length-1]], new ArrayList()); // ignore warnings here, they will be shown for each mdoel anyways } @@ -104,20 +109,32 @@ public void start() { UKV.put(dest(), new GLMModels(_lambdas.length * _alphas.length)); H2OCountedCompleter fjtask = new H2OCountedCompleter() { @Override public void compute2() { - if(_parallel) { + if(_parallelFlag) { final int cloudsize = H2O.CLOUD._memary.length; int myId = H2O.SELF.index(); - for( int a = 0; a < _alphas.length; a++ ) { - GridTask t = new GridTask(GLMGrid.this, a, _parallel); - int nodeId = (myId+a)%cloudsize; - if(nodeId == myId) - H2O.submitTask(t); - else + int all = 0, done = 0; + Future[] active = new GridTask[_parallelism]; + for (int job = 0; job < _alphas.length; job++) { + GridTask t = new GridTask(GLMGrid.this, job, true); + int nodeId = (myId+job)%cloudsize; + if (nodeId != myId) { RPC.call(H2O.CLOUD._memary[nodeId],t); + continue; + } + if (all - done >= _parallelism) { + try { + active[done++%_parallelism].get(); + } catch( InterruptedException e ) { + throw Log.errRTExcept(e); + } catch( ExecutionException e ) { + throw Log.errRTExcept(e); + } + } + active[all++%_parallelism] = t.fork(); } } else { for( int a = 0; a < _alphas.length; a++ ) { - GridTask t = new GridTask(GLMGrid.this, a, _parallel); + GridTask t = new GridTask(GLMGrid.this, a, false); t.compute2(); } remove(); diff --git a/src/main/java/water/api/GLMGrid.java b/src/main/java/water/api/GLMGrid.java index a232dff9ab..4b754abed2 100644 --- a/src/main/java/water/api/GLMGrid.java +++ b/src/main/java/water/api/GLMGrid.java @@ -35,6 +35,7 @@ public class GLMGrid extends Request { public static final String JSON_ROWS = "rows"; public static final String JSON_TIME = "time"; public static final String JSON_COEFFICIENTS = "coefficients"; + public static final String JSON_PARALLELISM = "parallelism"; // Need a HEX key for GLM protected final H2OHexKey _key = new H2OHexKey(KEY); @@ -62,6 +63,7 @@ public class GLMGrid extends Request { protected final RSeq _thresholds = new RSeq(Constants.DTHRESHOLDS, false, new NumberSequence("0:1:0.01",false,0.1),false); protected final Bool _parallel = new Bool(PARALLEL, true, "Build models in parallel"); + protected final Int _parallelism = new Int(JSON_PARALLELISM, 1, 256); public GLMGrid(){ _requestHelp = "Perform grid search over GLM parameters. Calls glm with all parameter combination from user-defined parameter range. Results are ordered according to AUC. For more details see GLM help."; @@ -127,7 +129,8 @@ private int[] getCols(int [] xs, int y){ _alpha.value()._arr, // Grid ranges ts, _xval.value(), - _parallel.value()); + _parallel.value(), + _parallelism.value()); job.start(); // Redirect to the grid-search status page