Skip to content

Commit

Permalink
fix extratrees splitting with num.random.splits>1
Browse files Browse the repository at this point in the history
  • Loading branch information
mnwright committed Aug 22, 2019
1 parent d15cf25 commit ce29a66
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/TreeClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,9 @@ void TreeClassification::findBestSplitValueExtraTrees(size_t nodeID, size_t varI
for (size_t i = 0; i < num_random_splits; ++i) {
possible_split_values.push_back(udist(random_number_generator));
}
if (num_random_splits > 1) {
std::sort(possible_split_values.begin(), possible_split_values.end());
}

const size_t num_splits = possible_split_values.size();
if (memory_saving_splitting) {
Expand Down
3 changes: 3 additions & 0 deletions src/TreeProbability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ void TreeProbability::findBestSplitValueExtraTrees(size_t nodeID, size_t varID,
for (size_t i = 0; i < num_random_splits; ++i) {
possible_split_values.push_back(udist(random_number_generator));
}
if (num_random_splits > 1) {
std::sort(possible_split_values.begin(), possible_split_values.end());
}

const size_t num_splits = possible_split_values.size();
if (memory_saving_splitting) {
Expand Down
3 changes: 3 additions & 0 deletions src/TreeRegression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,9 @@ void TreeRegression::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, d
for (size_t i = 0; i < num_random_splits; ++i) {
possible_split_values.push_back(udist(random_number_generator));
}
if (num_random_splits > 1) {
std::sort(possible_split_values.begin(), possible_split_values.end());
}

const size_t num_splits = possible_split_values.size();
if (memory_saving_splitting) {
Expand Down
3 changes: 3 additions & 0 deletions src/TreeSurvival.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,9 @@ void TreeSurvival::findBestSplitValueExtraTrees(size_t nodeID, size_t varID, dou
for (size_t i = 0; i < num_random_splits; ++i) {
possible_split_values.push_back(udist(random_number_generator));
}
if (num_random_splits > 1) {
std::sort(possible_split_values.begin(), possible_split_values.end());
}

size_t num_splits = possible_split_values.size();

Expand Down

0 comments on commit ce29a66

Please sign in to comment.