Skip to content

Commit

Permalink
ENH make_sparse_spd_matrix use sparse memory layout (scikit-learn#2…
Browse files Browse the repository at this point in the history
  • Loading branch information
Charlie-XIAO authored Sep 26, 2023
1 parent f86f41d commit 011e209
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 17 deletions.
12 changes: 11 additions & 1 deletion doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ and classes are impacted:
- :class:`impute.KNNImputer` in :pr:`27277` by :user:`Yao Xiao <Charlie-XIAO>`;
- :class:`kernel_approximation.PolynomialCountSketch` in :pr:`27301` by
:user:`Lohit SundaramahaLingam <lohitslohit>`;
- :class:`neural_network.BernoulliRBM` in :pr:`27252` by `Yao Xiao <Charlie-XIAO>`.
- :class:`neural_network.BernoulliRBM` in :pr:`27252` by
:user:`Yao Xiao <Charlie-XIAO>`.

Changelog
---------
Expand Down Expand Up @@ -168,6 +169,15 @@ Changelog
`kdtree` and `balltree` values will be removed in 1.6.
:pr:`26744` by :user:`Shreesha Kumar Bhat <Shreesha3112>`.

:mod:`sklearn.datasets`
.......................

- |Enhancement| :func:`datasets.make_sparse_spd_matrix` now uses a more memory-
efficient sparse layout. It also accepts a new keyword `sparse_format` that allows
specifying the output format of the sparse matrix. By default `sparse_format=None`,
which returns a dense numpy ndarray as before.
:pr:`27438` by :user:`Yao Xiao <Charlie-XIAO>`.

:mod:`sklearn.decomposition`
............................

Expand Down
44 changes: 30 additions & 14 deletions sklearn/datasets/_samples_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,10 @@ def make_spd_matrix(n_dim, *, random_state=None):
"norm_diag": ["boolean"],
"smallest_coef": [Interval(Real, 0, 1, closed="both")],
"largest_coef": [Interval(Real, 0, 1, closed="both")],
"sparse_format": [
StrOptions({"bsr", "coo", "csc", "csr", "dia", "dok", "lil"}),
None,
],
"random_state": ["random_state"],
},
prefer_skip_nested_validation=True,
Expand All @@ -1584,6 +1588,7 @@ def make_sparse_spd_matrix(
norm_diag=False,
smallest_coef=0.1,
largest_coef=0.9,
sparse_format=None,
random_state=None,
):
"""Generate a sparse symmetric definite positive matrix.
Expand All @@ -1609,6 +1614,12 @@ def make_sparse_spd_matrix(
largest_coef : float, default=0.9
The value of the largest coefficient between 0 and 1.
sparse_format : str, default=None
String representing the output sparse format, such as 'csc', 'csr', etc.
If ``None``, return a dense numpy ndarray.
.. versionadded:: 1.4
random_state : int, RandomState instance or None, default=None
Determines random number generation for dataset creation. Pass an int
for reproducible output across multiple function calls.
Expand All @@ -1631,30 +1642,35 @@ def make_sparse_spd_matrix(
"""
random_state = check_random_state(random_state)

chol = -np.eye(dim)
aux = random_state.uniform(size=(dim, dim))
aux[aux < alpha] = 0
aux[aux > alpha] = smallest_coef + (
largest_coef - smallest_coef
) * random_state.uniform(size=np.sum(aux > alpha))
aux = np.tril(aux, k=-1)
chol = -sp.eye(dim)
aux = sp.random(
m=dim,
n=dim,
density=1 - alpha,
data_rvs=lambda x: random_state.uniform(
low=smallest_coef, high=largest_coef, size=x
),
random_state=random_state,
)
# We need to avoid "coo" format because it does not support slicing
aux = sp.tril(aux, k=-1, format="csc")

# Permute the lines: we don't want to have asymmetries in the final
# SPD matrix
permutation = random_state.permutation(dim)
aux = aux[permutation].T[permutation]
chol += aux
prec = np.dot(chol.T, chol)
prec = chol.T @ chol

if norm_diag:
# Form the diagonal vector into a row matrix
d = np.diag(prec).reshape(1, prec.shape[0])
d = 1.0 / np.sqrt(d)
d = sp.diags(1.0 / np.sqrt(prec.diagonal()))
prec = d @ prec @ d

prec *= d
prec *= d.T

return prec
if sparse_format is None:
return prec.toarray()
else:
return prec.asformat(sparse_format)


@validate_params(
Expand Down
38 changes: 36 additions & 2 deletions sklearn/datasets/tests/test_samples_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
make_regression,
make_s_curve,
make_sparse_coded_signal,
make_sparse_spd_matrix,
make_sparse_uncorrelated,
make_spd_matrix,
make_swiss_roll,
)
from sklearn.utils._testing import (
assert_allclose,
assert_allclose_dense_sparse,
assert_almost_equal,
assert_array_almost_equal,
assert_array_equal,
Expand Down Expand Up @@ -549,10 +551,42 @@ def test_make_spd_matrix():
from numpy.linalg import eig

eigenvalues, _ = eig(X)
assert_array_equal(
eigenvalues > 0, np.array([True] * 5), "X is not positive-definite"
assert np.all(eigenvalues > 0), "X is not positive-definite"


@pytest.mark.parametrize("norm_diag", [True, False])
@pytest.mark.parametrize(
"sparse_format", [None, "bsr", "coo", "csc", "csr", "dia", "dok", "lil"]
)
def test_make_sparse_spd_matrix(norm_diag, sparse_format, global_random_seed):
dim = 5
X = make_sparse_spd_matrix(
dim=dim,
norm_diag=norm_diag,
sparse_format=sparse_format,
random_state=global_random_seed,
)

assert X.shape == (dim, dim), "X shape mismatch"
if sparse_format is None:
assert not sp.issparse(X)
assert_allclose(X, X.T)
Xarr = X
else:
assert sp.issparse(X) and X.format == sparse_format
assert_allclose_dense_sparse(X, X.T)
Xarr = X.toarray()

from numpy.linalg import eig

# Do not use scipy.sparse.linalg.eigs because it cannot find all eigenvalues
eigenvalues, _ = eig(Xarr)
assert np.all(eigenvalues > 0), "X is not positive-definite"

if norm_diag:
# Check that leading diagonal elements are 1
assert_array_almost_equal(Xarr.diagonal(), np.ones(dim))


@pytest.mark.parametrize("hole", [False, True])
def test_make_swiss_roll(hole):
Expand Down

0 comments on commit 011e209

Please sign in to comment.