Skip to content

Commit

Permalink
Add force_load_balance option to redistribute the data.
Browse files Browse the repository at this point in the history
  • Loading branch information
arnocandel committed Feb 25, 2014
1 parent 219905b commit 7bbde3e
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 119 deletions.
23 changes: 13 additions & 10 deletions h2o-samples/src/main/java/samples/NeuralNetMnist2.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ public static void main(String[] args) throws Exception {
// samples.launchers.CloudProcess.launch(job, 4);
//samples.launchers.CloudConnect.launch(job, "localhost:54321");
// samples.launchers.CloudRemote.launchIPs(job, "192.168.1.171", "192.168.1.172", "192.168.1.173", "192.168.1.174", "192.168.1.175");
samples.launchers.CloudRemote.launchIPs(job, "192.168.1.161", "192.168.1.162", "192.168.1.163", "192.168.1.164");
// samples.launchers.CloudRemote.launchIPs(job, "192.168.1.161", "192.168.1.162");
// samples.launchers.CloudRemote.launchIPs(job, "192.168.1.161");
// samples.launchers.CloudRemote.launchIPs(job, "192.168.1.161", "192.168.1.162", "192.168.1.163", "192.168.1.164");
samples.launchers.CloudRemote.launchIPs(job, "192.168.1.161", "192.168.1.162", "192.168.1.164");
// samples.launchers.CloudRemote.launchIPs(job, "192.168.1.162", "192.168.1.164");
// samples.launchers.CloudRemote.launchIPs(job, "192.168.1.162");
// samples.launchers.CloudRemote.launchEC2(job, 4);
}

