forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
MAINT Param validation for SpectralClustering (scikit-learn#23851)
Co-authored-by: jeremiedbb <[email protected]> Co-authored-by: Meekail Zain <[email protected]>
- Loading branch information
1 parent
cb8ec24
commit 46623ef
Showing
3 changed files
with
36 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
# Andrew Knyazev <[email protected]> | ||
# License: BSD 3 clause | ||
|
||
import numbers | ||
from numbers import Integral, Real | ||
import warnings | ||
|
||
import numpy as np | ||
|
@@ -15,8 +15,9 @@ | |
from scipy.sparse import csc_matrix | ||
|
||
from ..base import BaseEstimator, ClusterMixin | ||
from ..utils import check_random_state, as_float_array, check_scalar | ||
from ..metrics.pairwise import pairwise_kernels | ||
from ..utils._param_validation import Interval, StrOptions | ||
from ..utils import check_random_state, as_float_array | ||
from ..metrics.pairwise import pairwise_kernels, KERNEL_PARAMS | ||
from ..neighbors import kneighbors_graph, NearestNeighbors | ||
from ..manifold import spectral_embedding | ||
from ._kmeans import k_means | ||
|
@@ -426,8 +427,9 @@ class SpectralClustering(ClusterMixin, BaseEstimator): | |
but may also lead to instabilities. If None, then ``'arpack'`` is | ||
used. See [4]_ for more details regarding `'lobpcg'`. | ||
n_components : int, default=n_clusters | ||
Number of eigenvectors to use for the spectral embedding. | ||
n_components : int, default=None | ||
Number of eigenvectors to use for the spectral embedding. If None, | ||
defaults to `n_clusters`. | ||
random_state : int, RandomState instance, default=None | ||
A pseudo random number generator used for the initialization | ||
|
@@ -615,6 +617,33 @@ class SpectralClustering(ClusterMixin, BaseEstimator): | |
random_state=0) | ||
""" | ||
|
||
_parameter_constraints = { | ||
"n_clusters": [Interval(Integral, 1, None, closed="left")], | ||
"eigen_solver": [StrOptions({"arpack", "lobpcg", "amg"}), None], | ||
"n_components": [Interval(Integral, 1, None, closed="left"), None], | ||
"random_state": ["random_state"], | ||
"n_init": [Interval(Integral, 1, None, closed="left")], | ||
"gamma": [Interval(Real, 0, None, closed="neither")], | ||
"affinity": [ | ||
callable, | ||
StrOptions( | ||
set(KERNEL_PARAMS) | ||
| {"nearest_neighbors", "precomputed", "precomputed_nearest_neighbors"} | ||
), | ||
], | ||
"n_neighbors": [Interval(Integral, 1, None, closed="left")], | ||
"eigen_tol": [ | ||
Interval(Real, 0.0, None, closed="left"), | ||
StrOptions({"auto"}), | ||
], | ||
"assign_labels": [StrOptions({"kmeans", "discretize", "cluster_qr"})], | ||
"degree": [Interval(Integral, 1, None, closed="left")], | ||
"coef0": [Interval(Real, None, None, closed="neither")], | ||
"kernel_params": [dict, None], | ||
"n_jobs": [Integral, None], | ||
"verbose": ["verbose"], | ||
} | ||
|
||
def __init__( | ||
self, | ||
n_clusters=8, | ||
|
@@ -672,6 +701,8 @@ def fit(self, X, y=None): | |
self : object | ||
A fitted instance of the estimator. | ||
""" | ||
self._validate_params() | ||
|
||
X = self._validate_data( | ||
X, | ||
accept_sparse=["csr", "csc", "coo"], | ||
|
@@ -690,55 +721,6 @@ def fit(self, X, y=None): | |
"set ``affinity=precomputed``." | ||
) | ||
|
||
check_scalar( | ||
self.n_clusters, | ||
"n_clusters", | ||
target_type=numbers.Integral, | ||
min_val=1, | ||
include_boundaries="left", | ||
) | ||
|
||
check_scalar( | ||
self.n_init, | ||
"n_init", | ||
target_type=numbers.Integral, | ||
min_val=1, | ||
include_boundaries="left", | ||
) | ||
|
||
check_scalar( | ||
self.gamma, | ||
"gamma", | ||
target_type=numbers.Real, | ||
min_val=1.0, | ||
include_boundaries="left", | ||
) | ||
|
||
check_scalar( | ||
self.n_neighbors, | ||
"n_neighbors", | ||
target_type=numbers.Integral, | ||
min_val=1, | ||
include_boundaries="left", | ||
) | ||
|
||
if self.eigen_tol != "auto": | ||
check_scalar( | ||
self.eigen_tol, | ||
"eigen_tol", | ||
target_type=numbers.Real, | ||
min_val=0, | ||
include_boundaries="left", | ||
) | ||
|
||
check_scalar( | ||
self.degree, | ||
"degree", | ||
target_type=numbers.Integral, | ||
min_val=1, | ||
include_boundaries="left", | ||
) | ||
|
||
if self.affinity == "nearest_neighbors": | ||
connectivity = kneighbors_graph( | ||
X, n_neighbors=self.n_neighbors, include_self=True, n_jobs=self.n_jobs | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters