|
26 | 26 | from sklearn.decomposition import PCA
|
27 | 27 | from sklearn.ensemble import IsolationForest
|
28 | 28 | from sklearn.exceptions import InconsistentVersionWarning
|
29 |
| -from sklearn.model_selection import GridSearchCV |
| 29 | +from sklearn.metrics import get_scorer |
| 30 | +from sklearn.model_selection import GridSearchCV, KFold |
30 | 31 | from sklearn.pipeline import Pipeline
|
31 | 32 | from sklearn.preprocessing import StandardScaler
|
32 | 33 | from sklearn.svm import SVC, SVR
|
@@ -1000,3 +1001,81 @@ def test_get_params_html():
|
1000 | 1001 |
|
1001 | 1002 | assert est._get_params_html() == {"l1": 0, "empty": "test"}
|
1002 | 1003 | assert est._get_params_html().non_default == ("empty",)
|
| 1004 | + |
| 1005 | + |
| 1006 | +def make_estimator_with_param(default_value): |
| 1007 | + class DynamicEstimator(BaseEstimator): |
| 1008 | + def __init__(self, param=default_value): |
| 1009 | + self.param = param |
| 1010 | + |
| 1011 | + return DynamicEstimator |
| 1012 | + |
| 1013 | + |
| 1014 | +@pytest.mark.parametrize( |
| 1015 | + "default_value, test_value", |
| 1016 | + [ |
| 1017 | + ((), (1,)), |
| 1018 | + ((), [1]), |
| 1019 | + ((), np.array([1])), |
| 1020 | + ((1, 2), (3, 4)), |
| 1021 | + ((1, 2), [3, 4]), |
| 1022 | + ((1, 2), np.array([3, 4])), |
| 1023 | + (None, 1), |
| 1024 | + (None, []), |
| 1025 | + (None, lambda x: x), |
| 1026 | + (np.nan, 1.0), |
| 1027 | + (np.nan, np.array([np.nan])), |
| 1028 | + ("abc", "def"), |
| 1029 | + ("abc", ["abc"]), |
| 1030 | + (True, False), |
| 1031 | + (1, 2), |
| 1032 | + (1, [1]), |
| 1033 | + (1, np.array([1])), |
| 1034 | + (1.0, 2.0), |
| 1035 | + (1.0, [1.0]), |
| 1036 | + (1.0, np.array([1.0])), |
| 1037 | + ([1, 2], [3]), |
| 1038 | + (np.array([1]), [2, 3]), |
| 1039 | + (None, KFold()), |
| 1040 | + (None, get_scorer("accuracy")), |
| 1041 | + ], |
| 1042 | +) |
| 1043 | +def test_param_is_non_default(default_value, test_value): |
| 1044 | + """Check that we detect non-default parameters with various types. |
| 1045 | +
|
| 1046 | + Non-regression test for: |
| 1047 | + https://github.com/scikit-learn/scikit-learn/issues/31525 |
| 1048 | + """ |
| 1049 | + estimator = make_estimator_with_param(default_value)(param=test_value) |
| 1050 | + non_default = estimator._get_params_html().non_default |
| 1051 | + assert "param" in non_default |
| 1052 | + |
| 1053 | + |
| 1054 | +@pytest.mark.parametrize( |
| 1055 | + "default_value, test_value", |
| 1056 | + [ |
| 1057 | + (None, None), |
| 1058 | + ((), ()), |
| 1059 | + ((), []), |
| 1060 | + ((), np.array([])), |
| 1061 | + ((1, 2, 3), (1, 2, 3)), |
| 1062 | + ((1, 2, 3), [1, 2, 3]), |
| 1063 | + ((1, 2, 3), np.array([1, 2, 3])), |
| 1064 | + (np.nan, np.nan), |
| 1065 | + ("abc", "abc"), |
| 1066 | + (True, True), |
| 1067 | + (1, 1), |
| 1068 | + (1.0, 1.0), |
| 1069 | + (2, 2.0), |
| 1070 | + ], |
| 1071 | +) |
| 1072 | +def test_param_is_default(default_value, test_value): |
| 1073 | + """Check that we detect the default parameters and values in an array-like will |
| 1074 | + be reported as default as well. |
| 1075 | +
|
| 1076 | + Non-regression test for: |
| 1077 | + https://github.com/scikit-learn/scikit-learn/issues/31525 |
| 1078 | + """ |
| 1079 | + estimator = make_estimator_with_param(default_value)(param=test_value) |
| 1080 | + non_default = estimator._get_params_html().non_default |
| 1081 | + assert "param" not in non_default |
0 commit comments