Expand All @@ -31,9 +32,9 @@ public static void main(String[] args) throws Exception {
//long seed = 0xC0FFEE;
long seed = new Random().nextLong();
double fraction = 1.0;
Frame trainf = sampleFrame(TestUtil.parseFromH2OFolder("smalldata/mnist/train10x.csv"), (long)(600000*fraction), seed);
// Frame trainf = sampleFrame(TestUtil.parseFromH2OFolder("smalldata/mnist/train.csv"), (long)(60000*fraction), seed);
Frame testf = sampleFrame(TestUtil.parseFromH2OFolder("smalldata/mnist/test.csv"), (long)(10000*fraction), seed+1);
// Frame trainf = sampleFrame(TestUtil.parseFromH2OFolder("smalldata/mnist/train10x.csv"), (long)(600000*fraction), seed);
Frame trainf = sampleFrame(TestUtil.parseFromH2OFolder("smalldata/mnist/train.csv.gz"), (long)(60000*fraction), seed);
Frame testf = sampleFrame(TestUtil.parseFromH2OFolder("smalldata/mnist/test.csv.gz"), (long)(10000*fraction), seed+1);
Log.info("Done.");

NN p = new NN();
Expand All @@ -47,7 +48,7 @@ public static void main(String[] args) throws Exception {
p.loss = NN.Loss.CrossEntropy;
p.input_dropout_ratio = 0.2;
p.max_w2 = 15;
p.epochs = 10;
p.epochs = 10000;
p.l1 = 1e-5;
p.l2 = 0;
p.momentum_start = 0.5;
Expand All @@ -64,12 +65,14 @@ public static void main(String[] args) throws Exception {
p.source = trainf;
p.response = trainf.lastVec();
p.ignored_cols = null;
p.mini_batch = 60000;
p.score_interval = 600;
p.mini_batch = 240000;
p.score_interval = 60;

p.fast_mode = true; //to match old NeuralNet behavior
p.ignore_const_cols = true;
p.shuffle_training_data = true;
p.shuffle_training_data = false;
p.force_load_balance = true;
p.quiet_mode = false;
return p.exec();
}
}
177 changes: 79 additions & 98 deletions src/main/java/hex/nn/NN.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import water.api.RequestServer;
import water.fvec.Frame;
import water.util.Log;
import water.util.MRUtils;
import water.util.RString;

import java.util.Random;
Expand Down Expand Up @@ -96,6 +97,9 @@ public class NN extends Job.ValidatedJob {
@API(help = "Ignore constant training columns", filter = Default.class, json = true)
public boolean ignore_const_cols = true;

@API(help = "Force load balancing to increase speed for small datasets (<200MB/node)", filter = Default.class, json = true)
public boolean force_load_balance = true;

@API(help = "Enable periodic shuffling of training data (can increase stochastic gradient descent performance)", filter = Default.class, json = true)
public boolean shuffle_training_data = false;

Expand Down Expand Up @@ -171,8 +175,10 @@ protected void registered(RequestServer.API_VERSION ver) {
|| arg._name.equals("mini_batch")
|| arg._name.equals("fast_mode")
|| arg._name.equals("ignore_const_cols")
|| arg._name.equals("force_load_balance")
|| arg._name.equals("shuffle_training_data")
|| arg._name.equals("nesterov_accelerated_gradient") || arg._name.equals("classification_stop")
|| arg._name.equals("nesterov_accelerated_gradient")
|| arg._name.equals("classification_stop")
|| arg._name.equals("regression_stop")
|| arg._name.equals("quiet_mode")
) {
Expand Down Expand Up @@ -240,67 +246,83 @@ void checkParams() {
}

public final NNModel initModel() {
checkParams();
lock_data();
final Frame train = FrameTask.DataInfo.prepareFrame(source, response, ignored_cols, classification, ignore_const_cols);
final DataInfo dinfo = new FrameTask.DataInfo(train, 1, true, !classification);
final NNModel model = new NNModel(dest(), self(), source._key, dinfo, this);
unlock_data();
model.model_info().initializeMembers();
return model;
try {
lock_data();
checkParams();
final Frame train = FrameTask.DataInfo.prepareFrame(source, response, ignored_cols, classification, ignore_const_cols);
final DataInfo dinfo = new FrameTask.DataInfo(train, 1, true, !classification);
final NNModel model = new NNModel(dest(), self(), source._key, dinfo, this);
model.model_info().initializeMembers();
return model;
}
finally {
unlock_data();
}
}

public final NNModel buildModel(NNModel model) {
lock_data();
logStart();
Log.info("Number of chunks of the training data: " + source.anyVec().nChunks());
if (validation != null)
Log.info("Number of chunks of the validation data: " + validation.anyVec().nChunks());
if (model == null) {
model = UKV.get(dest());
}
model.write_lock(self());
if (!quiet_mode) Log.info("Initial model:\n" + model.model_info());

final long model_size = model.model_info().size();
Log.info("Number of model parameters (weights/biases): " + String.format("%,d", model_size));
Log.info("Memory usage of the model: " + String.format("%.2f", (double)model_size*Float.SIZE / (1<<23)) + " MB.");

final Frame train = model.model_info().data_info()._adaptedFrame;
Frame trainScoreFrame = sampleFrame(train, score_training_samples, seed);

Frame[] valid_adapted = null;
Frame valid = null;
Frame validScoreFrame = null;
if (validation != null) {
valid_adapted = model.adapt(validation, false);
valid = valid_adapted[0];
validScoreFrame = valid != validation ? sampleFrame(valid, score_validation_samples, seed+1) : null;
}
Frame valid = null, validScoreFrame = null;
Frame train = null, trainScoreFrame = null;
try {
lock_data();
logStart();
if (model == null) {
model = UKV.get(dest());
}
model.write_lock(self());
final long model_size = model.model_info().size();
Log.info("Number of model parameters (weights/biases): " + String.format("%,d", model_size));
Log.info("Memory usage of the model: " + String.format("%.2f", (double)model_size*Float.SIZE / (1<<23)) + " MB.");
train = reBalance(model.model_info().data_info()._adaptedFrame, seed);
trainScoreFrame = sampleFrame(train, score_training_samples, seed);
Log.info("Number of chunks of the training data: " + train.anyVec().nChunks());
if (validation != null) {
valid_adapted = model.adapt(validation, false);
valid = reBalance(valid_adapted[0], seed+1);
validScoreFrame = sampleFrame(valid, score_validation_samples, seed+1);
Log.info("Number of chunks of the validation data: " + valid.anyVec().nChunks());
}
if (mini_batch > train.numRows()) {
Log.warn("Setting mini_batch (" + mini_batch
+ ") to the number of rows of the training data (" + (mini_batch=train.numRows()) + ").");
}
// determines the number of rows processed during NNTask, affects synchronization (happens at the end of each NNTask)
final float sync_fraction = mini_batch == 0l ? 1.0f : (float)mini_batch / train.numRows();

if (!quiet_mode) Log.info("Initial model:\n" + model.model_info());

Log.info("Starting to train the Neural Net model.");
long timeStart = System.currentTimeMillis();

//main loop
long iter = 0;
Frame newtrain = new Frame(train);
do {
model.set_model_info(new NNTask(model.model_info(), sync_fraction).doAll(newtrain).model_info());
if (++iter % 10 != 0 && shuffle_training_data) {
Frame newtrain2 = reBalance(newtrain, seed+iter);
if (newtrain != newtrain2) {
newtrain.delete();
newtrain = newtrain2;
trainScoreFrame = sampleFrame(newtrain, score_training_samples, seed+iter+0xDADDAAAA);
}
}
}
while (model.doScoring(trainScoreFrame, validScoreFrame, timeStart, self()));

if (mini_batch > train.numRows()) {
Log.warn("Setting mini_batch (" + mini_batch
+ ") to the number of rows of the training data (" + (mini_batch=train.numRows()) + ").");
Log.info("Finished training the Neural Net model.");
return model;
}
// determines the number of rows processed during NNTask, affects synchronization (happens at the end of each NNTask)
final float sync_fraction = mini_batch == 0l ? 1.0f : (float)mini_batch / train.numRows();

Log.info("Starting to train the Neural Net model.");
long timeStart = System.currentTimeMillis();

//main loop
do model.set_model_info(new NNTask(model.model_info(), true /*train*/, sync_fraction).doAll(train).model_info());
while (model.doScoring(trainScoreFrame, validScoreFrame, timeStart, self()));
model.unlock(self());

//clean up
if (validScoreFrame != null && validScoreFrame != valid) validScoreFrame.delete();
if (trainScoreFrame != null && trainScoreFrame != train) trainScoreFrame.delete();
if (validation != null) valid_adapted[1].delete(); //just deleted the adapted frames for validation
finally {
model.unlock(self());
//clean up
if (validScoreFrame != null && validScoreFrame != valid) validScoreFrame.delete();
if (trainScoreFrame != null && trainScoreFrame != train) trainScoreFrame.delete();
if (validation != null) valid_adapted[1].delete(); //just deleted the adapted frames for validation
// if (_newsource != null && _newsource != source) _newsource.delete();
unlock_data();
Log.info("Finished training the Neural Net model.");
return model;
unlock_data();
}
}

private void lock_data() {
Expand All @@ -321,49 +343,8 @@ public void delete() {
remove();
}

/*
long _iter = 0;
private void reBalance(Frame fr) {
shuffleAndBalance(fr, seed+_iter++, shuffle_training_data);
fr.reloadVecs();
Log.info("Number of chunks of " + fr.toString() + ": " + fr.anyVec().nChunks());
private Frame reBalance(final Frame fr, long seed) {
return force_load_balance || shuffle_training_data ? MRUtils.shuffleAndBalance(fr, seed, shuffle_training_data) : fr;
}

// master node collects all rows, and distributes them across the cluster - slow
private static void shuffleAndBalance(Frame fr, long seed, final boolean shuffle) {
int cores = 0;
for( H2ONode node : H2O.CLOUD._memary )
cores += node._heartbeat._num_cpus;
final int splits = cores;
long[] idx = null;
if (shuffle) {
idx = new long[(int)fr.numRows()]; //HACK: int instead of of long
for (int r=0; r<idx.length; ++r) idx[r] = r;
Utils.shuffleArray(idx, seed);
}
Vec[] vecs = fr.vecs();
if( vecs[0].nChunks() < splits || shuffle ) {
Key keys[] = new Vec.VectorGroup().addVecs(vecs.length);
for( int v = 0; v < vecs.length; v++ ) {
AppendableVec vec = new AppendableVec(keys[v]);
final long rows = fr.numRows();
for( int split = 0; split < splits; split++ ) {
long off = rows * split / splits;
long lim = rows * (split + 1) / splits;
NewChunk chunk = new NewChunk(vec, split);
for( long r = off; r < lim; r++ ) {
if (shuffle) chunk.addNum(fr.vecs()[v].at(idx[(int)r]));
else chunk.addNum(fr.vecs()[v].at(r));
}
chunk.close(split, null);
}
Vec t = vec.close(null);
t._domain = vecs[v]._domain;
vecs[v] = t;
}
}
}
*/
}
3 changes: 1 addition & 2 deletions src/main/java/hex/nn/NNModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,7 @@ public long size() {
// package local helpers
final int[] units; //number of neurons per layer, extracted from parameters and from datainfo

// public NNModelInfo(NN params, int num_input, int num_output) {
public NNModelInfo(NN params, DataInfo dinfo) {
public NNModelInfo(final NN params, final DataInfo dinfo) {
data_info = dinfo; //should be deep_clone()?
final int num_input = dinfo.fullN();
final int num_output = params.classification ? dinfo._adaptedFrame.lastVec().domain().length : 1;
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/hex/nn/NNTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ public class NNTask extends FrameTask<NNTask> {

int _chunk_node_count = 1;

public NNTask(NNModel.NNModelInfo input, boolean training, float fraction){this(input,training,fraction,null);}
private NNTask(NNModel.NNModelInfo input, boolean training, float fraction, H2OCountedCompleter cmp){
public NNTask(NNModel.NNModelInfo input, float fraction){this(input,fraction,null);}
private NNTask(NNModel.NNModelInfo input, float fraction, H2OCountedCompleter cmp){
super(input.job(),input.data_info(),cmp);
_training=training;
_training=true;
_input=input;
_useFraction=fraction;
_shuffle = _input.get_params().shuffle_training_data;
Expand Down
1 change: 1 addition & 0 deletions src/test/java/hex/NNvsNeuralNet.java
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ else if (Math.abs(a - b) <= abseps) {
p.shuffle_training_data = false; //same as old NeuralNet code
p.nesterov_accelerated_gradient = true; //same as old NeuralNet code
p.classification_stop = -1; //don't stop early -> need to compare against old NeuralNet code, which doesn't stop either
p.force_load_balance = false; //keep 1 chunk for reproducibility
p.exec();

mymodel = UKV.get(p.dest());
Expand Down
1 change: 1 addition & 0 deletions src/test/java/hex/NeuralNetIrisTest2.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ else if (Math.abs(a - b) <= abseps) {
p.ignore_const_cols = false;
p.shuffle_training_data = false;
p.classification_stop = -1; //don't stop early -> need to compare against reference, which doesn't stop either
p.force_load_balance = false; //keep just 1 chunk for reproducibility
NNModel mymodel = p.initModel(); //randomize weights, but don't start training yet

Neurons[] neurons = NNTask.makeNeuronsForTraining(mymodel.model_info());
Expand Down
15 changes: 9 additions & 6 deletions src/test/java/hex/NeuralNetSpiralsTest2.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import water.fvec.ParseDataset2;
import water.util.Log;

import java.util.Random;

public class NeuralNetSpiralsTest2 extends TestUtil {
@BeforeClass public static void stall() {
stall_till_cloudsize(JUnitRunnerDebug.NODES);
Expand All @@ -28,10 +30,10 @@ public class NeuralNetSpiralsTest2 extends TestUtil {
// build the model
{
NN p = new NN();
p.seed = 7401699394609084302l;
p.seed = new Random().nextLong();
p.rate = 0.007;
p.rate_annealing = 0;
p.epochs = 11000;
p.epochs = 20000;
p.hidden = new int[]{100};
p.activation = NN.Activation.Tanh;
p.max_w2 = Double.MAX_VALUE;
Expand All @@ -50,15 +52,16 @@ public class NeuralNetSpiralsTest2 extends TestUtil {
p.ignored_cols = null;
p.mini_batch = 0; //sync once per period
p.quiet_mode = true;
p.fast_mode = true; //same as old NeuralNet code
p.ignore_const_cols = false; //same as old NeuralNet code
p.shuffle_training_data = false; //same as old NeuralNet code
p.nesterov_accelerated_gradient = true; //same as old NeuralNet code
p.fast_mode = true;
p.ignore_const_cols = true;
p.nesterov_accelerated_gradient = true;
p.classification = true;
p.diagnostics = true;
p.expert_mode = true;
p.score_training_samples = 1000;
p.score_validation_samples = 10000;
p.shuffle_training_data = true;
p.force_load_balance = true; //make it multi-threaded
p.destination_key = dest;
p.exec();
}
Expand Down

0 comments on commit 7bbde3e

Please sign in to comment.