Skip to content

Commit

Permalink
Add support for desired class distribution when balance_classes is en…
Browse files Browse the repository at this point in the history
…abled.
  • Loading branch information
arnocandel committed Sep 12, 2014
1 parent 268f51e commit 77c1912
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/main/java/hex/deeplearning/DeepLearning.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package hex.deeplearning;

import com.amazonaws.services.cloudfront.model.InvalidArgumentException;
import hex.*;
import water.*;
import water.util.*;
Expand Down Expand Up @@ -384,6 +385,12 @@ public class DeepLearning extends Job.ValidatedJob {
@API(help = "Balance training data class counts via over/under-sampling (for imbalanced data)", filter = Default.class, json = true, importance = ParamImportance.EXPERT)
public boolean balance_classes = false;

/**
* Desired relative class ratios of the training data after over/under-sampling. Only when balance_classes is enabled. Default is 1/N for each class, where N is the number of classes.
*/
@API(help = "Desired relative class ratios of the training data after over/under-sampling.", filter = Default.class, dmin = 1, json = true, importance = ParamImportance.SECONDARY)
public float[] balance_class_model_distribution;

/**
* When classes are balanced, limit the resulting dataset size to the
* specified multiple of the original dataset size.
Expand Down Expand Up @@ -1004,9 +1011,13 @@ public final DeepLearningModel trainModel(DeepLearningModel model) {
if (!quiet_mode) Log.info("Number of model parameters (weights/biases): " + String.format("%,d", model_size));
train = model.model_info().data_info()._adaptedFrame;
if (mp.force_load_balance) train = updateFrame(train, reBalance(train, mp.replicate_training_data /*rebalance into only 4*cores per node*/));
float[] trainSamplingFactors;
if (mp.classification && mp.balance_classes) {
trainSamplingFactors = new float[train.lastVec().domain().length]; //leave initialized to 0 -> will be filled up below
float[] trainSamplingFactors = new float[train.lastVec().domain().length]; //leave initialized to 0 -> will be filled up below
if (balance_class_model_distribution != null) {
if (balance_class_model_distribution.length != train.lastVec().domain().length)
throw new IllegalArgumentException("balance_class_model_distribution must have " + train.lastVec().domain().length + " elements");
trainSamplingFactors = balance_class_model_distribution;
}
train = updateFrame(train, sampleFrameStratified(
train, train.lastVec(), trainSamplingFactors, (long)(mp.max_after_balance_size*train.numRows()), mp.seed, true, false));
model.setModelClassDistribution(new MRUtils.ClassDist(train.lastVec()).doAll(train.lastVec()).rel_dist());
Expand Down Expand Up @@ -1049,6 +1060,7 @@ public final DeepLearningModel trainModel(DeepLearningModel model) {
model.update(self());
Log.info("Starting to train the Deep Learning model.");

if (n_folds == 0 || xval_models == null)
//main loop
do model.set_model_info(H2O.CLOUD.size() > 1 && mp.replicate_training_data ? ( mp.single_node_mode ?
new DeepLearningTask2(train, model.model_info(), rowFraction(train, mp, model)).invoke(Key.make()).model_info() : //replicated data + single node mode
Expand Down

0 comments on commit 77c1912

Please sign in to comment.