Skip to content

Commit 20dad58

Browse files
StefanieSengeradrinjalaliogrisel
authored
DOC documentation and error message for mismatching output formats in transformers with a sparse output (scikit-learn#26919)
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 7b6b657 commit 20dad58

File tree

5 files changed

+26
-9
lines changed

5 files changed

+26
-9
lines changed

sklearn/preprocessing/_encoders.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ class OneHotEncoder(_BaseEncoder):
450450
The features are encoded using a one-hot (aka 'one-of-K' or 'dummy')
451451
encoding scheme. This creates a binary column for each category and
452452
returns a sparse matrix or dense array (depending on the ``sparse_output``
453-
parameter)
453+
parameter).
454454
455455
By default, the encoder derives the categories based on the unique values
456456
in each feature. Alternatively, you can also specify the `categories`
@@ -522,7 +522,8 @@ class OneHotEncoder(_BaseEncoder):
522522
`sparse_output` instead.
523523
524524
sparse_output : bool, default=True
525-
Will return sparse matrix if set True else will return an array.
525+
When ``True``, it returns a :class:`scipy.sparse.csr_matrix`,
526+
i.e. a sparse matrix in "Compressed Sparse Row" (CSR) format.
526527
527528
.. versionadded:: 1.2
528529
`sparse` was renamed to `sparse_output`
@@ -995,8 +996,12 @@ def transform(self, X):
995996
"""
996997
Transform X using one-hot encoding.
997998
998-
If there are infrequent categories for a feature, the infrequent
999-
categories will be grouped into a single category.
999+
If `sparse_output=True` (default), it returns an instance of
1000+
:class:`scipy.sparse._csr.csr_matrix` (CSR format).
1001+
1002+
If there are infrequent categories for a feature, set by specifying
1003+
`max_categories` or `min_frequency`, the infrequent categories are
1004+
grouped into a single category.
10001005
10011006
Parameters
10021007
----------

sklearn/preprocessing/tests/test_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1979,7 +1979,7 @@ def test_one_hot_encoder_set_output():
19791979

19801980
ohe.set_output(transform="pandas")
19811981

1982-
match = "Pandas output does not support sparse data"
1982+
match = "Pandas output does not support sparse data. Set sparse_output=False"
19831983
with pytest.raises(ValueError, match=match):
19841984
ohe.fit_transform(X_df)
19851985

sklearn/utils/_set_output.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ def _wrap_in_pandas_container(
4343
Container with column names or unchanged `output`.
4444
"""
4545
if issparse(data_to_wrap):
46-
raise ValueError("Pandas output does not support sparse data.")
46+
raise ValueError(
47+
"The transformer outputs a scipy sparse matrix. "
48+
"Try to set the transformer output to a dense array or disable "
49+
"pandas output with set_output(transform='default')."
50+
)
4751

4852
if callable(columns):
4953
try:

sklearn/utils/estimator_checks.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4557,7 +4557,11 @@ def check_set_output_transform_pandas(name, transformer_orig):
45574557
outputs_pandas = _output_from_fit_transform(transformer_pandas, name, X, df, y)
45584558
except ValueError as e:
45594559
# transformer does not support sparse data
4560-
assert "Pandas output does not support sparse data." in str(e), e
4560+
error_message = str(e)
4561+
assert (
4562+
"Pandas output does not support sparse data." in error_message
4563+
or "The transformer outputs a scipy sparse matrix." in error_message
4564+
), e
45614565
return
45624566

45634567
for case in outputs_default:
@@ -4603,7 +4607,11 @@ def check_global_output_transform_pandas(name, transformer_orig):
46034607
)
46044608
except ValueError as e:
46054609
# transformer does not support sparse data
4606-
assert "Pandas output does not support sparse data." in str(e), e
4610+
error_message = str(e)
4611+
assert (
4612+
"Pandas output does not support sparse data." in error_message
4613+
or "The transformer outputs a scipy sparse matrix." in error_message
4614+
), e
46074615
return
46084616

46094617
for case in outputs_default:

sklearn/utils/tests/test_set_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test__wrap_in_pandas_container_error_validation(csr_container):
4646
"""Check errors in _wrap_in_pandas_container."""
4747
X = np.asarray([[1, 0, 3], [0, 0, 1]])
4848
X_csr = csr_container(X)
49-
match = "Pandas output does not support sparse data"
49+
match = "The transformer outputs a scipy sparse matrix."
5050
with pytest.raises(ValueError, match=match):
5151
_wrap_in_pandas_container(X_csr, columns=["a", "b", "c"])
5252

0 commit comments

Comments
 (0)