Skip to content

Commit 011e209

Browse files
authored
ENH make_sparse_spd_matrix use sparse memory layout (scikit-learn#27438)
1 parent f86f41d commit 011e209

File tree

3 files changed

+77
-17
lines changed

3 files changed

+77
-17
lines changed

doc/whats_new/v1.4.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ and classes are impacted:
113113
- :class:`impute.KNNImputer` in :pr:`27277` by :user:`Yao Xiao <Charlie-XIAO>`;
114114
- :class:`kernel_approximation.PolynomialCountSketch` in :pr:`27301` by
115115
:user:`Lohit SundaramahaLingam <lohitslohit>`;
116-
- :class:`neural_network.BernoulliRBM` in :pr:`27252` by `Yao Xiao <Charlie-XIAO>`.
116+
- :class:`neural_network.BernoulliRBM` in :pr:`27252` by
117+
:user:`Yao Xiao <Charlie-XIAO>`.
117118

118119
Changelog
119120
---------
@@ -168,6 +169,15 @@ Changelog
168169
`kdtree` and `balltree` values will be removed in 1.6.
169170
:pr:`26744` by :user:`Shreesha Kumar Bhat <Shreesha3112>`.
170171

172+
:mod:`sklearn.datasets`
173+
.......................
174+
175+
- |Enhancement| :func:`datasets.make_sparse_spd_matrix` now uses a more memory-
176+
efficient sparse layout. It also accepts a new keyword `sparse_format` that allows
177+
specifying the output format of the sparse matrix. By default `sparse_format=None`,
178+
which returns a dense numpy ndarray as before.
179+
:pr:`27438` by :user:`Yao Xiao <Charlie-XIAO>`.
180+
171181
:mod:`sklearn.decomposition`
172182
............................
173183

sklearn/datasets/_samples_generator.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,10 @@ def make_spd_matrix(n_dim, *, random_state=None):
15731573
"norm_diag": ["boolean"],
15741574
"smallest_coef": [Interval(Real, 0, 1, closed="both")],
15751575
"largest_coef": [Interval(Real, 0, 1, closed="both")],
1576+
"sparse_format": [
1577+
StrOptions({"bsr", "coo", "csc", "csr", "dia", "dok", "lil"}),
1578+
None,
1579+
],
15761580
"random_state": ["random_state"],
15771581
},
15781582
prefer_skip_nested_validation=True,
@@ -1584,6 +1588,7 @@ def make_sparse_spd_matrix(
15841588
norm_diag=False,
15851589
smallest_coef=0.1,
15861590
largest_coef=0.9,
1591+
sparse_format=None,
15871592
random_state=None,
15881593
):
15891594
"""Generate a sparse symmetric definite positive matrix.
@@ -1609,6 +1614,12 @@ def make_sparse_spd_matrix(
16091614
largest_coef : float, default=0.9
16101615
The value of the largest coefficient between 0 and 1.
16111616
1617+
sparse_format : str, default=None
1618+
String representing the output sparse format, such as 'csc', 'csr', etc.
1619+
If ``None``, return a dense numpy ndarray.
1620+
1621+
.. versionadded:: 1.4
1622+
16121623
random_state : int, RandomState instance or None, default=None
16131624
Determines random number generation for dataset creation. Pass an int
16141625
for reproducible output across multiple function calls.
@@ -1631,30 +1642,35 @@ def make_sparse_spd_matrix(
16311642
"""
16321643
random_state = check_random_state(random_state)
16331644

