Skip to content

Commit

Permalink
Merge pull request imbs-hl#426 from imbs-hl/fix_extratrees
Browse files Browse the repository at this point in the history
Fix extratrees splitting with num.random.splits>1
  • Loading branch information
mnwright authored Aug 22, 2019
2 parents d15cf25 + 46c2f4c commit 3c25198
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 3 deletions.
2 changes: 1 addition & 1 deletion cpp_version/src/utility/ArgumentHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ R package "ranger" under GPL3 license.
namespace ranger {

ArgumentHandler::ArgumentHandler(int argc, char **argv) :
caseweights(""), depvarname(""), fraction(1), holdout(false), memmode(MEM_DOUBLE), savemem(false), skipoob(false), predict(
caseweights(""), depvarname(""), fraction(0), holdout(false), memmode(MEM_DOUBLE), savemem(false), skipoob(false), predict(
""), predictiontype(DEFAULT_PREDICTIONTYPE), randomsplits(DEFAULT_NUM_RANDOM_SPLITS), splitweights(""), nthreads(
DEFAULT_NUM_THREADS), predall(false), alpha(DEFAULT_ALPHA), minprop(DEFAULT_MINPROP), maxdepth(
DEFAULT_MAXDEPTH), file(""), impmeasure(DEFAULT_IMPORTANCE_MODE), targetpartitionsize(0), mtry(0), outprefix(
Expand Down
9 changes: 8 additions & 1 deletion src/Forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,14 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode
prediction_mode = true;
}

// Sample fraction to vector
// Sample fraction default and convert to vector
if (sample_fraction == 0) {
if (sample_with_replacement) {
sample_fraction = DEFAULT_SAMPLE_FRACTION_REPLACE;
} else {
sample_fraction = DEFAULT_SAMPLE_FRACTION_NOREPLACE;
}
}
std::vector<double> sample_fraction_vector = { sample_fraction };

// Call other init function
Expand Down
3 changes: 3 additions & 0 deletions src/TreeClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,9 @@ void TreeClassification::findBestSplitValueExtraTrees(size_t nodeID, size_t varI
for (size_t i = 0; i < num_random_splits; ++i) {
possible_split_values.push_back(udist(random_number_generator));
}
if (num_random_splits > 1) {
std::sort(possible_split_values.begin(), possible_split_values.end());
}

const size_t num_splits = possible_split_values.size();
if (memory_saving_splitting) {
Expand Down
3 changes: 3 additions & 0 deletions src/TreeProbability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ void TreeProbability::findBestSplitValueExtraTrees(size_t nodeID, size_t varID,
for (size_t i = 0; i < num_random_splits; ++i) {
possible_split_values.push_back(udist(random_number_generator));
}
if (num_random_splits > 1) {
std::sort(possible_split_values.begin(), possible_split_values.end());
}

const size_t num_splits = possible_split_values.size();
if (memory_saving_splitting) {
Expand Down
3 changes: 3 additions & 0 deletions src/TreeRegression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,9 @@ void TreeRegression::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, d
for (size_t i = 0; i < num_random_splits; ++i) {
possible_split_values.push_back(udist(random_number_generator));
}
if (num_random_splits > 1) {
std::sort(possible_split_values.begin(), possible_split_values.end());
}

const size_t num_splits = possible_split_values.size();
if (memory_saving_splitting) {
Expand Down
3 changes: 3 additions & 0 deletions src/TreeSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,9 @@ void TreeSurvival::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, dou
for (size_t i = 0; i < num_random_splits; ++i) {
possible_split_values.push_back(udist(random_number_generator));
}
if (num_random_splits > 1) {
std::sort(possible_split_values.begin(), possible_split_values.end());
}

size_t num_splits = possible_split_values.size();

Expand Down
3 changes: 2 additions & 1 deletion src/globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ const uint DEFAULT_MAXDEPTH = 0;
const PredictionType DEFAULT_PREDICTIONTYPE = RESPONSE;
const uint DEFAULT_NUM_RANDOM_SPLITS = 1;

//const std::vector<double> DEFAULT_SAMPLE_FRACTION = std::vector<double>({1});
const double DEFAULT_SAMPLE_FRACTION_REPLACE = 1;
const double DEFAULT_SAMPLE_FRACTION_NOREPLACE = 0.632;

// Interval to print progress in seconds
const double STATUS_INTERVAL = 30.0;
Expand Down

0 comments on commit 3c25198

Please sign in to comment.