Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: [python-package] use sklearn_compat for multi-version scikit-learn compatibilities #6740

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
MAINT use sklearn_compat for multi-version scikit-learn compatibilities
  • Loading branch information
glemaitre committed Dec 8, 2024
commit fed9115e18c17d76e1aec35ed0fbc99aec272611
98 changes: 6 additions & 92 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,91 +12,11 @@
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import assert_all_finite, check_array, check_X_y

try:
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
except ImportError:
from sklearn.cross_validation import BaseCrossValidator, GroupKFold, StratifiedKFold
from sklearn.utils.validation import NotFittedError
try:
from sklearn.utils.validation import _check_sample_weight
except ImportError:
from sklearn.utils.validation import check_consistent_length

# dummy function to support older version of scikit-learn
def _check_sample_weight(sample_weight: Any, X: Any, dtype: Any = None) -> Any:
check_consistent_length(sample_weight, X)
return sample_weight

try:
from sklearn.utils.validation import validate_data
except ImportError:
# validate_data() was added in scikit-learn 1.6, this function roughly imitates it for older versions.
# It can be removed when lightgbm's minimum scikit-learn version is at least 1.6.
def validate_data(
_estimator,
X,
y="no_validation",
accept_sparse: bool = True,
# 'force_all_finite' was renamed to 'ensure_all_finite' in scikit-learn 1.6
ensure_all_finite: bool = False,
ensure_min_samples: int = 1,
# trap other keyword arguments that only work on scikit-learn >=1.6, like 'reset'
**ignored_kwargs,
):
# it's safe to import _num_features unconditionally because:
#
# * it was first added in scikit-learn 0.24.2
# * lightgbm cannot be used with scikit-learn versions older than that
# * this validate_data() re-implementation will not be called in scikit-learn>=1.6
#
from sklearn.utils.validation import _num_features

# _num_features() raises a TypeError on 1-dimensional input. That's a problem
# because scikit-learn's 'check_fit1d' estimator check sets that expectation that
# estimators must raise a ValueError when a 1-dimensional input is passed to fit().
#
# So here, lightgbm avoids calling _num_features() on 1-dimensional inputs.
if hasattr(X, "shape") and len(X.shape) == 1:
n_features_in_ = 1
else:
n_features_in_ = _num_features(X)

no_val_y = isinstance(y, str) and y == "no_validation"

# NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
if no_val_y:
X = check_array(
X,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)
else:
X, y = check_X_y(
X,
y,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)

# this only needs to be updated at fit() time
_estimator.n_features_in_ = n_features_in_

# raise the same error that scikit-learn's `validate_data()` does on scikit-learn>=1.6
if _estimator.__sklearn_is_fitted__() and _estimator._n_features != n_features_in_:
raise ValueError(
f"X has {n_features_in_} features, but {_estimator.__class__.__name__} "
f"is expecting {_estimator._n_features} features as input."
)

if no_val_y:
return X
else:
return X, y
from sklearn.utils.validation import assert_all_finite
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import BaseCrossValidator, GroupKFold, StratifiedKFold
from sklearn.utils.validation import _check_sample_weight
from .sklearn_compat.utils.validation import validate_data

SKLEARN_INSTALLED = True
_LGBMBaseCrossValidator = BaseCrossValidator
Expand Down Expand Up @@ -144,13 +64,7 @@ class _LGBMRegressorBase: # type: ignore

# additional scikit-learn imports only for type hints
if TYPE_CHECKING:
# sklearn.utils.Tags can be imported unconditionally once
# lightgbm's minimum scikit-learn version is 1.6 or higher
try:
from sklearn.utils import Tags as _sklearn_Tags
except ImportError:
_sklearn_Tags = None

from .sklearn_compat.utils import Tags as _sklearn_Tags

"""pandas"""
try:
Expand Down
Loading
Loading