Skip to content

Commit d970efb

Browse files
BUG: Fixes for old sklearn (#13308)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent 5009ee8 commit d970efb

File tree

13 files changed

+165
-18
lines changed

13 files changed

+165
-18
lines changed

environment.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ dependencies:
4242
- pyarrow
4343
- pybv
4444
- pymatreader
45-
- PySide6 !=6.8.0,!=6.8.0.1
45+
- PySide6 !=6.9.1
4646
- python-neo
4747
- python-picard
4848
- pyvista >=0.32,!=0.35.2,!=0.38.0,!=0.38.1,!=0.38.2,!=0.38.3,!=0.38.4,!=0.38.5,!=0.38.6,!=0.42.0
4949
- pyvistaqt >=0.4
5050
- qdarkstyle !=3.2.2
5151
- qtpy
52-
- scikit-learn
52+
- scikit-learn >=1.3.0
5353
- scipy >=1.11
5454
- sip
5555
- snirf
@@ -60,5 +60,5 @@ dependencies:
6060
- trame
6161
- trame-vtk
6262
- trame-vuetify
63-
- vtk =9.3.1=qt_*
63+
- vtk >=9.2
6464
- xlrd

mne/decoding/_fixes.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Authors: The MNE-Python contributors.
2+
# License: BSD-3-Clause
3+
# Copyright the MNE-Python contributors.
4+
5+
try:
6+
from sklearn.utils.validation import validate_data
7+
except ImportError:
8+
from sklearn.utils.validation import check_array, check_X_y
9+
10+
# Use a limited version pulled from sklearn 1.7
11+
def validate_data(
12+
_estimator,
13+
/,
14+
X="no_validation",
15+
y="no_validation",
16+
reset=True,
17+
validate_separately=False,
18+
skip_check_array=False,
19+
**check_params,
20+
):
21+
"""Validate input data and set or check feature names and counts of the input.
22+
23+
This helper function should be used in an estimator that requires input
24+
validation. This mutates the estimator and sets the `n_features_in_` and
25+
`feature_names_in_` attributes if `reset=True`.
26+
27+
.. versionadded:: 1.6
28+
29+
Parameters
30+
----------
31+
_estimator : estimator instance
32+
The estimator to validate the input for.
33+
34+
X : {array-like, sparse matrix, dataframe} of shape \
35+
(n_samples, n_features), default='no validation'
36+
The input samples.
37+
If `'no_validation'`, no validation is performed on `X`. This is
38+
useful for meta-estimator which can delegate input validation to
39+
their underlying estimator(s). In that case `y` must be passed and
40+
the only accepted `check_params` are `multi_output` and
41+
`y_numeric`.
42+
43+
y : array-like of shape (n_samples,), default='no_validation'
44+
The targets.
45+
46+
- If `None`, :func:`~sklearn.utils.check_array` is called on `X`. If
47+
the estimator's `requires_y` tag is True, then an error will be raised.
48+
- If `'no_validation'`, :func:`~sklearn.utils.check_array` is called
49+
on `X` and the estimator's `requires_y` tag is ignored. This is a default
50+
placeholder and is never meant to be explicitly set. In that case `X` must
51+
be passed.
52+
- Otherwise, only `y` with `_check_y` or both `X` and `y` are checked with
53+
either :func:`~sklearn.utils.check_array` or
54+
:func:`~sklearn.utils.check_X_y` depending on `validate_separately`.
55+
56+
reset : bool, default=True
57+
Whether to reset the `n_features_in_` attribute.
58+
If False, the input will be checked for consistency with data
59+
provided when reset was last True.
60+
61+
.. note::
62+
63+
It is recommended to call `reset=True` in `fit` and in the first
64+
call to `partial_fit`. All other methods that validate `X`
65+
should set `reset=False`.
66+
67+
validate_separately : False or tuple of dicts, default=False
68+
Only used if `y` is not `None`.
69+
If `False`, call :func:`~sklearn.utils.check_X_y`. Else, it must be a tuple
70+
of kwargs to be used for calling :func:`~sklearn.utils.check_array` on `X`
71+
and `y` respectively.
72+
73+
`estimator=self` is automatically added to these dicts to generate
74+
more informative error message in case of invalid input data.
75+
76+
skip_check_array : bool, default=False
77+
If `True`, `X` and `y` are unchanged and only `feature_names_in_` and
78+
`n_features_in_` are checked. Otherwise, :func:`~sklearn.utils.check_array`
79+
is called on `X` and `y`.
80+
81+
**check_params : kwargs
82+
Parameters passed to :func:`~sklearn.utils.check_array` or
83+
:func:`~sklearn.utils.check_X_y`. Ignored if validate_separately
84+
is not False.
85+
86+
`estimator=self` is automatically added to these params to generate
87+
more informative error message in case of invalid input data.
88+
89+
Returns
90+
-------
91+
out : {ndarray, sparse matrix} or tuple of these
92+
The validated input. A tuple is returned if both `X` and `y` are
93+
validated.
94+
"""
95+
no_val_X = isinstance(X, str) and X == "no_validation"
96+
no_val_y = y is None or (isinstance(y, str) and y == "no_validation")
97+
98+
if no_val_X and no_val_y:
99+
raise ValueError("Validation should be done on X, y or both.")
100+
101+
default_check_params = {"estimator": _estimator}
102+
check_params = {**default_check_params, **check_params}
103+
104+
if skip_check_array:
105+
if not no_val_X and no_val_y:
106+
out = X
107+
elif no_val_X and not no_val_y:
108+
out = y
109+
else:
110+
out = X, y
111+
elif not no_val_X and no_val_y:
112+
out = check_array(X, input_name="X", **check_params)
113+
elif no_val_X and not no_val_y:
114+
out = check_array(y, input_name="y", **check_params)
115+
else:
116+
if validate_separately:
117+
# We need this because some estimators validate X and y
118+
# separately, and in general, separately calling check_array()
119+
# on X and y isn't equivalent to just calling check_X_y()
120+
# :(
121+
check_X_params, check_y_params = validate_separately
122+
if "estimator" not in check_X_params:
123+
check_X_params = {**default_check_params, **check_X_params}
124+
X = check_array(X, input_name="X", **check_X_params)
125+
if "estimator" not in check_y_params:
126+
check_y_params = {**default_check_params, **check_y_params}
127+
y = check_array(y, input_name="y", **check_y_params)
128+
else:
129+
X, y = check_X_y(X, y, **check_params)
130+
out = X, y
131+
132+
return out

mne/decoding/tests/test_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
get_coef,
5050
)
5151
from mne.decoding.search_light import SlidingEstimator
52+
from mne.utils import check_version
5253

