Skip to content

Commit 5a2ac9a

Browse files
authored
Merge branch 'scikit-learn:main' into submodulev2
2 parents 6b4d0e7 + 18cf8d0 commit 5a2ac9a

File tree

10 files changed

+205
-124
lines changed

10 files changed

+205
-124
lines changed

doc/developers/develop.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,9 @@ The current set of estimator tags are:
535535
allow_nan (default=False)
536536
whether the estimator supports data with missing values encoded as np.nan
537537

538+
array_api_support (default=False)
539+
whether the estimator supports Array API compatible inputs.
540+
538541
binary_only (default=False)
539542
whether estimator supports binary classification but lacks multi-class
540543
classification support.

doc/modules/array_api.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,18 @@ Estimators with support for `Array API`-compatible inputs
9393
Coverage for more estimators is expected to grow over time. Please follow the
9494
dedicated `meta-issue on GitHub
9595
<https://github.com/scikit-learn/scikit-learn/issues/22352>`_ to track progress.
96+
97+
Common estimator checks
98+
=======================
99+
100+
Add the `array_api_support` tag to an estimator's set of tags to indicate that
101+
it supports the Array API. This will enable dedicated checks as part of the
102+
common tests to verify that the estimators result's are the same when using
103+
vanilla NumPy and Array API inputs.
104+
105+
To run these checks you need to install
106+
`array_api_compat <https://github.com/data-apis/array-api-compat>`_ in your
107+
test environment. To run the full set of checks you need to install both
108+
`PyTorch <https://pytorch.org/>`_ and `CuPy <https://cupy.dev/>`_ and have
109+
a GPU. Checks that can not be executed or have missing dependencies will be
110+
automatically skipped.

doc/whats_new/v1.3.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,10 @@ Changelog
628628
The `sample_interval_` attribute is deprecated and will be removed in 1.5.
629629
:pr:`25190` by :user:`Vincent Maladière <Vincent-Maladiere>`.
630630

631+
- |Fix| :class:`preprocessing.PowerTransformer` now correcly raises error when
632+
using `method="box-cox"` on data with a constant `np.nan` column.
633+
:pr:`26400` by :user:`Yao Xiao <Charlie-XIAO>`.
634+
631635
:mod:`sklearn.svm`
632636
..................
633637

sklearn/discriminant_analysis.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,9 @@ def decision_function(self, X):
745745
# Only override for the doc
746746
return super().decision_function(X)
747747

748+
def _more_tags(self):
749+
return {"array_api_support": True}
750+
748751

749752
class QuadraticDiscriminantAnalysis(ClassifierMixin, BaseEstimator):
750753
"""Quadratic Discriminant Analysis.

sklearn/preprocessing/_data.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3311,9 +3311,13 @@ def _box_cox_optimize(self, x):
33113311
33123312
We here use scipy builtins which uses the brent optimizer.
33133313
"""
3314+
mask = np.isnan(x)
3315+
if np.all(mask):
3316+
raise ValueError("Column must not be all nan.")
3317+
33143318
# the computation of lambda is influenced by NaNs so we need to
33153319
# get rid of them
3316-
_, lmbda = stats.boxcox(x[~np.isnan(x)], lmbda=None)
3320+
_, lmbda = stats.boxcox(x[~mask], lmbda=None)
33173321

33183322
return lmbda
33193323

sklearn/preprocessing/tests/test_data.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2527,6 +2527,21 @@ def test_power_transformer_copy_False(method, standardize):
25272527
assert X_trans is X_inv_trans
25282528

25292529

