diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index 8b1725a64992..304c712f0723 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -322,10 +323,14 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int } void SerialTreeLearner::FindBestSplits(const Tree* tree) { + FindBestSplits(tree, nullptr); +} + +void SerialTreeLearner::FindBestSplits(const Tree* tree, const std::set* force_features) { std::vector is_feature_used(num_features_, 0); #pragma omp parallel for schedule(static, 256) if (num_features_ >= 512) for (int feature_index = 0; feature_index < num_features_; ++feature_index) { - if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue; + if (!col_sampler_.is_feature_used_bytree()[feature_index] && (force_features == nullptr || force_features->find(feature_index) == force_features->end())) continue; if (parent_leaf_histogram_array_ != nullptr && !parent_leaf_histogram_array_[feature_index].is_splittable()) { smaller_leaf_histogram_array_[feature_index].set_is_splittable(false); @@ -462,12 +467,14 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf, bool left_smaller = true; std::unordered_map forceSplitMap; q.push(std::make_pair(left, *left_leaf)); + + // Histogram construction require parent features. + std::set force_split_features = FindAllForceFeatures(*forced_split_json_); while (!q.empty()) { - // before processing next node from queue, store info for current left/right leaf - // store "best split" for left and right, even if they might be overwritten by forced split if (BeforeFindBestSplit(tree, *left_leaf, *right_leaf)) { - FindBestSplits(tree); + FindBestSplits(tree, &force_split_features); } + // then, compute own splits SplitInfo left_split; SplitInfo right_split; @@ -561,6 +568,32 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf, return result_count; } +std::set SerialTreeLearner::FindAllForceFeatures(Json force_split_leaf_setting) { + std::set force_features; + std::queue force_split_leafs; + + force_split_leafs.push(force_split_leaf_setting); + + while (!force_split_leafs.empty()) { + Json split_leaf = force_split_leafs.front(); + force_split_leafs.pop(); + + const int feature_index = split_leaf["feature"].int_value(); + const int feature_inner_index = train_data_->InnerFeatureIndex(feature_index); + force_features.insert(feature_inner_index); + + if (split_leaf.object_items().count("left") > 0) { + force_split_leafs.push(split_leaf["left"]); + } + + if (split_leaf.object_items().count("right") > 0) { + force_split_leafs.push(split_leaf["right"]); + } + } + + return force_features; +} + void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf, int* right_leaf, bool update_cnt) { Common::FunctionTimer fun_timer("SerialTreeLearner::SplitInner", global_timer); diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h index 0466cafa43b3..7dfadf05d119 100644 --- a/src/treelearner/serial_tree_learner.h +++ b/src/treelearner/serial_tree_learner.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "col_sampler.hpp" #include "data_partition.hpp" @@ -142,6 +143,8 @@ class SerialTreeLearner: public TreeLearner { virtual void FindBestSplits(const Tree* tree); + virtual void FindBestSplits(const Tree* tree, const std::set* force_features); + virtual void ConstructHistograms(const std::vector& is_feature_used, bool use_subtract); virtual void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract, const Tree*); @@ -165,6 +168,8 @@ class SerialTreeLearner: public TreeLearner { int32_t ForceSplits(Tree* tree, int* left_leaf, int* right_leaf, int* cur_depth); + std::set FindAllForceFeatures(Json force_split_leaf_setting); + /*! * \brief Get the number of data in a leaf * \param leaf_idx The index of leaf diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index eaf7244fe150..547d662e123a 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -1,6 +1,7 @@ # coding: utf-8 import copy import itertools +import json import math import pickle import platform @@ -2887,3 +2888,40 @@ def hook(obj): dumped_model_str = str(bst.dump_model(5, 0, object_hook=hook)) assert "leaf_value" not in dumped_model_str assert "LV" in dumped_model_str + + +def test_force_split_with_feature_fraction(tmp_path): + X, y = load_boston(return_X_y=True) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) + lgb_train = lgb.Dataset(X_train, y_train) + + forced_split = { + "feature": 0, + "threshold": 0.5, + "right": { + "feature": 2, + "threshold": 10.0 + } + } + + tmp_split_file = tmp_path / "forced_split.json" + with open(tmp_split_file, "w") as f: + f.write(json.dumps(forced_split)) + + params = { + "objective": "regression", + "feature_fraction": 0.6, + "force_col_wise": True, + "feature_fraction_seed": 1, + "forcedsplits_filename": tmp_split_file + } + + gbm = lgb.train(params, lgb_train) + ret = mean_absolute_error(y_test, gbm.predict(X_test)) + assert ret < 2.0 + + tree_info = gbm.dump_model()["tree_info"] + assert len(tree_info) > 1 + for tree in tree_info: + tree_structure = tree["tree_structure"] + assert tree_structure['split_feature'] == 0