Skip to content

Commit b39ab89

Browse files
authored
FIX fix comparison between array-like parameters when detecting non-default params for HTML representation (scikit-learn#31528)
1 parent d1479da commit b39ab89

File tree

3 files changed

+87
-3
lines changed

3 files changed

+87
-3
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- Fix regression in HTML representation when detecting the non-default parameters
2+
that where of array-like types.
3+
By :user:`Dea María Léon <deamarialeon>`

sklearn/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,12 +292,14 @@ def is_non_default(param_name, param_value):
292292
init_default_params[param_name]
293293
):
294294
return True
295-
296-
if param_value != init_default_params[param_name] and not (
295+
if not np.array_equal(
296+
param_value, init_default_params[param_name]
297+
) and not (
297298
is_scalar_nan(init_default_params[param_name])
298299
and is_scalar_nan(param_value)
299300
):
300301
return True
302+
301303
return False
302304

303305
# reorder the parameters from `self.get_params` using the `__init__`

sklearn/tests/test_base.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from sklearn.decomposition import PCA
2727
from sklearn.ensemble import IsolationForest
2828
from sklearn.exceptions import InconsistentVersionWarning
29-
from sklearn.model_selection import GridSearchCV
29+
from sklearn.metrics import get_scorer
30+
from sklearn.model_selection import GridSearchCV, KFold
3031
from sklearn.pipeline import Pipeline
3132
from sklearn.preprocessing import StandardScaler
3233
from sklearn.svm import SVC, SVR
@@ -1000,3 +1001,81 @@ def test_get_params_html():
10001001

10011002
assert est._get_params_html() == {"l1": 0, "empty": "test"}
10021003
assert est._get_params_html().non_default == ("empty",)
1004+
1005+
1006+
def make_estimator_with_param(default_value):
1007+
class DynamicEstimator(BaseEstimator):
1008+
def __init__(self, param=default_value):
1009+
self.param = param
1010+
1011+
return DynamicEstimator
1012+
1013+
1014+
@pytest.mark.parametrize(
1015+
"default_value, test_value",
1016+
[
1017+
((), (1,)),
1018+
((), [1]),
1019+
((), np.array([1])),
1020+
((1, 2), (3, 4)),
1021+
((1, 2), [3, 4]),
1022+
((1, 2), np.array([3, 4])),
1023+
(None, 1),
1024+
(None, []),
1025+
(None, lambda x: x),
1026+
(np.nan, 1.0),
1027+
(np.nan, np.array([np.nan])),
1028+
("abc", "def"),
1029+
("abc", ["abc"]),
1030+
(True, False),
1031+
(1, 2),
1032+
(1, [1]),
1033+
(1, np.array([1])),
1034+
(1.0, 2.0),
1035+
(1.0, [1.0]),
1036+
(1.0, np.array([1.0])),
1037+
([1, 2], [3]),
1038+
(np.array([1]), [2, 3]),
1039+
(None, KFold()),
1040+
(None, get_scorer("accuracy")),
1041+
],
1042+
)
1043+
def test_param_is_non_default(default_value, test_value):
1044+
"""Check that we detect non-default parameters with various types.
1045+
1046+
Non-regression test for:
1047+
https://github.com/scikit-learn/scikit-learn/issues/31525
1048+
"""
1049+
estimator = make_estimator_with_param(default_value)(param=test_value)
1050+
non_default = estimator._get_params_html().non_default
1051+
assert "param" in non_default
1052+
1053+
1054+
@pytest.mark.parametrize(
1055+
"default_value, test_value",
1056+
[
1057+
(None, None),
1058+
((), ()),
1059+
((), []),
1060+
((), np.array([])),
1061+
((1, 2, 3), (1, 2, 3)),
1062+
((1, 2, 3), [1, 2, 3]),
1063+
((1, 2, 3), np.array([1, 2, 3])),
1064+
(np.nan, np.nan),
1065+
("abc", "abc"),
1066+
(True, True),
1067+
(1, 1),
1068+
(1.0, 1.0),
1069+
(2, 2.0),
1070+
],
1071+
)
1072+
def test_param_is_default(default_value, test_value):
1073+
"""Check that we detect the default parameters and values in an array-like will
1074+
be reported as default as well.
1075+
1076+
Non-regression test for:
1077+
https://github.com/scikit-learn/scikit-learn/issues/31525
1078+
"""
1079+
estimator = make_estimator_with_param(default_value)(param=test_value)
1080+
non_default = estimator._get_params_html().non_default
1081+
assert "param" not in non_default

0 commit comments

Comments
 (0)