2530+
def test_power_transformer_box_cox_raise_all_nans_col():
2531+
"""Check that box-cox raises informative when a column contains all nans.
2532+
2533+
Non-regression test for gh-26303
2534+
"""
2535+
X = rng.random_sample((4, 5))
2536+
X[:, 0] = np.nan
2537+
2538+
err_msg = "Column must not be all nan."
2539+
2540+
pt = PowerTransformer(method="box-cox")
2541+
with pytest.raises(ValueError, match=err_msg):
2542+
pt.fit_transform(X)
2543+
2544+
25302545
@pytest.mark.parametrize(
25312546
"X_2",
25322547
[

sklearn/tests/test_discriminant_analysis.py

Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,12 @@
44

55
from scipy import linalg
66

7-
from sklearn.base import clone
8-
from sklearn._config import config_context
97
from sklearn.utils import check_random_state
108
from sklearn.utils._testing import assert_array_equal
119
from sklearn.utils._testing import assert_array_almost_equal
1210
from sklearn.utils._testing import assert_allclose
1311
from sklearn.utils._testing import assert_almost_equal
14-
from sklearn.utils._array_api import _convert_to_numpy
1512
from sklearn.utils._testing import _convert_container
16-
from sklearn.utils._testing import skip_if_array_api_compat_not_configured
1713

1814
from sklearn.datasets import make_blobs
1915
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
@@ -675,121 +671,3 @@ def test_get_feature_names_out():
675671
dtype=object,
676672
)
677673
assert_array_equal(names_out, expected_names_out)
678-
679-
680-
@skip_if_array_api_compat_not_configured
681-
@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"])
682-
def test_lda_array_api(array_namespace):
683-
"""Check that the array_api Array gives the same results as ndarrays."""
684-
xp = pytest.importorskip(array_namespace)
685-
686-
X_xp = xp.asarray(X)
687-
y_xp = xp.asarray(y3)
688-
689-
lda = LinearDiscriminantAnalysis()
690-
lda.fit(X, y3)
691-
692-
array_attributes = {
693-
key: value for key, value in vars(lda).items() if isinstance(value, np.ndarray)
694-
}
695-
696-
lda_xp = clone(lda)
697-
with config_context(array_api_dispatch=True):
698-
lda_xp.fit(X_xp, y_xp)
699-
700-
# Fitted-attributes which are arrays must have the same
701-
# namespace than the one of the training data.
702-
for key, attribute in array_attributes.items():
703-
lda_xp_param = getattr(lda_xp, key)
704-
assert hasattr(lda_xp_param, "__array_namespace__")
705-
706-
lda_xp_param_np = _convert_to_numpy(lda_xp_param, xp=xp)
707-
assert_allclose(
708-
attribute, lda_xp_param_np, err_msg=f"{key} not the same", atol=1e-3
709-
)
710-
711-
# Check predictions are the same
712-
methods = (
713-
"decision_function",
714-
"predict",
715-
"predict_log_proba",
716-
"predict_proba",
717-
"transform",
718-
)
719-
720-
for method in methods:
721-
result = getattr(lda, method)(X)
722-
with config_context(array_api_dispatch=True):
723-
result_xp = getattr(lda_xp, method)(X_xp)
724-
assert hasattr(
725-
result_xp, "__array_namespace__"
726-
), f"{method} did not output an array_namespace"
727-
728-
result_xp_np = _convert_to_numpy(result_xp, xp=xp)
729-
730-
assert_allclose(
731-
result,
732-
result_xp_np,
733-
err_msg=f"{method} did not the return the same result",
734-
atol=1e-5,
735-
)
736-
737-
738-
@skip_if_array_api_compat_not_configured
739-
@pytest.mark.parametrize("device", ["cuda", "cpu"])
740-
@pytest.mark.parametrize("dtype", ["float32", "float64"])
741-
def test_lda_array_torch(device, dtype):
742-
"""Check running on PyTorch Tensors gives the same results as NumPy"""
743-
torch = pytest.importorskip("torch")
744-
if device == "cuda" and not torch.has_cuda:
745-
pytest.skip("test requires cuda")
746-
747-
lda = LinearDiscriminantAnalysis()
748-
X_np = X6.astype(dtype)
749-
y_np = y6.astype(dtype)
750-
lda.fit(X_np, y_np)
751-
752-
X_torch = torch.asarray(X_np, device=device)
753-
y_torch = torch.asarray(y_np, device=device)
754-
lda_xp = clone(lda)
755-
with config_context(array_api_dispatch=True):
756-
lda_xp.fit(X_torch, y_torch)
757-
758-
array_attributes = {
759-
key: value for key, value in vars(lda).items() if isinstance(value, np.ndarray)
760-
}
761-
762-
for key, attribute in array_attributes.items():
763-
lda_xp_param = getattr(lda_xp, key)
764-
assert isinstance(lda_xp_param, torch.Tensor)
765-
assert lda_xp_param.device.type == device
766-
767-
lda_xp_param_np = _convert_to_numpy(lda_xp_param, xp=torch)
768-
assert_allclose(
769-
attribute, lda_xp_param_np, err_msg=f"{key} not the same", atol=1e-3
770-
)
771-
772-
# Check predictions are the same
773-
methods = (
774-
"decision_function",
775-
"predict",
776-
"predict_log_proba",
777-
"predict_proba",
778-
"transform",
779-
)
780-
for method in methods:
781-
result = getattr(lda, method)(X_np)
782-
with config_context(array_api_dispatch=True):
783-
result_xp = getattr(lda_xp, method)(X_torch)
784-
785-
assert isinstance(result_xp, torch.Tensor)
786-
assert result_xp.device.type == device
787-
788-
result_xp_np = _convert_to_numpy(result_xp, xp=torch)
789-
790-
assert_allclose(
791-
result,
792-
result_xp_np,
793-
err_msg=f"{method} did not the return the same result",
794-
atol=1e-6,
795-
)

sklearn/utils/_tags.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
_DEFAULT_TAGS = {
4+
"array_api_support": False,
45
"non_deterministic": False,
56
"requires_positive_X": False,
67
"requires_positive_y": False,

sklearn/utils/estimator_checks.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import warnings
2+
import importlib
3+
import itertools
24
import pickle
35
import re
46
from copy import deepcopy
@@ -58,6 +60,7 @@
5860
from ..utils.fixes import sp_version
5961
from ..utils.fixes import parse_version
6062
from ..utils.validation import check_is_fitted
63+
from ..utils._array_api import _convert_to_numpy, get_namespace, device as array_device
6164
from ..utils._param_validation import make_constraint
6265
from ..utils._param_validation import generate_invalid_param_val
6366
from ..utils._param_validation import InvalidParameterError
@@ -73,6 +76,7 @@
7376
from ..datasets import (
7477
load_iris,
7578
make_blobs,
79+
make_classification,
7680
make_multilabel_classification,
7781
make_regression,
7882
)
@@ -133,6 +137,21 @@ def _yield_checks(estimator):
133137

134138
yield check_estimator_get_tags_default_keys
135139

140+
if tags["array_api_support"]:
141+
for array_namespace in ["numpy.array_api", "cupy.array_api", "cupy", "torch"]:
142+
if array_namespace == "torch":
143+
for device, dtype in itertools.product(
144+
("cpu", "cuda"), ("float64", "float32")
145+
):
146+
yield partial(
147+
check_array_api_input,
148+
array_namespace=array_namespace,
149+
dtype=dtype,
150+
device=device,
151+
)
152+
else:
153+
yield partial(check_array_api_input, array_namespace=array_namespace)
154+
136155

137156
def _yield_classifier_checks(classifier):
138157
tags = _safe_tags(classifier)
@@ -831,6 +850,111 @@ def _generate_sparse_matrix(X_csr):
831850
yield sparse_format + "_64", X
832851

833852

853+
def check_array_api_input(
854+
name, estimator_orig, *, array_namespace, device=None, dtype="float64"
855+
):
856+
"""Check that the array_api Array gives the same results as ndarrays."""
857+
try:
858+
array_mod = importlib.import_module(array_namespace)
859+
except ModuleNotFoundError:
860+
raise SkipTest(
861+
f"{array_namespace} is not installed: not checking array_api input"
862+
)
863+
try:
864+
import array_api_compat # noqa
865+
except ImportError:
866+
raise SkipTest(
867+
"array_api_compat is not installed: not checking array_api input"
868+
)
869+
870+
# First create an array using the chosen array module and then get the
871+
# corresponding (compatibility wrapped) array namespace based on it.
872+
# This is because `cupy` is not the same as the compatibility wrapped
873+
# namespace of a CuPy array.
874+
xp = array_api_compat.get_namespace(array_mod.asarray(1))
875+
876+
if array_namespace == "torch" and device == "cuda" and not xp.has_cuda:
877+
raise SkipTest("PyTorch test requires cuda, which is not available")
878+
elif array_namespace in {"cupy", "cupy.array_api"}: # pragma: nocover
879+
import cupy
880+
881+
if cupy.cuda.runtime.getDeviceCount() == 0:
882+
raise SkipTest("CuPy test requires cuda, which is not available")
883+
884+
X, y = make_classification(random_state=42)
885+
X = X.astype(dtype, copy=False)
886+
887+
X = _enforce_estimator_tags_X(estimator_orig, X)
888+
y = _enforce_estimator_tags_y(estimator_orig, y)
889+
890+
est = clone(estimator_orig)
891+
892+
X_xp = xp.asarray(X, device=device)
893+
y_xp = xp.asarray(y, device=device)
894+
895+
est.fit(X, y)
896+
897+
array_attributes = {
898+
key: value for key, value in vars(est).items() if isinstance(value, np.ndarray)
899+
}
900+
901+
est_xp = clone(est)
902+
with config_context(array_api_dispatch=True):
903+
est_xp.fit(X_xp, y_xp)
904+
905+
# Fitted attributes which are arrays must have the same
906+
# namespace as the one of the training data.
907+
for key, attribute in array_attributes.items():
908+
est_xp_param = getattr(est_xp, key)
909+
assert (
910+
get_namespace(est_xp_param)[0] == get_namespace(X_xp)[0]
911+
), f"'{key}' attribute is in wrong namespace"
912+
913+
assert array_device(est_xp_param) == array_device(X_xp)
914+
915+
est_xp_param_np = _convert_to_numpy(est_xp_param, xp=xp)
916+
assert_allclose(
917+
attribute,
918+
est_xp_param_np,
919+
err_msg=f"{key} not the same",
920+
atol=np.finfo(X.dtype).eps * 100,
921+
)
922+
923+
# Check estimator methods, if supported, give the same results
924+
methods = (
925+
"decision_function",
926+
"predict",
927+
"predict_log_proba",
928+
"predict_proba",
929+
"transform",
930+
"inverse_transform",
931+
)
932+
933+
for method_name in methods:
934+
method = getattr(est, method_name, None)
935+
if method is None:
936+
continue
937+
938+
result = method(X)
939+
with config_context(array_api_dispatch=True):
940+
result_xp = getattr(est_xp, method_name)(X_xp)
941+
942+
assert (
943+
get_namespace(result_xp)[0] == get_namespace(X_xp)[0]
944+
), f"'{method}' output is in wrong namespace"
945+
946+
assert array_device(result_xp) == array_device(X_xp)
947+
948+
result_xp_np = _convert_to_numpy(result_xp, xp=xp)
949+
950+
assert_allclose(
951+
result,
952+
result_xp_np,
953+
err_msg=f"{method} did not the return the same result",
954+
atol=np.finfo(X.dtype).eps * 100,
955+
)
956+
957+
834958
def check_estimator_sparse_data(name, estimator_orig):
835959
rng = np.random.RandomState(0)
836960
X = rng.uniform(size=(40, 3))

0 commit comments

Comments
 (0)