Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] require scikit-learn>=0.24.2, make scikit-learn estimators compatible with scikit-learn>=1.6.0dev #6651

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
1adb77b
__sklearn_tags__ replacing sklearn's BaseEstimator._more_tags_
vnherdeiro Sep 11, 2024
8ed87d2
fixing tags dict -> dataclass
vnherdeiro Sep 11, 2024
32ec431
fixing wrong import
vnherdeiro Sep 11, 2024
ade9798
remove type hint
vnherdeiro Sep 11, 2024
2085a12
remove type hint
vnherdeiro Sep 11, 2024
a9ec348
fix linting
vnherdeiro Sep 11, 2024
fcc4e12
triggering new CI (scikit-learn dev has changed)
vnherdeiro Sep 14, 2024
3b15646
bringing back _more_tags, adding convertsion from more_tags to sklear…
vnherdeiro Sep 15, 2024
34d9eb4
lint fix
vnherdeiro Sep 15, 2024
6d20ef8
Update python-package/lightgbm/sklearn.py
vnherdeiro Sep 16, 2024
d715311
adressing PR comments
vnherdeiro Sep 16, 2024
c4ec9a4
move comment
jameslamb Sep 16, 2024
b0a4703
updates
jameslamb Sep 21, 2024
7eb861a
remove uses of super()
jameslamb Sep 24, 2024
b137ba2
fix version constraint in lint job, add one more comment
jameslamb Sep 24, 2024
d1915c0
Update python-package/lightgbm/sklearn.py
jameslamb Sep 24, 2024
6cf2158
Merge branch 'master' into fix_sklearn_more_tags_deprecation
jameslamb Sep 26, 2024
b5663aa
Merge branch 'fix_sklearn_more_tags_deprecation' of github.com:vnherd…
jameslamb Sep 26, 2024
118efd9
use scikit-learn 1.6 nightlies again, move some code to compat.py, re…
jameslamb Oct 2, 2024
4fb82f3
optionally use validate_data(), starting in scikit-learn 1.6
jameslamb Oct 3, 2024
c42c53d
fix validate_data() for older versions, update tests
jameslamb Oct 4, 2024
58d77e7
Merge branch 'master' of github.com:microsoft/LightGBM into fix_sklea…
jameslamb Oct 4, 2024
33fb5b6
more changes
jameslamb Oct 4, 2024
6689faa
fix n_features_in setting
jameslamb Oct 5, 2024
9a05670
fix return type
jameslamb Oct 5, 2024
815433f
remove now-unnecessary _LGBMCheckXY()
jameslamb Oct 5, 2024
ffebe41
correct comment
jameslamb Oct 5, 2024
722474d
Merge branch 'master' of github.com:microsoft/LightGBM into fix_sklea…
jameslamb Oct 6, 2024
f2cb2fe
Apply suggestions from code review
jameslamb Oct 6, 2024
86b5ab3
move __version__ import to compat.py, test with all ML tasks
jameslamb Oct 6, 2024
125f4ea
just set the setters and deleters
jameslamb Oct 6, 2024
4233d70
set floor of scikit-learn>=0.24.2, fix ordering of n_features_in_ set…
jameslamb Oct 6, 2024
330df3f
fix conflicts
jameslamb Oct 6, 2024
e8e4cdb
Update python-package/lightgbm/sklearn.py
jameslamb Oct 6, 2024
0b0ea24
Merge branch 'master' into fix_sklearn_more_tags_deprecation
jameslamb Oct 6, 2024
f22e494
forgot to commit ... fix comment
jameslamb Oct 7, 2024
b124797
Merge branch 'master' of github.com:microsoft/LightGBM into fix_sklea…
jameslamb Oct 7, 2024
beab71c
Merge branch 'fix_sklearn_more_tags_deprecation' of github.com:vnherd…
jameslamb Oct 7, 2024
c6e6fad
Apply suggestions from code review
jameslamb Oct 8, 2024
e3eabac
Merge branch 'master' into fix_sklearn_more_tags_deprecation
jameslamb Oct 9, 2024
d8762e5
predict_proba() shape is (num_data, num_classes) for multi-class
jameslamb Oct 9, 2024
8ef1deb
pass ensure_min_samples=1 at predict() time too
jameslamb Oct 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove uses of super()
  • Loading branch information
