Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] require scikit-learn>=0.24.2, make scikit-learn estimators compatible with scikit-learn>=1.6.0dev #6651

Merged

Conversation

vnherdeiro
Copy link
Contributor

@vnherdeiro vnherdeiro commented Sep 11, 2024

Fixes #6653

(edit: taken over by @jameslamb, description re-written below)

  • raises minimum supported version to scikit-learn>=0.24.2
  • implements __sklearn_tags__() (replacement for _more_tags()) for scikit-learn estimators
  • starts using sklearn.utils.validation.validate_data() in fit() and predict()
  • adds tests confirming that scikit-learn estimators reject inputs with the wrong number of features

Notes for Reviewers

see https://scikit-learn.org/dev/whats_new/v1.6.html and scikit-learn/scikit-learn#29677

@vnherdeiro
Copy link
Contributor Author

Update:

The change introduced in scikit-learn/scikit-learn#29677 makes it hard to subclass a sklearn estimator in a codebase while being compatible with sklearn < 1.6.0 and sklearn >= 1.6.0. Essentially the former looks up ._more_tags() and ignore __sklearn_tags__() while the former looks up __sklearn_tags__() and forbids existence of a
._more_tags() tags method.

The issue is discussed here:
scikit-learn/scikit-learn#29801

and it looks like a relaxation of the impossibility of having both ._more_tags() and __sklearn_tags__() simulatenously will be relaxed. If it goes through let's park this MR until lightgbm decides to force a scikit-learn>=1.6.0 dependency.

@adrinjalali
Copy link

@vnherdeiro note that it's possible already to support both with this method (scikit-learn/scikit-learn#29677 (comment)), however, the version check and @available_if are going to be unnecessary once we merge scikit-learn/scikit-learn#29801

@vnherdeiro
Copy link
Contributor Author

vnherdeiro commented Sep 12, 2024 via email

@jameslamb
Copy link
Collaborator

jameslamb commented Sep 15, 2024

Thanks for starting on this @vnherdeiro . I've documented it in an issue: #6653 (and added that to the PR description).

Note there that I intentionally put the exact errors messages in plain text instead of just referring to _more_tags() ... that helps people to find this work from search engines.

Note also that the _more_tags() thing is only 1 of 3 breaking changes in scikit-learn that lightgbm will have to adjust to to get those tests passing again with scikit-learn==1.6.0.

Copy link
Collaborator

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for starting on this! Please see scikit-learn/scikit-learn#29801 (comment):

The story becomes "If you want to support multiple scikit-learn versions, define both."

I think we should leave _more_tags() untouched and add __sklearn_tags__(). And have self.__sklearn_tags__() call self._more_tags() to get its data, so we don't define things like _xfail_checks twice.

Do you have time to do that in the next few days? We need to fix this to unblock CI here, so if you don't have time to fix it this week please let me know and I will work on this.

@jameslamb jameslamb changed the title __sklearn_tags__ replacing sklearn's BaseEstimator._more_tags_ [python-package] make scikit-learn tags compatible with scikit-learn>=1.16 Sep 15, 2024
@jameslamb jameslamb changed the title [python-package] make scikit-learn tags compatible with scikit-learn>=1.16 [python-package] make scikit-learn estimator tags compatible with scikit-learn>=1.16 Sep 15, 2024
@vnherdeiro
Copy link
Contributor Author