5354

5455
def _make_data(n_samples=1000, n_features=5, n_targets=3):
@@ -97,7 +98,8 @@ def test_get_coef():
9798
"""Test getting linear coefficients (filters/patterns) from estimators."""
9899
lm_classification = LinearModel()
99100
assert hasattr(lm_classification, "__sklearn_tags__")
100-
print(lm_classification.__sklearn_tags__())
101+
if check_version("sklearn", "1.4"):
102+
print(lm_classification.__sklearn_tags__())
101103
assert is_classifier(lm_classification.model)
102104
assert is_classifier(lm_classification)
103105
assert not is_regressor(lm_classification.model)

mne/decoding/tests/test_csp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,4 +495,5 @@ def test_csp_component_ordering():
495495
@parametrize_with_checks([CSP(), SPoC()])
496496
def test_sklearn_compliance(estimator, check):
497497
"""Test compliance with sklearn."""
498+
pytest.importorskip("sklearn", minversion="1.4") # TODO VERSION remove on 1.4+
498499
check(estimator)

mne/decoding/tests/test_ems.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,5 @@ def test_ems():
9797
@parametrize_with_checks([EMS()])
9898
def test_sklearn_compliance(estimator, check):
9999
"""Test compliance with sklearn."""
100+
pytest.importorskip("sklearn", minversion="1.4") # TODO VERSION remove on 1.4+
100101
check(estimator)

