Skip to content

Commit

Permalink
Add warning and list remedies if the input layer is bigger than 100k …
Browse files Browse the repository at this point in the history
…neurons.
  • Loading branch information
arnocandel committed Jan 30, 2015
1 parent ab04e26 commit ae78448
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/main/java/hex/deeplearning/DeepLearningModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import water.fvec.Vec;
import water.util.*;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Random;

/**
Expand Down Expand Up @@ -326,6 +328,34 @@ public DeepLearningModelInfo(final Job job, final DataInfo dinfo) {
units[0] = num_input;
System.arraycopy(get_params().hidden, 0, units, 1, layers);
units[layers+1] = num_output;

if ((long)units[0] > 100000L) {
final String[][] domains = dinfo._adaptedFrame.domains();
int[] levels = new int[domains.length];
for (int i=0; i<levels.length; ++i) {
levels[i] = domains[i] != null ? domains[i].length : 0;
}
Arrays.sort(levels);
Log.warn("===================================================================================================================================");
Log.warn(num_input + " input features" + (dinfo._cats > 0 ? " (after categorical one-hot encoding)" : "") + ". Can be slow and require a lot of memory.");
if (levels[levels.length-1] > 0) {
int levelcutoff = levels[levels.length-1-Math.min(10, levels.length)];
int count = 0;
for (int i=0; i<dinfo._adaptedFrame.numCols() - (get_params().autoencoder ? 0 : 1) && count < 10; ++i) {
if (dinfo._adaptedFrame.domains()[i] != null && dinfo._adaptedFrame.domains()[i].length >= levelcutoff) {
Log.warn("Categorical feature '" + dinfo._adaptedFrame._names[i] + "' has cardinality " + dinfo._adaptedFrame.domains()[i].length + ".");
}
}
}
Log.warn("Suggestions:");
Log.warn(" *) Limit the size of the first hidden layer");
if (dinfo._cats > 0) {
Log.warn(" *) Limit the total number of one-hot encoded features with the parameter 'max_categorical_features'");
Log.warn(" *) Run h2o.interaction(...,pairwise=F) on high-cardinality categorical columns to limit the factor count, see http://learn.h2o.ai");
}
Log.warn("===================================================================================================================================");
}

// weights (to connect layers)
dense_row_weights = new Neurons.DenseRowMatrix[layers+1];
dense_col_weights = new Neurons.DenseColMatrix[layers+1];
Expand Down

0 comments on commit ae78448

Please sign in to comment.