Skip to content

Commit

Permalink
Always respect forced splits, even when feature_fraction < 1.0 (fixes m…
Browse files Browse the repository at this point in the history
…icrosoft#4601) (microsoft#4725)

* issue fix microsoft#4601

* fix issue 4601 it2

* add tests for issue 4601

* fix warning

* fix warning

* add new line at end

* remove last line at end

* fix lint warning

* address comments

* address comments

* address comments

* fix address

* address comments

* revert seed

* fix recursive force split issue

* fix build error

* fix lint warning
  • Loading branch information
tongwu-sh authored Nov 10, 2021
1 parent b1facf5 commit 33a2f9e
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 4 deletions.
41 changes: 37 additions & 4 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <algorithm>
#include <queue>
#include <set>
#include <unordered_map>
#include <utility>

Expand Down Expand Up @@ -322,10 +323,14 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
}

void SerialTreeLearner::FindBestSplits(const Tree* tree) {
FindBestSplits(tree, nullptr);
}

void SerialTreeLearner::FindBestSplits(const Tree* tree, const std::set<int>* force_features) {
std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static, 256) if (num_features_ >= 512)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue;
if (!col_sampler_.is_feature_used_bytree()[feature_index] && (force_features == nullptr || force_features->find(feature_index) == force_features->end())) continue;
if (parent_leaf_histogram_array_ != nullptr
&& !parent_leaf_histogram_array_[feature_index].is_splittable()) {
smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
Expand Down Expand Up @@ -462,12 +467,14 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf,
bool left_smaller = true;
std::unordered_map<int, SplitInfo> forceSplitMap;
q.push(std::make_pair(left, *left_leaf));

// Histogram construction require parent features.
std::set<int> force_split_features = FindAllForceFeatures(*forced_split_json_);
while (!q.empty()) {
// before processing next node from queue, store info for current left/right leaf
// store "best split" for left and right, even if they might be overwritten by forced split
if (BeforeFindBestSplit(tree, *left_leaf, *right_leaf)) {
FindBestSplits(tree);
FindBestSplits(tree, &force_split_features);
}

// then, compute own splits
SplitInfo left_split;
SplitInfo right_split;
Expand Down Expand Up @@ -561,6 +568,32 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf,
return result_count;
}

std::set<int> SerialTreeLearner::FindAllForceFeatures(Json force_split_leaf_setting) {
std::set<int> force_features;
std::queue<Json> force_split_leafs;

force_split_leafs.push(force_split_leaf_setting);

while (!force_split_leafs.empty()) {
Json split_leaf = force_split_leafs.front();
force_split_leafs.pop();

const int feature_index = split_leaf["feature"].int_value();
const int feature_inner_index = train_data_->InnerFeatureIndex(feature_index);
force_features.insert(feature_inner_index);

if (split_leaf.object_items().count("left") > 0) {
force_split_leafs.push(split_leaf["left"]);
}

if (split_leaf.object_items().count("right") > 0) {
force_split_leafs.push(split_leaf["right"]);
}
}

return force_features;
}

void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
int* right_leaf, bool update_cnt) {
Common::FunctionTimer fun_timer("SerialTreeLearner::SplitInner", global_timer);
Expand Down
5 changes: 5 additions & 0 deletions src/treelearner/serial_tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <memory>
#include <random>
#include <vector>
#include <set>

#include "col_sampler.hpp"
#include "data_partition.hpp"
Expand Down Expand Up @@ -142,6 +143,8 @@ class SerialTreeLearner: public TreeLearner {

virtual void FindBestSplits(const Tree* tree);

virtual void FindBestSplits(const Tree* tree, const std::set<int>* force_features);

virtual void ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract);

virtual void FindBestSplitsFromHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree*);
Expand All @@ -165,6 +168,8 @@ class SerialTreeLearner: public TreeLearner {
int32_t ForceSplits(Tree* tree, int* left_leaf, int* right_leaf,
int* cur_depth);

std::set<int> FindAllForceFeatures(Json force_split_leaf_setting);

/*!
* \brief Get the number of data in a leaf
* \param leaf_idx The index of leaf
Expand Down
38 changes: 38 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# coding: utf-8
import copy
import itertools
import json
import math
import pickle
import platform
Expand Down Expand Up @@ -2887,3 +2888,40 @@ def hook(obj):
dumped_model_str = str(bst.dump_model(5, 0, object_hook=hook))
assert "leaf_value" not in dumped_model_str
assert "LV" in dumped_model_str


def test_force_split_with_feature_fraction(tmp_path):
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train)

forced_split = {
"feature": 0,
"threshold": 0.5,
"right": {
"feature": 2,
"threshold": 10.0
}
}

tmp_split_file = tmp_path / "forced_split.json"
with open(tmp_split_file, "w") as f:
f.write(json.dumps(forced_split))

params = {
"objective": "regression",
"feature_fraction": 0.6,
"force_col_wise": True,
"feature_fraction_seed": 1,
"forcedsplits_filename": tmp_split_file
}

gbm = lgb.train(params, lgb_train)
ret = mean_absolute_error(y_test, gbm.predict(X_test))
assert ret < 2.0

tree_info = gbm.dump_model()["tree_info"]
assert len(tree_info) > 1
for tree in tree_info:
tree_structure = tree["tree_structure"]
assert tree_structure['split_feature'] == 0

0 comments on commit 33a2f9e

Please sign in to comment.