mne/decoding/tests/test_ssd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ def test_non_full_rank_data():
499499
)
500500
def test_sklearn_compliance(estimator, check):
501501
"""Test LinearModel compliance with sklearn."""
502+
pytest.importorskip("sklearn", minversion="1.4") # TODO VERSION remove on 1.4+
502503
ignores = (
503504
"check_methods_sample_order_invariance",
504505
# Shape stuff

mne/decoding/tests/test_time_frequency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,5 @@ def test_timefrequency_basic():
5757
@parametrize_with_checks([TimeFrequency([300, 400], 1000.0, n_cycles=0.25)])
5858
def test_sklearn_compliance(estimator, check):
5959
"""Test LinearModel compliance with sklearn."""
60+
pytest.importorskip("sklearn", minversion="1.4") # TODO VERSION remove on 1.4+
6061
check(estimator)

mne/decoding/tests/test_transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def test_bad_triage():
339339
)
340340
def test_sklearn_compliance(estimator, check):
341341
"""Test LinearModel compliance with sklearn."""
342+
pytest.importorskip("sklearn", minversion="1.4") # TODO VERSION remove on 1.4+
342343
ignores = []
343344
if estimator.__class__.__name__ == "FilterEstimator":
344345
ignores += [

mne/decoding/transformer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sklearn.base import BaseEstimator, TransformerMixin, check_array, clone
77
from sklearn.preprocessing import RobustScaler, StandardScaler
88
from sklearn.utils import check_X_y
9-
from sklearn.utils.validation import check_is_fitted, validate_data
9+
from sklearn.utils.validation import check_is_fitted
1010

1111
from .._fiff.pick import (
1212
_pick_data_channels,
@@ -18,7 +18,8 @@
1818
from ..epochs import BaseEpochs
1919
from ..filter import filter_data
2020
from ..time_frequency import psd_array_multitaper
21-
from ..utils import _check_option, _validate_type, fill_doc
21+
from ..utils import _check_option, _validate_type, check_version, fill_doc
22+
from ._fixes import validate_data # TODO VERSION remove with sklearn 1.4+
2223

2324

2425
class MNETransformerMixin(TransformerMixin):
@@ -41,7 +42,9 @@ def _check_data(
4142
# array to make everyone happy.
4243
if isinstance(epochs_data, BaseEpochs):
4344
epochs_data = epochs_data.get_data(copy=False)
44-
kwargs = dict(dtype=np.float64, allow_nd=True, order="C", force_writeable=True)
45+
kwargs = dict(dtype=np.float64, allow_nd=True, order="C")
46+
if check_version("sklearn", "1.4"): # TODO VERSION sklearn 1.4+
47+
kwargs["force_writeable"] = True
4548
if hasattr(self, "n_features_in_") and check_n_features:
4649
if y is None:
4750
epochs_data = validate_data(

mne/inverse_sparse/tests/test_mxne_inverse.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -589,8 +589,12 @@ def data_fun(times):
589589
forward, stc, info, noise_cov, nave=nave, use_cps=False, iir_filter=None
590590
)
591591
evoked = evoked.crop(tmin=0, tmax=10e-3)
592-
stc_ = mixed_norm(evoked, forward, noise_cov, loose=0.9, n_mxne_iter=5, depth=0.9)
593-
assert_array_equal(stc_.vertices, stc.vertices)
592+
stc_ = mixed_norm(
593+
evoked, forward, noise_cov, loose=0.9, n_mxne_iter=5, depth=0.9, random_state=0
594+
)
595+
assert len(stc_.vertices) == len(stc.vertices) == 2
596+
for si in range(len(stc_.vertices)):
597+
assert_array_equal(stc_.vertices[si], stc.vertices[si], err_msg=f"{si=}")
594598

595599

596600
@pytest.mark.slowtest # slow on Azure
@@ -609,7 +613,13 @@ def test_mxne_inverse_empty():
609613
cov = read_cov(fname_cov)
610614
with pytest.warns(RuntimeWarning, match="too big"):
611615
stc, residual = mixed_norm(
612-
evoked, forward, cov, n_mxne_iter=3, alpha=99, return_residual=True
616+
evoked,
617+
forward,
618+
cov,
619+
n_mxne_iter=3,
620+
alpha=99,
621+
return_residual=True,
622+
random_state=0,
613623
)
614624
assert stc.data.size == 0
615625
assert stc.vertices[0].size == 0

0 commit comments

Comments
 (0)