Skip to content

Commit ea6c77b

Browse files
TamaraAtanasoskaglemaitrejeremiedbb
authored
ENH Add "ensure_non_negative" option to check_array (scikit-learn#29540)
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai> Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
1 parent 9a6b7d6 commit ea6c77b

File tree

10 files changed

+66
-30
lines changed

10 files changed

+66
-30
lines changed

doc/whats_new/v1.6.rst

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,9 @@ Changelog
177177
.......................
178178

179179
- |Efficiency| Small runtime improvement of fitting
180-
:class:`ensemble.HistGradientBoostingClassifier` and :class:`ensemble.HistGradientBoostingRegressor`
181-
by parallelizing the initial search for bin thresholds
180+
:class:`ensemble.HistGradientBoostingClassifier` and
181+
:class:`ensemble.HistGradientBoostingRegressor` by parallelizing the initial search
182+
for bin thresholds.
182183
:pr:`28064` by :user:`Christian Lorentzen <lorentzenchr>`.
183184

184185
- |Enhancement| The verbosity of :class:`ensemble.HistGradientBoostingClassifier`
@@ -193,9 +194,10 @@ Changelog
193194
:pr:`28622` by :user:`Adam Li <adam2392>` and
194195
:user:`Sérgio Pereira <sergiormpereira>`.
195196

196-
- |Feature| :class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor` now support
197-
missing-values in the data matrix `X`. Missing-values are handled by randomly moving all of
198-
the samples to the left, or right child node as the tree is traversed.
197+
- |Feature| :class:`ensemble.ExtraTreesClassifier` and
198+
:class:`ensemble.ExtraTreesRegressor` now support missing-values in the data matrix
199+
`X`. Missing-values are handled by randomly moving all of the samples to the left, or
200+
right child node as the tree is traversed.
199201
:pr:`28268` by :user:`Adam Li <adam2392>`.
200202

201203
:mod:`sklearn.impute`
@@ -249,7 +251,8 @@ Changelog
249251
estimator without re-fitting it.
250252
:pr:`29067` by :user:`Guillaume Lemaitre <glemaitre>`.
251253

252-
- |Fix| Improve error message when :func:`model_selection.RepeatedStratifiedKFold.split` is called without a `y` argument
254+
- |Fix| Improve error message when :func:`model_selection.RepeatedStratifiedKFold.split`
255+
is called without a `y` argument
253256
:pr:`29402` by :user:`Anurag Varma <Anurag-Varma>`.
254257

255258
:mod:`sklearn.neighbors`
@@ -285,6 +288,11 @@ Changelog
285288
:mod:`sklearn.utils`
286289
....................
287290

291+
- |Enhancement| :func:`utils.validation.check_array` now accepts `ensure_non_negative`
292+
to check for negative values in the passed array, until now only available through
293+
calling :func:`utils.validation.check_non_negative`.
294+
:pr:`29540` by :user:`Tamara Atanasoska <tamaraatanasoska>`.
295+
288296
- |API| the `assert_all_finite` parameter of functions :func:`utils.check_array`,
289297
:func:`utils.check_X_y`, :func:`utils.as_float_array` is renamed into
290298
`ensure_all_finite`. `force_all_finite` will be removed in 1.8.

sklearn/decomposition/_nmf.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,8 +1700,6 @@ def _fit_transform(self, X, y=None, W=None, H=None, update_H=True):
17001700
n_iter_ : int
17011701
Actual number of iterations.
17021702
"""
1703-
check_non_negative(X, "NMF (input X)")
1704-
17051703
# check parameters
17061704
self._check_params(X)
17071705

@@ -1777,7 +1775,11 @@ def transform(self, X):
17771775
"""
17781776
check_is_fitted(self)
17791777
X = self._validate_data(
1780-
X, accept_sparse=("csr", "csc"), dtype=[np.float64, np.float32], reset=False
1778+
X,
1779+
accept_sparse=("csr", "csc"),
1780+
dtype=[np.float64, np.float32],
1781+
reset=False,
1782+
ensure_non_negative=True,
17811783
)
17821784

17831785
with config_context(assume_finite=True):

sklearn/ensemble/_weight_boosting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def fit(self, X, y, sample_weight=None):
137137
)
138138

139139
sample_weight = _check_sample_weight(
140-
sample_weight, X, np.float64, copy=True, only_non_negative=True
140+
sample_weight, X, np.float64, copy=True, ensure_non_negative=True
141141
)
142142
sample_weight /= sample_weight.sum()
143143

sklearn/kernel_approximation.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from .utils.validation import (
2525
_check_feature_names_in,
2626
check_is_fitted,
27-
check_non_negative,
2827
)
2928

3029

@@ -674,8 +673,7 @@ def fit(self, X, y=None):
674673
self : object
675674
Returns the transformer.
676675
"""
677-
X = self._validate_data(X, accept_sparse="csr")
678-
check_non_negative(X, "X in AdditiveChi2Sampler.fit")
676+
X = self._validate_data(X, accept_sparse="csr", ensure_non_negative=True)
679677

680678
if self.sample_interval is None and self.sample_steps not in (1, 2, 3):
681679
raise ValueError(
@@ -701,8 +699,9 @@ def transform(self, X):
701699
Whether the return value is an array or sparse matrix depends on
702700
the type of the input X.
703701
"""
704-
X = self._validate_data(X, accept_sparse="csr", reset=False)
705-
check_non_negative(X, "X in AdditiveChi2Sampler.transform")
702+
X = self._validate_data(
703+
X, accept_sparse="csr", reset=False, ensure_non_negative=True
704+
)
706705
sparse = sp.issparse(X)
707706

708707
if self.sample_interval is None:

sklearn/linear_model/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def fit(self, X, y, sample_weight=None):
609609
has_sw = sample_weight is not None
610610
if has_sw:
611611
sample_weight = _check_sample_weight(
612-
sample_weight, X, dtype=X.dtype, only_non_negative=True
612+
sample_weight, X, dtype=X.dtype, ensure_non_negative=True
613613
)
614614

615615
# Note that neither _rescale_data nor the rest of the fit method of

sklearn/neighbors/_base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ..utils.fixes import parse_version, sp_base_version
3131
from ..utils.multiclass import check_classification_targets
3232
from ..utils.parallel import Parallel, delayed
33-
from ..utils.validation import _to_object_array, check_is_fitted, check_non_negative
33+
from ..utils.validation import _to_object_array, check_is_fitted
3434
from ._ball_tree import BallTree
3535
from ._kd_tree import KDTree
3636

@@ -167,8 +167,7 @@ def _check_precomputed(X):
167167
case only non-zero elements may be considered neighbors.
168168
"""
169169
if not issparse(X):
170-
X = check_array(X)
171-
check_non_negative(X, whom="precomputed distance matrix.")
170+
X = check_array(X, ensure_non_negative=True, input_name="X")
172171
return X
173172
else:
174173
graph = X
@@ -179,8 +178,12 @@ def _check_precomputed(X):
179178
"its handling of explicit zeros".format(graph.format)
180179
)
181180
copied = graph.format != "csr"
182-
graph = check_array(graph, accept_sparse="csr")
183-
check_non_negative(graph, whom="precomputed distance matrix.")
181+
graph = check_array(
182+
graph,
183+
accept_sparse="csr",
184+
ensure_non_negative=True,
185+
input_name="precomputed distance matrix",
186+
)
184187
graph = sort_graph_by_row_values(graph, copy=not copied, warn_when_not_sorted=True)
185188

186189
return graph

sklearn/neighbors/_kde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def fit(self, X, y=None, sample_weight=None):
230230

231231
if sample_weight is not None:
232232
sample_weight = _check_sample_weight(
233-
sample_weight, X, dtype=np.float64, only_non_negative=True
233+
sample_weight, X, dtype=np.float64, ensure_non_negative=True
234234
)
235235

236236
kwargs = self.metric_params

sklearn/tests/test_kernel_approximation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,9 @@ def test_additive_chi2_sampler_exceptions():
200200
transformer = AdditiveChi2Sampler()
201201
X_neg = X.copy()
202202
X_neg[0, 0] = -1
203-
with pytest.raises(ValueError, match="X in AdditiveChi2Sampler.fit"):
203+
with pytest.raises(ValueError, match="X in AdditiveChi2Sampler"):
204204
transformer.fit(X_neg)
205-
with pytest.raises(ValueError, match="X in AdditiveChi2Sampler.transform"):
205+
with pytest.raises(ValueError, match="X in AdditiveChi2Sampler"):
206206
transformer.fit(X)
207207
transformer.transform(X_neg)
208208

sklearn/utils/tests/test_validation.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,17 @@ def test_check_array():
462462
result = check_array(X_no_array)
463463
assert isinstance(result, np.ndarray)
464464

465+
# check negative values when ensure_non_negative=True
466+
X_neg = check_array([[1, 2], [-3, 4]])
467+
err_msg = "Negative values in data passed to X in RandomForestRegressor"
468+
with pytest.raises(ValueError, match=err_msg):
469+
check_array(
470+
X_neg,
471+
ensure_non_negative=True,
472+
input_name="X",
473+
estimator=RandomForestRegressor(),
474+
)
475+
465476

466477
@pytest.mark.parametrize(
467478
"X",
@@ -1480,13 +1491,13 @@ def test_check_sample_weight():
14801491
sample_weight = _check_sample_weight(None, X, dtype=X.dtype)
14811492
assert sample_weight.dtype == np.float64
14821493

1483-
# check negative weight when only_non_negative=True
1494+
# check negative weight when ensure_non_negative=True
14841495
X = np.ones((5, 2))
14851496
sample_weight = np.ones(_num_samples(X))
14861497
sample_weight[-1] = -10
14871498
err_msg = "Negative values in data passed to `sample_weight`"
14881499
with pytest.raises(ValueError, match=err_msg):
1489-
_check_sample_weight(sample_weight, X, only_non_negative=True)
1500+
_check_sample_weight(sample_weight, X, ensure_non_negative=True)
14901501

14911502

14921503
@pytest.mark.parametrize("toarray", [np.array, sp.csr_matrix, sp.csc_matrix])

sklearn/utils/validation.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,7 @@ def check_array(
741741
force_writeable=False,
742742
force_all_finite="deprecated",
743743
ensure_all_finite=None,
744+
ensure_non_negative=False,
744745
ensure_2d=True,
745746
allow_nd=False,
746747
ensure_min_samples=1,
@@ -828,6 +829,12 @@ def check_array(
828829
.. versionadded:: 1.6
829830
`force_all_finite` was renamed to `ensure_all_finite`.
830831
832+
ensure_non_negative : bool, default=False
833+
Make sure the array has only non-negative values. If True, an array that
834+
contains negative values will raise a ValueError.
835+
836+
.. versionadded:: 1.6
837+
831838
ensure_2d : bool, default=True
832839
Whether to raise a value error if array is not 2D.
833840
@@ -1132,6 +1139,12 @@ def is_sparse(dtype):
11321139
% (n_features, array.shape, ensure_min_features, context)
11331140
)
11341141

1142+
if ensure_non_negative:
1143+
whom = input_name
1144+
if estimator_name:
1145+
whom += f" in {estimator_name}"
1146+
check_non_negative(array, whom)
1147+
11351148
if force_writeable:
11361149
# By default, array.copy() creates a C-ordered copy. We set order=K to
11371150
# preserve the order of the array.
@@ -1739,7 +1752,7 @@ def check_non_negative(X, whom):
17391752
X_min = xp.min(X)
17401753

17411754
if X_min < 0:
1742-
raise ValueError("Negative values in data passed to %s" % whom)
1755+
raise ValueError(f"Negative values in data passed to {whom}.")
17431756

17441757

17451758
def check_scalar(
@@ -2044,7 +2057,7 @@ def _check_psd_eigenvalues(lambdas, enable_warnings=False):
20442057

20452058

20462059
def _check_sample_weight(
2047-
sample_weight, X, dtype=None, copy=False, only_non_negative=False
2060+
sample_weight, X, dtype=None, copy=False, ensure_non_negative=False
20482061
):
20492062
"""Validate sample weights.
20502063
@@ -2061,7 +2074,7 @@ def _check_sample_weight(
20612074
X : {ndarray, list, sparse matrix}
20622075
Input data.
20632076
2064-
only_non_negative : bool, default=False,
2077+
ensure_non_negative : bool, default=False,
20652078
Whether or not the weights are expected to be non-negative.
20662079
20672080
.. versionadded:: 1.0
@@ -2112,7 +2125,7 @@ def _check_sample_weight(
21122125
)
21132126
)
21142127

2115-
if only_non_negative:
2128+
if ensure_non_negative:
21162129
check_non_negative(sample_weight, "`sample_weight`")
21172130

21182131
return sample_weight

0 commit comments

Comments
 (0)