@jameslamb Have just pushe a sklearn_tags trying a conversion from _more_tags. I added a out of current argument scope warning to catch a change from the arguments in _more_tags (they don't seem to change much though).

@vnherdeiro vnherdeiro changed the title [python-package] make scikit-learn estimator tags compatible with scikit-learn>=1.16 [python-package] make scikit-learn estimator tags compatible with scikit-learn>=1.6.0dev Sep 15, 2024
Copy link

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a maintainer here, but coming from sklearn side. Leaving thoughts hoping it'd help.

python-package/lightgbm/sklearn.py Outdated Show resolved Hide resolved
python-package/lightgbm/sklearn.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@jameslamb jameslamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this.

I've reviewed the dataclasses at https://github.com/scikit-learn/scikit-learn/blob/e2ee93156bd3692722a39130c011eea313628690/sklearn/utils/_tags.py and agree with the choices you've made about how to map the dictionary-formatted values from _more_tags() to the dataclass attributes scikit-learn now prefers.

Please see the other comments about simplifying this.

python-package/lightgbm/sklearn.py Outdated Show resolved Hide resolved
python-package/lightgbm/sklearn.py Outdated Show resolved Hide resolved
@@ -1067,6 +1137,21 @@ def n_features_in_(self) -> int:
raise LGBMNotFittedError("No n_features_in found. Need to call fit beforehand.")
return self._n_features_in

@n_features_in_.setter
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you pass reset=True to sklearn.utils.validation.validate_data(), it will try to:

We want the "set estimator.n_features_in_" behavior, because without it we have to manually set estimator.n_features_in_ in fit().

Doing that requires determining the number of features in X, which requires either re-implementing something like sklearn.utils.validation._num_features() (as I originally tried to do) or just calling that function directly. But that function can't safely be called directly before calling check_array(), because it raises a TypeError on 1-D inputs, which violates the check_fit1d estimator check (code link).

So here, I'm proposing that we do the following:

  • add a setter for n_features_in_ and a deleter for feature_names_in_
  • pass reset=True at fit() time to validate_data()
  • modify the pre-1.6 implementation of validate_data() in compat.py to match

@jameslamb jameslamb changed the title [python-package] make scikit-learn estimator tags compatible with scikit-learn>=1.6.0dev [python-package] require scikit-learn>=0.24.2, make scikit-learn estimators compatible with scikit-learn>=1.6.0dev Oct 6, 2024
@jameslamb jameslamb requested a review from StrikerRUS October 6, 2024 06:49
@jameslamb
Copy link
Collaborator

@StrikerRUS your comments were definitely not "minor", they really helped a lot! I've re-thought a lot of this PR based on trying to implement those suggestions.

This is ready for another review. Thank you for all your reviewing effort here, I know this change has become quite complex and there are many competing constraints it's trying to satisfy.

Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hope this time my review comments will be really minor 😄


# NOTE: check_X_y() calls check_array() internally, so only need to call one or the other of them here
if no_val_y:
X = check_array(X, accept_sparse=accept_sparse, force_all_finite=ensure_all_finite)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adds ensure_min_samples=ensure_min_samples,

Suggested change
X = check_array(X, accept_sparse=accept_sparse, force_all_finite=ensure_all_finite)
X = check_array(
X,
accept_sparse=accept_sparse,
force_all_finite=ensure_all_finite,
ensure_min_samples=ensure_min_samples,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intentionally omitted ensure_min_samples. It's already not being passed in the one place it's used on master:

X = _LGBMCheckArray(X, accept_sparse=True, force_all_finite=False)

This call to check_array() only happens in predict(), so I also think we should avoid any more validation than absolutely necessary to comply with the scikit-learn API, since applications calling predict() might care more about latency than those calling fit().

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Fine with me. However, I think you should be aware that the default argument is 1, not None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not realize that! I'm glad you mentioned it, just checked and it looks like omitting this argument from the call still results in that validation being performed.

    if ensure_min_samples > 0:
        n_samples = _num_samples(array)
        if n_samples < ensure_min_samples:
            raise ValueError(
                "Found array with %d sample(s) (shape=%s) while a"
                " minimum of %d is required%s."
                % (n_samples, array.shape, ensure_min_samples, context)
            )

https://github.com/scikit-learn/scikit-learn/blob/be52df50f1e9e9a6546248ccd7160a0a289f482c/sklearn/utils/validation.py#L1125-L1132

If that's the case, then my point about avoiding the overhead at predict() time doesn't matter... we're getting that overhead anyway. I guess the scikit-learn interface is probably not what you'd choose for low-latency predictions anyway... I will change this to pass through ensure_min_samples and set that to 1 in the call in sklearn.py, to make it explicit.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just made this change in 8ef1deb. Now ensure_min_samples=1 will be passed at predict() time.

Thanks for the suggestion and talking through it with me.

python-package/lightgbm/sklearn.py Outdated Show resolved Hide resolved
python-package/lightgbm/sklearn.py Outdated Show resolved Hide resolved
python-package/lightgbm/sklearn.py Show resolved Hide resolved
tests/python_package_test/test_sklearn.py Show resolved Hide resolved
tests/python_package_test/test_sklearn.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@StrikerRUS StrikerRUS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

That was really challenging!

@vnherdeiro Thank you for starting the work!

@adrinjalali Thanks for your help here!

@jameslamb Thanks a ton for the huge work done here!

@jameslamb
Copy link
Collaborator

Thanks everyone for the help, and especially @StrikerRUS for a thorough review of a very complex change!

@jameslamb jameslamb merged commit 7eae66a into microsoft:master Oct 9, 2024
48 checks passed
@vnherdeiro
Copy link
Contributor Author

Thanks for all the work @jameslamb Feeling glad this went in!

@StrikerRUS
Copy link
Collaborator

@jameslamb I think it's more important to set breaking than fix label to this PR due to scikit-learn minimum version bump. WDYT?

@jameslamb jameslamb added breaking and removed fix labels Oct 20, 2024
@jameslamb
Copy link
Collaborator

Sure, that is ok with me. I just made that change, it should be reflected the next time release-drafter regenerates the release notes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[ci] [python-package] scikit-learn compatibility tests fail with scikit-learn 1.6.dev0
5 participants