1634-
chol = -np.eye(dim)
1635-
aux = random_state.uniform(size=(dim, dim))
1636-
aux[aux < alpha] = 0
1637-
aux[aux > alpha] = smallest_coef + (
1638-
largest_coef - smallest_coef
1639-
) * random_state.uniform(size=np.sum(aux > alpha))
1640-
aux = np.tril(aux, k=-1)
1645+
chol = -sp.eye(dim)
1646+
aux = sp.random(
1647+
m=dim,
1648+
n=dim,
1649+
density=1 - alpha,
1650+
data_rvs=lambda x: random_state.uniform(
1651+
low=smallest_coef, high=largest_coef, size=x
1652+
),
1653+
random_state=random_state,
1654+
)
1655+
# We need to avoid "coo" format because it does not support slicing
1656+
aux = sp.tril(aux, k=-1, format="csc")
16411657

16421658
# Permute the lines: we don't want to have asymmetries in the final
16431659
# SPD matrix
16441660
permutation = random_state.permutation(dim)
16451661
aux = aux[permutation].T[permutation]
16461662
chol += aux
1647-
prec = np.dot(chol.T, chol)
1663+
prec = chol.T @ chol
16481664

16491665
if norm_diag:
16501666
# Form the diagonal vector into a row matrix
1651-
d = np.diag(prec).reshape(1, prec.shape[0])
1652-
d = 1.0 / np.sqrt(d)
1667+
d = sp.diags(1.0 / np.sqrt(prec.diagonal()))
1668+
prec = d @ prec @ d
16531669

1654-
prec *= d
1655-
prec *= d.T
1656-
1657-
return prec
1670+
if sparse_format is None:
1671+
return prec.toarray()
1672+
else:
1673+
return prec.asformat(sparse_format)
16581674

16591675

16601676
@validate_params(

sklearn/datasets/tests/test_samples_generator.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@
2222
make_regression,
2323
make_s_curve,
2424
make_sparse_coded_signal,
25+
make_sparse_spd_matrix,
2526
make_sparse_uncorrelated,
2627
make_spd_matrix,
2728
make_swiss_roll,
2829
)
2930
from sklearn.utils._testing import (
3031
assert_allclose,
32+
assert_allclose_dense_sparse,
3133
assert_almost_equal,
3234
assert_array_almost_equal,
3335
assert_array_equal,
@@ -549,10 +551,42 @@ def test_make_spd_matrix():
549551
from numpy.linalg import eig
550552

551553
eigenvalues, _ = eig(X)
552-
assert_array_equal(
553-
eigenvalues > 0, np.array([True] * 5), "X is not positive-definite"
554+
assert np.all(eigenvalues > 0), "X is not positive-definite"
555+
556+
557+
@pytest.mark.parametrize("norm_diag", [True, False])
558+
@pytest.mark.parametrize(
559+
"sparse_format", [None, "bsr", "coo", "csc", "csr", "dia", "dok", "lil"]
560+
)
561+
def test_make_sparse_spd_matrix(norm_diag, sparse_format, global_random_seed):
562+
dim = 5
563+
X = make_sparse_spd_matrix(
564+
dim=dim,
565+
norm_diag=norm_diag,
566+
sparse_format=sparse_format,
567+
random_state=global_random_seed,
554568
)
555569

570+
assert X.shape == (dim, dim), "X shape mismatch"
571+
if sparse_format is None:
572+
assert not sp.issparse(X)
573+
assert_allclose(X, X.T)
574+
Xarr = X
575+
else:
576+
assert sp.issparse(X) and X.format == sparse_format
577+
assert_allclose_dense_sparse(X, X.T)
578+
Xarr = X.toarray()
579+
580+
from numpy.linalg import eig
581+
582+
# Do not use scipy.sparse.linalg.eigs because it cannot find all eigenvalues
583+
eigenvalues, _ = eig(Xarr)
584+
assert np.all(eigenvalues > 0), "X is not positive-definite"
585+
586+
if norm_diag:
587+
# Check that leading diagonal elements are 1
588+
assert_array_almost_equal(Xarr.diagonal(), np.ones(dim))
589+
556590

557591
@pytest.mark.parametrize("hole", [False, True])
558592
def test_make_swiss_roll(hole):

0 commit comments

Comments
 (0)