From 77c191240e9f875c3060f481217a89af2503647e Mon Sep 17 00:00:00 2001 From: Arno Candel Date: Thu, 11 Sep 2014 23:48:50 -0700 Subject: [PATCH] Add support for desired class distribution when balance_classes is enabled. --- src/main/java/hex/deeplearning/DeepLearning.java | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/main/java/hex/deeplearning/DeepLearning.java b/src/main/java/hex/deeplearning/DeepLearning.java index 522c3258c3..2db4780849 100644 --- a/src/main/java/hex/deeplearning/DeepLearning.java +++ b/src/main/java/hex/deeplearning/DeepLearning.java @@ -1,5 +1,6 @@ package hex.deeplearning; +import com.amazonaws.services.cloudfront.model.InvalidArgumentException; import hex.*; import water.*; import water.util.*; @@ -384,6 +385,12 @@ public class DeepLearning extends Job.ValidatedJob { @API(help = "Balance training data class counts via over/under-sampling (for imbalanced data)", filter = Default.class, json = true, importance = ParamImportance.EXPERT) public boolean balance_classes = false; + /** + * Desired relative class ratios of the training data after over/under-sampling. Only when balance_classes is enabled. Default is 1/N for each class, where N is the number of classes. + */ + @API(help = "Desired relative class ratios of the training data after over/under-sampling.", filter = Default.class, dmin = 1, json = true, importance = ParamImportance.SECONDARY) + public float[] balance_class_model_distribution; + /** * When classes are balanced, limit the resulting dataset size to the * specified multiple of the original dataset size. @@ -1004,9 +1011,13 @@ public final DeepLearningModel trainModel(DeepLearningModel model) { if (!quiet_mode) Log.info("Number of model parameters (weights/biases): " + String.format("%,d", model_size)); train = model.model_info().data_info()._adaptedFrame; if (mp.force_load_balance) train = updateFrame(train, reBalance(train, mp.replicate_training_data /*rebalance into only 4*cores per node*/)); - float[] trainSamplingFactors; if (mp.classification && mp.balance_classes) { - trainSamplingFactors = new float[train.lastVec().domain().length]; //leave initialized to 0 -> will be filled up below + float[] trainSamplingFactors = new float[train.lastVec().domain().length]; //leave initialized to 0 -> will be filled up below + if (balance_class_model_distribution != null) { + if (balance_class_model_distribution.length != train.lastVec().domain().length) + throw new IllegalArgumentException("balance_class_model_distribution must have " + train.lastVec().domain().length + " elements"); + trainSamplingFactors = balance_class_model_distribution; + } train = updateFrame(train, sampleFrameStratified( train, train.lastVec(), trainSamplingFactors, (long)(mp.max_after_balance_size*train.numRows()), mp.seed, true, false)); model.setModelClassDistribution(new MRUtils.ClassDist(train.lastVec()).doAll(train.lastVec()).rel_dist()); @@ -1049,6 +1060,7 @@ public final DeepLearningModel trainModel(DeepLearningModel model) { model.update(self()); Log.info("Starting to train the Deep Learning model."); + if (n_folds == 0 || xval_models == null) //main loop do model.set_model_info(H2O.CLOUD.size() > 1 && mp.replicate_training_data ? ( mp.single_node_mode ? new DeepLearningTask2(train, model.model_info(), rowFraction(train, mp, model)).invoke(Key.make()).model_info() : //replicated data + single node mode