Skip to content

Commit bf08cb3

Browse files
MarcoGorellilestevejeremiedbb
authored
Fix a regression in GridSearchCV for parameter grids that have arrays of different sizes as parameter values (scikit-learn#29314)
Co-authored-by: Loïc Estève <loic.esteve@ymail.com> Co-authored-by: Jérémie du Boisberranger <jeremie@probabl.ai>
1 parent 3ef8bf5 commit bf08cb3

File tree

3 files changed

+172
-41
lines changed

3 files changed

+172
-41
lines changed

doc/whats_new/v1.5.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ Changelog
6363
grids that have estimators as parameter values.
6464
:pr:`29179` by :user:`Marco Gorelli<MarcoGorelli>`.
6565

66+
- |Fix| Fix a regression in :class:`model_selection.GridSearchCV` for parameter
67+
grids that have arrays of different sizes as parameter values.
68+
:pr:`29314` by :user:`Marco Gorelli<MarcoGorelli>`.
69+
6670
:mod:`sklearn.tree`
6771
...................
6872

sklearn/model_selection/_search.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,56 @@ def check(self):
379379
return check
380380

381381

382+
def _yield_masked_array_for_each_param(candidate_params):
383+
"""
384+
Yield a masked array for each candidate param.
385+
386+
`candidate_params` is a sequence of params which were used in
387+
a `GridSearchCV`. We use masked arrays for the results, as not
388+
all params are necessarily present in each element of
389+
`candidate_params`. For example, if using `GridSearchCV` with
390+
a `SVC` model, then one might search over params like:
391+
392+
- kernel=["rbf"], gamma=[0.1, 1]
393+
- kernel=["poly"], degree=[1, 2]
394+
395+
and then param `'gamma'` would not be present in entries of
396+
`candidate_params` corresponding to `kernel='poly'`.
397+
"""
398+
n_candidates = len(candidate_params)
399+
param_results = defaultdict(dict)
400+
401+
for cand_idx, params in enumerate(candidate_params):
402+
for name, value in params.items():
403+
param_results["param_%s" % name][cand_idx] = value
404+
405+
for key, param_result in param_results.items():
406+
param_list = list(param_result.values())
407+
try:
408+
arr = np.array(param_list)
409+
except ValueError:
410+
# This can happen when param_list contains lists of different
411+
# lengths, for example:
412+
# param_list=[[1], [2, 3]]
413+
arr_dtype = np.dtype(object)
414+
else:
415+
# There are two cases when we don't use the automatically inferred
416+
# dtype when creating the array and we use object instead:
417+
# - string dtype
418+
# - when array.ndim > 1, that means that param_list was something
419+
# like a list of same-size sequences, which gets turned into a
420+
# multi-dimensional array but we want a 1d array
421+
arr_dtype = arr.dtype if arr.dtype.kind != "U" and arr.ndim == 1 else object
422+
423+
# Use one MaskedArray and mask all the places where the param is not
424+
# applicable for that candidate (which may not contain all the params).
425+
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr_dtype)
426+
for index, value in param_result.items():
427+
# Setting the value at an index unmasks that index
428+
ma[index] = value
429+
yield (key, ma)
430+
431+
382432
class BaseSearchCV(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta):
383433
"""Abstract base class for hyper parameter search with cross-validation."""
384434

@@ -1079,45 +1129,9 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
10791129

10801130
_store("fit_time", out["fit_time"])
10811131
_store("score_time", out["score_time"])
1082-
param_results = defaultdict(dict)
1083-
for cand_idx, params in enumerate(candidate_params):
1084-
for name, value in params.items():
1085-
param_results["param_%s" % name][cand_idx] = value
1086-
for key, param_result in param_results.items():
1087-
param_list = list(param_result.values())
1088-
try:
1089-
with warnings.catch_warnings():
1090-
warnings.filterwarnings(
1091-
"ignore",
1092-
message="in the future the `.dtype` attribute",
1093-
category=DeprecationWarning,
1094-
)
1095-
# Warning raised by NumPy 1.20+
1096-
arr_dtype = np.result_type(*param_list)
1097-
except (TypeError, ValueError):
1098-
arr_dtype = np.dtype(object)
1099-
else:
1100-
if any(np.min_scalar_type(x) == object for x in param_list):
1101-
# `np.result_type` might get thrown off by `.dtype` properties
1102-
# (which some estimators have).
1103-
# If finding the result dtype this way would give object,
1104-
# then we use object.
1105-
# https://github.com/scikit-learn/scikit-learn/issues/29157
1106-
arr_dtype = np.dtype(object)
1107-
if len(param_list) == n_candidates and arr_dtype != object:
1108-
# Exclude `object` else the numpy constructor might infer a list of
1109-
# tuples to be a 2d array.
1110-
results[key] = MaskedArray(param_list, mask=False, dtype=arr_dtype)
1111-
else:
1112-
# Use one MaskedArray and mask all the places where the param is not
1113-
# applicable for that candidate (which may not contain all the params).
1114-
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr_dtype)
1115-
for index, value in param_result.items():
1116-
# Setting the value at an index unmasks that index
1117-
ma[index] = value
1118-
results[key] = ma
1119-
11201132
# Store a list of param dicts at the key 'params'
1133+
for param, ma in _yield_masked_array_for_each_param(candidate_params):
1134+
results[param] = ma
11211135
results["params"] = candidate_params
11221136

11231137
test_scores_dict = _normalize_score_results(out["test_scores"])

