-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[python-package] require scikit-learn>=0.24.2
, make scikit-learn estimators compatible with scikit-learn>=1.6.0dev
#6651
Changes from 1 commit
1adb77b
8ed87d2
32ec431
ade9798
2085a12
a9ec348
fcc4e12
3b15646
34d9eb4
6d20ef8
d715311
c4ec9a4
b0a4703
7eb861a
b137ba2
d1915c0
6cf2158
b5663aa
118efd9
4fb82f3
c42c53d
58d77e7
33fb5b6
6689faa
9a05670
815433f
ffebe41
722474d
f2cb2fe
86b5ab3
125f4ea
4233d70
330df3f
e8e4cdb
0b0ea24
f22e494
b124797
beab71c
c6e6fad
e3eabac
d8762e5
8ef1deb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4,7 +4,7 @@ | |||||||||||||||||||||||||||||||||
import copy | ||||||||||||||||||||||||||||||||||
from inspect import signature | ||||||||||||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||||||||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||||||||||||||||||||||||||||||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||||||||
import scipy.sparse | ||||||||||||||||||||||||||||||||||
|
@@ -46,6 +46,12 @@ | |||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
from .engine import train | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
if TYPE_CHECKING: | ||||||||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||||||||
from sklearn.utils import Tags as _sklearn_Tags | ||||||||||||||||||||||||||||||||||
except ImportError: | ||||||||||||||||||||||||||||||||||
_sklearn_Tags = None | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
__all__ = [ | ||||||||||||||||||||||||||||||||||
"LGBMClassifier", | ||||||||||||||||||||||||||||||||||
"LGBMModel", | ||||||||||||||||||||||||||||||||||
|
@@ -673,41 +679,45 @@ def _more_tags(self) -> Dict[str, Any]: | |||||||||||||||||||||||||||||||||
"_xfail_checks": { | ||||||||||||||||||||||||||||||||||
"check_no_attributes_set_in_init": "scikit-learn incorrectly asserts that private attributes " | ||||||||||||||||||||||||||||||||||
"cannot be set in __init__: " | ||||||||||||||||||||||||||||||||||
"(see https://github.com/microsoft/LightGBM/issues/2628)" | ||||||||||||||||||||||||||||||||||
"(see https://github.com/microsoft/LightGBM/issues/2628)", | ||||||||||||||||||||||||||||||||||
"check_n_features_in_after_fitting": ( | ||||||||||||||||||||||||||||||||||
"validate_data() was first added in scikit-learn 1.6 and lightgbm" | ||||||||||||||||||||||||||||||||||
"supports much older versions than that" | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should add in this comment that LightGBM supports There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmmm I don't agree with that suggestion. If if not params.get("predict_disable_shape_check", True):
self.validate_data(X) LightGBM offering the ability to opt out of this validation wouldn't change that fact that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didn't think about such conditional approach. I like it! Then, I think we can adopt LightGBM/python-package/lightgbm/compat.py Lines 22 to 30 in 59a3432
We can try something like the following: try:
from sklearn.utils.validation import validate_data
except ImportError:
# dummy function to support older version of scikit-learn
def validate_data(_estimator, /, X='no_validation', y='no_validation', reset=True, validate_separately=False, skip_check_array=False, **check_params)
check_array(X, y)
check_X_y(X, y)
_estimator._n_features_in = _estimator._n_features
return X, y There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh yeah, good idea!! I've started on this, still a bit of local testing to do... will try to push something here tomorrow. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While testing this, I realized that LightGBM's See this in LightGBM/python-package/lightgbm/sklearn.py Lines 1008 to 1014 in d1d218c
I think that behavior should be kept... if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Agree! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In recent commits, I've done the following:
|
||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On the
We should ignore this check here in LightGBM...
We have other mechanisms further down in LightGBM to check shape mismatches between training data and the data provided at scoring time. I'd rather rely on those than take on the complexity of try-catching a call to this new-in-v1.6 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Understand you don't want to use You probably also want to make sure you store I would personally go down the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
How could we avoid the
We do. LightGBM/python-package/lightgbm/sklearn.py Lines 1063 to 1068 in 41ba9e8
|
||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
@staticmethod | ||||||||||||||||||||||||||||||||||
def _update_sklearn_tags_from_dict( | ||||||||||||||||||||||||||||||||||
*, | ||||||||||||||||||||||||||||||||||
tags: "sklearn.utils.Tags", | ||||||||||||||||||||||||||||||||||
tags_dict: Dict[str, Any] | ||||||||||||||||||||||||||||||||||
) -> "sklearn.utils.Tags": | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
scikit-learn 1.6 introduced a dataclass-based interface for estimator tags. | ||||||||||||||||||||||||||||||||||
tags: "_sklearn_Tags", | ||||||||||||||||||||||||||||||||||
tags_dict: Dict[str, Any], | ||||||||||||||||||||||||||||||||||
) -> "_sklearn_Tags": | ||||||||||||||||||||||||||||||||||
"""Update ``sklearn.utils.Tags`` inherited from ``scikit-learn`` base classes. | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
``scikit-learn`` 1.6 introduced a dataclass-based interface for estimator tags. | ||||||||||||||||||||||||||||||||||
ref: https://github.com/scikit-learn/scikit-learn/pull/29677 | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
That interface means that each | ||||||||||||||||||||||||||||||||||
This method handles updating that instance based on the value in ``self._more_tags()``. | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
tags.input_tags.allow_nan = more_tags["allow_nan"] | ||||||||||||||||||||||||||||||||||
tags.input_tags.sparse = "sparse" in more_tags["X_types"] | ||||||||||||||||||||||||||||||||||
tags.target_tags.one_d_labels = "1dlabels" in more_tags["X_types"] | ||||||||||||||||||||||||||||||||||
tags._xfail_checks = more_tags["_xfail_checks"] | ||||||||||||||||||||||||||||||||||
tags.input_tags.allow_nan = tags_dict["allow_nan"] | ||||||||||||||||||||||||||||||||||
tags.input_tags.sparse = "sparse" in tags_dict["X_types"] | ||||||||||||||||||||||||||||||||||
tags.target_tags.one_d_labels = "1dlabels" in tags_dict["X_types"] | ||||||||||||||||||||||||||||||||||
tags._xfail_checks = tags_dict["_xfail_checks"] | ||||||||||||||||||||||||||||||||||
return tags | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def __sklearn_tags__(self): | ||||||||||||||||||||||||||||||||||
# super().__sklearn_tags__() cannot be called unconditionally, | ||||||||||||||||||||||||||||||||||
def __sklearn_tags__(self) -> Optional["_sklearn_Tags"]: | ||||||||||||||||||||||||||||||||||
# _LGBMModelBase.__sklearn_tags__() cannot be called unconditionally, | ||||||||||||||||||||||||||||||||||
# because that method isn't defined for scikit-learn<1.6 | ||||||||||||||||||||||||||||||||||
if not callable(getattr(super(), "__sklearn_tags__", None)): | ||||||||||||||||||||||||||||||||||
if not callable(getattr(_LGBMModelBase, "__sklearn_tags__", None)): | ||||||||||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# take whatever tags are provided by BaseEstimator, then modify | ||||||||||||||||||||||||||||||||||
# them with LightGBM-specific values | ||||||||||||||||||||||||||||||||||
tags = self._update_sklearn_tags_from_dict( | ||||||||||||||||||||||||||||||||||
tags=super().__sklearn_tags__(), | ||||||||||||||||||||||||||||||||||
tags_dict=self._more_tags() | ||||||||||||||||||||||||||||||||||
return self._update_sklearn_tags_from_dict( | ||||||||||||||||||||||||||||||||||
tags=_LGBMModelBase.__sklearn_tags__(self), | ||||||||||||||||||||||||||||||||||
tags_dict=self._more_tags(), | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
return tags | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def __sklearn_is_fitted__(self) -> bool: | ||||||||||||||||||||||||||||||||||
return getattr(self, "fitted_", False) | ||||||||||||||||||||||||||||||||||
|
@@ -1206,15 +1216,17 @@ class LGBMRegressor(_LGBMRegressorBase, LGBMModel): | |||||||||||||||||||||||||||||||||
"""LightGBM regressor.""" | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def _more_tags(self) -> Dict[str, Any]: | ||||||||||||||||||||||||||||||||||
tags = super(LGBMModel, self)._more_tags() | ||||||||||||||||||||||||||||||||||
tags.update(super(_LGBMRegressorBase, self)._more_tags()) | ||||||||||||||||||||||||||||||||||
# handle the case where ClassifierMixin possibly provides _more_tags() | ||||||||||||||||||||||||||||||||||
if callable(getattr(_LGBMClassifierBase, "_more_tags", None)): | ||||||||||||||||||||||||||||||||||
tags = _LGBMClassifierBase._more_tags(self) | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Proposing all these uses of
Using
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||
tags = {} | ||||||||||||||||||||||||||||||||||
# override those with LightGBM-specific preferences | ||||||||||||||||||||||||||||||||||
tags.update(LGBMModel._more_tags(self)) | ||||||||||||||||||||||||||||||||||
return tags | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def __sklearn_tags__(self): | ||||||||||||||||||||||||||||||||||
tags = super().__sklearn_tags__() | ||||||||||||||||||||||||||||||||||
if tags is None: | ||||||||||||||||||||||||||||||||||
return None | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def __sklearn_tags__(self) -> Optional["_sklearn_Tags"]: | ||||||||||||||||||||||||||||||||||
return LGBMModel.__sklearn_tags__(self) | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah exactly. Because the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand what you mean by this comment. Would appreciate your thoughts on #6651 (comment) if you have time. |
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def fit( # type: ignore[override] | ||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||
|
@@ -1263,12 +1275,17 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): | |||||||||||||||||||||||||||||||||
"""LightGBM classifier.""" | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def _more_tags(self) -> Dict[str, Any]: | ||||||||||||||||||||||||||||||||||
tags = super(LGBMModel, self)._more_tags() | ||||||||||||||||||||||||||||||||||
tags.update(super(_LGBMClassifierBase, self)._more_tags()) | ||||||||||||||||||||||||||||||||||
# handle the case where ClassifierMixin possibly provides _more_tags() | ||||||||||||||||||||||||||||||||||
if callable(getattr(_LGBMClassifierBase, "_more_tags", None)): | ||||||||||||||||||||||||||||||||||
tags = _LGBMClassifierBase._more_tags(self) | ||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||
tags = {} | ||||||||||||||||||||||||||||||||||
# override those with LightGBM-specific preferences | ||||||||||||||||||||||||||||||||||
tags.update(LGBMModel._more_tags(self)) | ||||||||||||||||||||||||||||||||||
return tags | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def __sklearn_tags__(self): | ||||||||||||||||||||||||||||||||||
return super().__ | ||||||||||||||||||||||||||||||||||
def __sklearn_tags__(self) -> Optional["_sklearn_Tags"]: | ||||||||||||||||||||||||||||||||||
return LGBMModel.__sklearn_tags__(self) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def fit( # type: ignore[override] | ||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to ensure that
mypy
checksscikit-learn
imports. Extra important now that I'm proposing adding an optional type hint on this newsklearn.utils.Tags
.