Skip to content

Commit 0198e2c

Browse files
Giorgio Patriniamueller
authored andcommitted
warning for PCA with sparse input (scikit-learn#7649)
1 parent fa59873 commit 0198e2c

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

sklearn/decomposition/pca.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
from scipy import linalg
1717
from scipy.special import gammaln
18+
from scipy.sparse import issparse
1819

1920
from ..externals import six
2021

@@ -116,6 +117,9 @@ class PCA(_BasePCA):
116117
It can also use the scipy.sparse.linalg ARPACK implementation of the
117118
truncated SVD.
118119
120+
Notice that this class does not support sparse input. See
121+
:ref:`<TruncatedSVD>` for an alternative with sparse data.
122+
119123
Read more in the :ref:`User Guide <PCA>`.
120124
121125
Parameters
@@ -332,6 +336,13 @@ def fit_transform(self, X, y=None):
332336

333337
def _fit(self, X):
334338
"""Dispatch to the right submethod depending on the chosen solver."""
339+
340+
# Raise an error for sparse input.
341+
# This is more informative than the generic one raised by check_array.
342+
if issparse(X):
343+
raise TypeError('PCA does not support sparse input. See '
344+
'TruncatedSVD for a possible alternative.')
345+
335346
X = check_array(X, dtype=[np.float64], ensure_2d=True,
336347
copy=self.copy)
337348

sklearn/decomposition/tests/test_pca.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import scipy as sp
23
from itertools import product
34

45
from sklearn.utils.testing import assert_almost_equal
@@ -508,3 +509,15 @@ def fit_deprecated(X):
508509
assert_warns_message(DeprecationWarning, depr_message, fit_deprecated, X)
509510
Y_pca = PCA(svd_solver='randomized', random_state=0).fit_transform(X)
510511
assert_array_almost_equal(Y, Y_pca)
512+
513+
514+
def test_pca_spase_input():
515+
516+
X = np.random.RandomState(0).rand(5, 4)
517+
X = sp.sparse.csr_matrix(X)
518+
assert(sp.sparse.issparse(X))
519+
520+
for svd_solver in solver_list:
521+
pca = PCA(n_components=3, svd_solver=svd_solver)
522+
523+
assert_raises(TypeError, pca.fit, X)

sklearn/decomposition/truncated_svd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class TruncatedSVD(BaseEstimator, TransformerMixin):
3535
returned by the vectorizers in sklearn.feature_extraction.text. In that
3636
context, it is known as latent semantic analysis (LSA).
3737
38-
This estimator supports two algorithm: a fast randomized SVD solver, and
38+
This estimator supports two algorithms: a fast randomized SVD solver, and
3939
a "naive" algorithm that uses ARPACK as an eigensolver on (X * X.T) or
4040
(X.T * X), whichever is more efficient.
4141

0 commit comments

Comments
 (0)