sklearn/model_selection/tests/test_search.py

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,20 @@
6161
StratifiedShuffleSplit,
6262
train_test_split,
6363
)
64-
from sklearn.model_selection._search import BaseSearchCV
64+
from sklearn.model_selection._search import (
65+
BaseSearchCV,
66+
_yield_masked_array_for_each_param,
67+
)
6568
from sklearn.model_selection.tests.common import OneTimeSplitter
6669
from sklearn.naive_bayes import ComplementNB
6770
from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor
68-
from sklearn.pipeline import Pipeline
69-
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, StandardScaler
71+
from sklearn.pipeline import Pipeline, make_pipeline
72+
from sklearn.preprocessing import (
73+
OneHotEncoder,
74+
OrdinalEncoder,
75+
SplineTransformer,
76+
StandardScaler,
77+
)
7078
from sklearn.svm import SVC, LinearSVC
7179
from sklearn.tests.metadata_routing_common import (
7280
ConsumingScorer,
@@ -2724,6 +2732,37 @@ def test_search_with_estimators_issue_29157():
27242732
assert grid_search.cv_results_["param_enc__enc"].dtype == object
27252733

27262734

2735+
def test_cv_results_multi_size_array():
2736+
"""Check that GridSearchCV works with params that are arrays of different sizes.
2737+
2738+
Non-regression test for #29277.
2739+
"""
2740+
n_features = 10
2741+
X, y = make_classification(n_features=10)
2742+
2743+
spline_reg_pipe = make_pipeline(
2744+
SplineTransformer(extrapolation="periodic"),
2745+
LogisticRegression(),
2746+
)
2747+
2748+
n_knots_list = [n_features * i for i in [10, 11, 12]]
2749+
knots_list = [
2750+
np.linspace(0, np.pi * 2, n_knots).reshape((-1, n_features))
2751+
for n_knots in n_knots_list
2752+
]
2753+
spline_reg_pipe_cv = GridSearchCV(
2754+
estimator=spline_reg_pipe,
2755+
param_grid={
2756+
"splinetransformer__knots": knots_list,
2757+
},
2758+
)
2759+
2760+
spline_reg_pipe_cv.fit(X, y)
2761+
assert (
2762+
spline_reg_pipe_cv.cv_results_["param_splinetransformer__knots"].dtype == object
2763+
)
2764+
2765+
27272766
@pytest.mark.parametrize(
27282767
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
27292768
)
@@ -2747,3 +2786,77 @@ def test_array_api_search_cv_classifier(SearchCV, array_namespace, device, dtype
27472786
)
27482787
searcher.fit(X_xp, y_xp)
27492788
searcher.score(X_xp, y_xp)
2789+
2790+
2791+
# Construct these outside the tests so that the same object is used
2792+
# for both input and `expected`
2793+
one_hot_encoder = OneHotEncoder()
2794+
ordinal_encoder = OrdinalEncoder()
2795+
2796+
# If we construct this directly via `MaskedArray`, the list of tuples
2797+
# gets auto-converted to a 2D array.
2798+
ma_with_tuples = np.ma.MaskedArray(np.empty(2), mask=True, dtype=object)
2799+
ma_with_tuples[0] = (1, 2)
2800+
ma_with_tuples[1] = (3, 4)
2801+
2802+
2803+
@pytest.mark.parametrize(
2804+
("candidate_params", "expected"),
2805+
[
2806+
pytest.param(
2807+
[{"foo": 1}, {"foo": 2}],
2808+
[
2809+
("param_foo", np.ma.MaskedArray(np.array([1, 2]))),
2810+
],
2811+
id="simple numeric, single param",
2812+
),
2813+
pytest.param(
2814+
[{"foo": 1, "bar": 3}, {"foo": 2, "bar": 4}, {"foo": 3}],
2815+
[
2816+
("param_foo", np.ma.MaskedArray(np.array([1, 2, 3]))),
2817+
(
2818+
"param_bar",
2819+
np.ma.MaskedArray(np.array([3, 4, 0]), mask=[False, False, True]),
2820+
),
2821+
],
2822+
id="simple numeric, one param is missing in one round",
2823+
),
2824+
pytest.param(
2825+
[{"foo": [[1], [2], [3]]}, {"foo": [[1], [2]]}],
2826+
[
2827+
(
2828+
"param_foo",
2829+
np.ma.MaskedArray([[[1], [2], [3]], [[1], [2]]], dtype=object),
2830+
),
2831+
],
2832+
id="lists of different lengths",
2833+
),
2834+
pytest.param(
2835+
[{"foo": (1, 2)}, {"foo": (3, 4)}],
2836+
[
2837+
(
2838+
"param_foo",
2839+
ma_with_tuples,
2840+
),
2841+
],
2842+
id="lists tuples",
2843+
),
2844+
pytest.param(
2845+
[{"foo": ordinal_encoder}, {"foo": one_hot_encoder}],
2846+
[
2847+
(
2848+
"param_foo",
2849+
np.ma.MaskedArray([ordinal_encoder, one_hot_encoder], dtype=object),
2850+
),
2851+
],
2852+
id="estimators",
2853+
),
2854+
],
2855+
)
2856+
def test_yield_masked_array_for_each_param(candidate_params, expected):
2857+
result = list(_yield_masked_array_for_each_param(candidate_params))
2858+
for (key, value), (expected_key, expected_value) in zip(result, expected):
2859+
assert key == expected_key
2860+
assert value.dtype == expected_value.dtype
2861+
np.testing.assert_array_equal(value, expected_value)
2862+
np.testing.assert_array_equal(value.mask, expected_value.mask)

0 commit comments

Comments
 (0)