jameslamb committed Sep 24, 2024
commit 7eb861addebd6d446be511c35a00b85158d11fa9
1 change: 1 addition & 0 deletions .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ if [[ $TASK == "lint" ]]; then
'mypy>=1.11.1' \
'pre-commit>=3.8.0' \
'pyarrow-core>=17.0' \
'scikit-learn>=1.15.0' \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to ensure that mypy checks scikit-learn imports. Extra important now that I'm proposing adding an optional type hint on this new sklearn.utils.Tags.

'r-lintr>=3.1.2'
source activate $CONDA_ENV
echo "Linting Python code"
Expand Down
77 changes: 47 additions & 30 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import copy
from inspect import signature
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import scipy.sparse
Expand Down Expand Up @@ -46,6 +46,12 @@
)
from .engine import train

if TYPE_CHECKING:
try:
from sklearn.utils import Tags as _sklearn_Tags
except ImportError:
_sklearn_Tags = None

__all__ = [
"LGBMClassifier",
"LGBMModel",
Expand Down Expand Up @@ -673,41 +679,45 @@ def _more_tags(self) -> Dict[str, Any]:
"_xfail_checks": {
"check_no_attributes_set_in_init": "scikit-learn incorrectly asserts that private attributes "
"cannot be set in __init__: "
"(see https://github.com/microsoft/LightGBM/issues/2628)"
"(see https://github.com/microsoft/LightGBM/issues/2628)",
"check_n_features_in_after_fitting": (
"validate_data() was first added in scikit-learn 1.6 and lightgbm"
"supports much older versions than that"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add in this comment that LightGBM supports predict_disable_shape_check=True and we won't call validate_data() even after minimum sklearn version bump.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm I don't agree with that suggestion.

If lightgbm's minimum floor on scikit-learn was >=1.6, I think we could safely consider calling validate_data(), maybe like:

if not params.get("predict_disable_shape_check", True):
    self.validate_data(X)

LightGBM offering the ability to opt out of this validation wouldn't change that fact that the check_n_features_in_after_fitting test from scikit-learn would pass by default.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't think about such conditional approach. I like it!

Then, I think we can adopt validate_data() right now like we do with _check_sample_weight():

try:
from sklearn.utils.validation import _check_sample_weight
except ImportError:
from sklearn.utils.validation import check_consistent_length
# dummy function to support older version of scikit-learn
def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
check_consistent_length(sample_weight, X)
return sample_weight

We can try something like the following:

try: 
    from sklearn.utils.validation import validate_data
except ImportError:
    # dummy function to support older version of scikit-learn
    def validate_data(_estimator, /, X='no_validation', y='no_validation', reset=True, validate_separately=False, skip_check_array=False, **check_params)
        check_array(X, y)
        check_X_y(X, y)
        _estimator._n_features_in = _estimator._n_features
        return X, y

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah, good idea!! I've started on this, still a bit of local testing to do... will try to push something here tomorrow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While testing this, I realized that LightGBM's scikit-learn interface already does not support the predict_disable_shape_check mechanism.

See this in LGBMModel.predict().

n_features = X.shape[1]
if self._n_features != n_features:
raise ValueError(
"Number of features of the model must "
f"match the input. Model n_features_ is {self._n_features} and "
f"input n_features is {n_features}"
)

I think that behavior should be kept... if scikit-learn's API does not support passing inputs to predict() with a different number of features than were present when the model was fit(), then neither should LightGBM's scikit-learn estimators.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that behavior should be kept... if scikit-learn's API does not support passing inputs to predict() with a different number of features than were present when the model was fit(), then neither should LightGBM's scikit-learn estimators.

Agree!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In recent commits, I've done the following:

  • moved that number-of-features validation code out of fit() and into the validate_data() now defined in compat.py
  • reworded the error message so it's identical to the one sklearn.utils.validation.validate_data() raises
  • added a unit tests confirming that that error is raised, with that exact message, to be sure we'll find out if scikit-learn removes that check in the internals of validate_data() in the future

),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the 1.6.dev nightlies, scikit-learn is raising this error:

E AssertionError: LGBMRegressor.predict() does not check for consistency between input number
E of features with LGBMRegressor.fit(), via the n_features_in_ attribute.
E You might want to use sklearn.utils.validation.validate_data instead
E of check_array in LGBMRegressor.fit() and LGBMRegressor.predict()`. This can be done
E like the following:
E from sklearn.utils.validation import validate_data

We should ignore this check here in LightGBM... validate_data() will be added for the first time in scikit-learn 1.6:

We have other mechanisms further down in LightGBM to check shape mismatches between training data and the data provided at scoring time. I'd rather rely on those than take on the complexity of try-catching a call to this new-in-v1.6 validate_data() function.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understand you don't want to use validate_data here, but you can still conform to the API with your own tools.

You probably also want to make sure you store n_feature_in_ as well, to better imitate sklearn's behavior.

I would personally go down the fixes.py path though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understand you don't want to use validate_data here, but you can still conform to the API with your own tools.

How could we avoid the check_n_features_in_after_fitting check failing without calling validate_data()? Could you point to a doc I could reference?

You probably also want to make sure you store n_feature_in_ as well, to better imitate sklearn's behavior.

We do.

@property
def n_features_in_(self) -> int:
""":obj:`int`: The number of features of fitted model."""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError("No n_features_in found. Need to call fit beforehand.")
return self._n_features_in

},
}

@staticmethod
def _update_sklearn_tags_from_dict(
*,
tags: "sklearn.utils.Tags",
tags_dict: Dict[str, Any]
) -> "sklearn.utils.Tags":
"""
scikit-learn 1.6 introduced a dataclass-based interface for estimator tags.
tags: "_sklearn_Tags",
tags_dict: Dict[str, Any],
) -> "_sklearn_Tags":
"""Update ``sklearn.utils.Tags`` inherited from ``scikit-learn`` base classes.

``scikit-learn`` 1.6 introduced a dataclass-based interface for estimator tags.
ref: https://github.com/scikit-learn/scikit-learn/pull/29677

That interface means that each
This method handles updating that instance based on the value in ``self._more_tags()``.
"""
tags.input_tags.allow_nan = more_tags["allow_nan"]
tags.input_tags.sparse = "sparse" in more_tags["X_types"]
tags.target_tags.one_d_labels = "1dlabels" in more_tags["X_types"]
tags._xfail_checks = more_tags["_xfail_checks"]
tags.input_tags.allow_nan = tags_dict["allow_nan"]
tags.input_tags.sparse = "sparse" in tags_dict["X_types"]
tags.target_tags.one_d_labels = "1dlabels" in tags_dict["X_types"]
tags._xfail_checks = tags_dict["_xfail_checks"]
return tags

def __sklearn_tags__(self):
# super().__sklearn_tags__() cannot be called unconditionally,
def __sklearn_tags__(self) -> Optional["_sklearn_Tags"]:
# _LGBMModelBase.__sklearn_tags__() cannot be called unconditionally,
# because that method isn't defined for scikit-learn<1.6
if not callable(getattr(super(), "__sklearn_tags__", None)):
if not callable(getattr(_LGBMModelBase, "__sklearn_tags__", None)):
return None

# take whatever tags are provided by BaseEstimator, then modify
# them with LightGBM-specific values
tags = self._update_sklearn_tags_from_dict(
tags=super().__sklearn_tags__(),
tags_dict=self._more_tags()
return self._update_sklearn_tags_from_dict(
tags=_LGBMModelBase.__sklearn_tags__(self),
tags_dict=self._more_tags(),
)
return tags

def __sklearn_is_fitted__(self) -> bool:
return getattr(self, "fitted_", False)
Expand Down Expand Up @@ -1206,15 +1216,17 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
"""LightGBM regressor."""

def _more_tags(self) -> Dict[str, Any]:
tags = super(LGBMModel, self)._more_tags()
tags.update(super(_LGBMRegressorBase, self)._more_tags())
# handle the case where ClassifierMixin possibly provides _more_tags()
if callable(getattr(_LGBMClassifierBase, "_more_tags", None)):
tags = _LGBMClassifierBase._more_tags(self)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proposing all these uses of {some_class}.{some_method} instead of super().{some_method} because we follow this advice from scikit-learn's docs (https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator):

...mixins should be “on the left” while the BaseEstimator should be “on the right” in the inheritance list for proper MRO.

Using super() would get the _more_tags() / __sklearn_tags__() from e.g. sklearn.base.RegressorMixin, but we want to use LightGBM's specific tags.

jameslamb marked this conversation as resolved.
Show resolved Hide resolved
else:
tags = {}
# override those with LightGBM-specific preferences
tags.update(LGBMModel._more_tags(self))
return tags

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
if tags is None:
return None

def __sklearn_tags__(self) -> Optional["_sklearn_Tags"]:
return LGBMModel.__sklearn_tags__(self)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need __sklearn_tags__() in LGBMRegressor and LGBMClassifier due to MRO again, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah exactly. Because the scikit-learn mixin classes come first (#6651 (comment)), if/when they define a __sklearn_tags__(), it would take precedence over the one from LGBMModel.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__sklearn_tags__ doesn't have any corresponding manual MRO walk the same way that _more_tags did. You should treat it like any other normal overriding of a method in python's OOP model.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what you mean by this comment. Would appreciate your thoughts on #6651 (comment) if you have time.


def fit( # type: ignore[override]
self,
Expand Down Expand Up @@ -1263,12 +1275,17 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
"""LightGBM classifier."""

def _more_tags(self) -> Dict[str, Any]:
tags = super(LGBMModel, self)._more_tags()
tags.update(super(_LGBMClassifierBase, self)._more_tags())
# handle the case where ClassifierMixin possibly provides _more_tags()
if callable(getattr(_LGBMClassifierBase, "_more_tags", None)):
tags = _LGBMClassifierBase._more_tags(self)
else:
tags = {}
# override those with LightGBM-specific preferences
tags.update(LGBMModel._more_tags(self))
return tags

def __sklearn_tags__(self):
return super().__
def __sklearn_tags__(self) -> Optional["_sklearn_Tags"]:
return LGBMModel.__sklearn_tags__(self)

def fit( # type: ignore[override]
self,
Expand Down
13 changes: 9 additions & 4 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,11 +1441,16 @@ def test_sklearn_integration(estimator, check):
def test_sklearn_tags_should_correctly_reflect_lightgbm_specific_values(estimator_class):
est = estimator_class()
more_tags = est._more_tags()
assert (
more_tags["X_types"] == ["2darray", "sparse", "1dlabels"],
"List of supported X_types has changed. Update LGBMModel.__sklearn__tags() to match.",
)
err_msg = "List of supported X_types has changed. Update LGBMModel.__sklearn__tags() to match."
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
assert more_tags["X_types"] == ["2darray", "sparse", "1dlabels"], err_msg
sklearn_tags = est.__sklearn_tags__()
# these tests should be run unconditionally (no 'if') once lightgbm's
# minimum scikit-learn version is 1.6 or higher
if sklearn_tags is not None:
assert sklearn_tags.input_tags.allow_nan is True
assert sklearn_tags.input_tags.sparse is True
assert sklearn_tags.target_tags.one_d_labels is True
assert sklearn_tags._xfail_checks == more_tags["_xfail_checks"]


@pytest.mark.parametrize("task", ["binary-classification", "multiclass-classification", "ranking", "regression"])
Expand Down
Loading