Skip to content

Commit

Permalink
controlled issuing of parallel grid s aearch jobs.
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Nov 15, 2013
1 parent 68960bc commit c98af8c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
39 changes: 28 additions & 11 deletions src/main/java/hex/GLMGrid.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
import hex.NewRowVecTask.JobCancelledException;

import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;

import water.*;
import water.H2O.H2OCountedCompleter;

import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import water.util.Log;

public class GLMGrid extends Job {
Key _datakey; // Data to work on
Expand All @@ -22,10 +25,11 @@ public class GLMGrid extends Job {
double[] _ts; // Thresholds
double[] _alphas; // Grid search values
int _xfold;
boolean _parallel;
boolean _parallelFlag;
int _parallelism;
GLMParams _glmp;

public GLMGrid(Key dest, ValueArray va, GLMParams glmp, int[] xs, double[] ls, double[] as, double[] thresholds, int xfold, boolean parallel) {
public GLMGrid(Key dest, ValueArray va, GLMParams glmp, int[] xs, double[] ls, double[] as, double[] thresholds, int xfold, boolean pflag, int par) {
destination_key = dest;
_ary = va; // VA is large, and already in a Key so make it transient
_datakey = va._key; // ... and use the data key instead when reloading
Expand All @@ -36,7 +40,8 @@ public GLMGrid(Key dest, ValueArray va, GLMParams glmp, int[] xs, double[] ls, d
_ts = thresholds;
_alphas = as;
_xfold = xfold;
_parallel = parallel;
_parallelFlag = pflag;
_parallelism = par;
_glmp.checkResponseCol(_ary._cols[xs[xs.length-1]], new ArrayList<String>()); // ignore warnings here, they will be shown for each mdoel anyways
}

Expand Down Expand Up @@ -104,20 +109,32 @@ public void start() {
UKV.put(dest(), new GLMModels(_lambdas.length * _alphas.length));
H2OCountedCompleter fjtask = new H2OCountedCompleter() {
@Override public void compute2() {
if(_parallel) {
if(_parallelFlag) {
final int cloudsize = H2O.CLOUD._memary.length;
int myId = H2O.SELF.index();
for( int a = 0; a < _alphas.length; a++ ) {
GridTask t = new GridTask(GLMGrid.this, a, _parallel);
int nodeId = (myId+a)%cloudsize;
if(nodeId == myId)
H2O.submitTask(t);
else
int all = 0, done = 0;
Future[] active = new GridTask[_parallelism];
for (int job = 0; job < _alphas.length; job++) {
GridTask t = new GridTask(GLMGrid.this, job, true);
int nodeId = (myId+job)%cloudsize;
if (nodeId != myId) {
RPC.call(H2O.CLOUD._memary[nodeId],t);
continue;
}
if (all - done >= _parallelism) {
try {
active[done++%_parallelism].get();
} catch( InterruptedException e ) {
throw Log.errRTExcept(e);
} catch( ExecutionException e ) {
throw Log.errRTExcept(e);
}
}
active[all++%_parallelism] = t.fork();
}
} else {
for( int a = 0; a < _alphas.length; a++ ) {
GridTask t = new GridTask(GLMGrid.this, a, _parallel);
GridTask t = new GridTask(GLMGrid.this, a, false);
t.compute2();
}
remove();
Expand Down
5 changes: 4 additions & 1 deletion src/main/java/water/api/GLMGrid.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class GLMGrid extends Request {
public static final String JSON_ROWS = "rows";
public static final String JSON_TIME = "time";
public static final String JSON_COEFFICIENTS = "coefficients";
public static final String JSON_PARALLELISM = "parallelism";

// Need a HEX key for GLM
protected final H2OHexKey _key = new H2OHexKey(KEY);
Expand Down Expand Up @@ -62,6 +63,7 @@ public class GLMGrid extends Request {
protected final RSeq _thresholds = new RSeq(Constants.DTHRESHOLDS, false, new NumberSequence("0:1:0.01",false,0.1),false);

protected final Bool _parallel = new Bool(PARALLEL, true, "Build models in parallel");
protected final Int _parallelism = new Int(JSON_PARALLELISM, 1, 256);

public GLMGrid(){
_requestHelp = "Perform grid search over GLM parameters. Calls glm with all parameter combination from user-defined parameter range. Results are ordered according to AUC. For more details see <a href='GLM.help'>GLM help</a>.";
Expand Down Expand Up @@ -127,7 +129,8 @@ private int[] getCols(int [] xs, int y){
_alpha.value()._arr, // Grid ranges
ts,
_xval.value(),
_parallel.value());
_parallel.value(),
_parallelism.value());
job.start();

// Redirect to the grid-search status page
Expand Down

0 comments on commit c98af8c

Please sign in to comment.