Skip to content

Commit 8b06fa6

Browse files
JosephBARBIERDARNALglemaitreCharlie-XIAO
authored
FIX handle aliases in displays when used as default and provided by user (scikit-learn#30023)
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai> Co-authored-by: Yao Xiao <108576690+Charlie-XIAO@users.noreply.github.com>
1 parent f84199e commit 8b06fa6

File tree

14 files changed

+282
-70
lines changed

14 files changed

+282
-70
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- Classes :class:`metrics.ConfusionMatrixDisplay`,
2+
:class:`metrics.RocCurveDisplay`, :class:`calibration.CalibrationDisplay`,
3+
:class:`metrics.PrecisionRecallDisplay`, :class:`metrics.PredictionErrorDisplay` and
4+
:class:`inspection.PartialDependenceDisplay` now properly handle Matplotlib aliases
5+
for style parameters (e.g., `c` and `color`, `ls` and `linestyle`, etc).
6+
By :user:`Joseph Barbier <JosephBARBIERDARNAL>`

sklearn/calibration.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
StrOptions,
3939
validate_params,
4040
)
41-
from .utils._plotting import _BinaryClassifierCurveDisplayMixin
41+
from .utils._plotting import _BinaryClassifierCurveDisplayMixin, _validate_style_kwargs
4242
from .utils._response import _get_response_values, _process_predict_proba
4343
from .utils.metadata_routing import (
4444
MetadataRouter,
@@ -1150,10 +1150,10 @@ def plot(self, *, ax=None, name=None, ref_line=True, **kwargs):
11501150
f"(Positive class: {self.pos_label})" if self.pos_label is not None else ""
11511151
)
11521152

1153-
line_kwargs = {"marker": "s", "linestyle": "-"}
1153+
default_line_kwargs = {"marker": "s", "linestyle": "-"}
11541154
if name is not None:
1155-
line_kwargs["label"] = name
1156-
line_kwargs.update(**kwargs)
1155+
default_line_kwargs["label"] = name
1156+
line_kwargs = _validate_style_kwargs(default_line_kwargs, kwargs)
11571157

11581158
ref_line_label = "Perfectly calibrated"
11591159
existing_ref_line = ref_line_label in self.ax_.get_legend_handles_labels()[1]

sklearn/inspection/_plot/partial_dependence.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from ...utils._encode import _unique
2020
from ...utils._optional_dependencies import check_matplotlib_support # noqa
21+
from ...utils._plotting import _validate_style_kwargs
2122
from ...utils.parallel import Parallel, delayed
2223
from .. import partial_dependence
2324
from .._pd_utils import _check_feature_names, _get_feature_index
@@ -1294,7 +1295,7 @@ def plot(
12941295
if contour_kw is None:
12951296
contour_kw = {}
12961297
default_contour_kws = {"alpha": 0.75}
1297-
contour_kw = {**default_contour_kws, **contour_kw}
1298+
contour_kw = _validate_style_kwargs(default_contour_kws, contour_kw)
12981299

12991300
n_features = len(self.features)
13001301
is_average_plot = [kind_plot == "average" for kind_plot in kind]
@@ -1422,26 +1423,25 @@ def plot(
14221423
default_ice_lines_kws = {}
14231424
default_pd_lines_kws = {}
14241425

1425-
ice_lines_kw = {
1426-
**default_line_kws,
1427-
**default_ice_lines_kws,
1428-
**line_kw,
1429-
**ice_lines_kw,
1430-
}
1426+
default_ice_lines_kws = {**default_line_kws, **default_ice_lines_kws}
1427+
default_pd_lines_kws = {**default_line_kws, **default_pd_lines_kws}
1428+
1429+
line_kw = _validate_style_kwargs(default_line_kws, line_kw)
1430+
1431+
ice_lines_kw = _validate_style_kwargs(
1432+
_validate_style_kwargs(default_ice_lines_kws, line_kw), ice_lines_kw
1433+
)
14311434
del ice_lines_kw["label"]
14321435

1433-
pd_line_kw = {
1434-
**default_line_kws,
1435-
**default_pd_lines_kws,
1436-
**line_kw,
1437-
**pd_line_kw,
1438-
}
1436+
pd_line_kw = _validate_style_kwargs(
1437+
_validate_style_kwargs(default_pd_lines_kws, line_kw), pd_line_kw
1438+
)
14391439

14401440
default_bar_kws = {"color": "C0"}
1441-
bar_kw = {**default_bar_kws, **bar_kw}
1441+
bar_kw = _validate_style_kwargs(default_bar_kws, bar_kw)
14421442

14431443
default_heatmap_kw = {}
1444-
heatmap_kw = {**default_heatmap_kw, **heatmap_kw}
1444+
heatmap_kw = _validate_style_kwargs(default_heatmap_kw, heatmap_kw)
14451445

14461446
self._plot_one_way_partial_dependence(
14471447
kind_plot,

sklearn/inspection/_plot/tests/test_plot_partial_dependence.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,10 @@ def test_partial_dependence_kind_error(
970970
({"color": "r"}, {"color": "g"}, None, ("g", "r")),
971971
({"color": "r"}, None, None, ("r", "r")),
972972
({"color": "r"}, {"linestyle": "--"}, {"linestyle": "-."}, ("r", "r")),
973+
({"c": "r"}, None, None, ("r", "r")),
974+
({"c": "r", "ls": "-."}, {"color": "g"}, {"color": "b"}, ("g", "b")),
975+
({"c": "r"}, {"c": "g"}, {"c": "b"}, ("g", "b")),
976+
({"c": "r"}, {"ls": "--"}, {"ls": "-."}, ("r", "r")),
973977
],
974978
)
975979
def test_plot_partial_dependence_lines_kw(
@@ -999,16 +1003,26 @@ def test_plot_partial_dependence_lines_kw(
9991003
)
10001004

10011005
line = disp.lines_[0, 0, -1]
1002-
assert line.get_color() == expected_colors[0]
1003-
if pd_line_kw is not None and "linestyle" in pd_line_kw:
1004-
assert line.get_linestyle() == pd_line_kw["linestyle"]
1006+
assert line.get_color() == expected_colors[0], (
1007+
f"{line.get_color()}!={expected_colors[0]}\n" f"{line_kw} and {pd_line_kw}"
1008+
)
1009+
if pd_line_kw is not None:
1010+
if "linestyle" in pd_line_kw:
1011+
assert line.get_linestyle() == pd_line_kw["linestyle"]
1012+
elif "ls" in pd_line_kw:
1013+
assert line.get_linestyle() == pd_line_kw["ls"]
10051014
else:
10061015
assert line.get_linestyle() == "--"
10071016

10081017
line = disp.lines_[0, 0, 0]
1009-
assert line.get_color() == expected_colors[1]
1010-
if ice_lines_kw is not None and "linestyle" in ice_lines_kw:
1011-
assert line.get_linestyle() == ice_lines_kw["linestyle"]
1018+
assert (
1019+
line.get_color() == expected_colors[1]
1020+
), f"{line.get_color()}!={expected_colors[1]}"
1021+
if ice_lines_kw is not None:
1022+
if "linestyle" in ice_lines_kw:
1023+
assert line.get_linestyle() == ice_lines_kw["linestyle"]
1024+
elif "ls" in ice_lines_kw:
1025+
assert line.get_linestyle() == ice_lines_kw["ls"]
10121026
else:
10131027
assert line.get_linestyle() == "-"
10141028

sklearn/metrics/_plot/confusion_matrix.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from ...base import is_classifier
99
from ...utils._optional_dependencies import check_matplotlib_support
10+
from ...utils._plotting import _validate_style_kwargs
1011
from ...utils.multiclass import unique_labels
1112
from .. import confusion_matrix
1213

@@ -145,7 +146,7 @@ def plot(
145146

146147
default_im_kw = dict(interpolation="nearest", cmap=cmap)
147148
im_kw = im_kw or {}
148-
im_kw = {**default_im_kw, **im_kw}
149+
im_kw = _validate_style_kwargs(default_im_kw, im_kw)
149150
text_kw = text_kw or {}
150151

151152
self.im_ = ax.imshow(cm, **im_kw)
@@ -171,7 +172,7 @@ def plot(
171172
text_cm = format(cm[i, j], values_format)
172173

173174
default_text_kwargs = dict(ha="center", va="center", color=color)
174-
text_kwargs = {**default_text_kwargs, **text_kw}
175+
text_kwargs = _validate_style_kwargs(default_text_kwargs, text_kw)
175176

176177
self.text_[i, j] = ax.text(j, i, text_cm, **text_kwargs)
177178

sklearn/metrics/_plot/precision_recall_curve.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33

44
from collections import Counter
55

6-
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
6+
from ...utils._plotting import (
7+
_BinaryClassifierCurveDisplayMixin,
8+
_validate_style_kwargs,
9+
)
710
from .._ranking import average_precision_score, precision_recall_curve
811

912

@@ -178,14 +181,17 @@ def plot(
178181
"""
179182
self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name)
180183

181-
line_kwargs = {"drawstyle": "steps-post"}
184+
default_line_kwargs = {"drawstyle": "steps-post"}
182185
if self.average_precision is not None and name is not None:
183-
line_kwargs["label"] = f"{name} (AP = {self.average_precision:0.2f})"
186+
default_line_kwargs["label"] = (
187+
f"{name} (AP = {self.average_precision:0.2f})"
188+
)
184189
elif self.average_precision is not None:
185-
line_kwargs["label"] = f"AP = {self.average_precision:0.2f}"
190+
default_line_kwargs["label"] = f"AP = {self.average_precision:0.2f}"
186191
elif name is not None:
187-
line_kwargs["label"] = name
188-
line_kwargs.update(**kwargs)
192+
default_line_kwargs["label"] = name
193+
194+
line_kwargs = _validate_style_kwargs(default_line_kwargs, kwargs)
189195

190196
(self.line_,) = self.ax_.plot(self.recall, self.precision, **line_kwargs)
191197

@@ -214,13 +220,18 @@ def plot(
214220
"to automatically set prevalence_pos_label"
215221
)
216222

217-
chance_level_line_kw = {
223+
default_chance_level_line_kw = {
218224
"label": f"Chance level (AP = {self.prevalence_pos_label:0.2f})",
219225
"color": "k",
220226
"linestyle": "--",
221227
}
222-
if chance_level_kw is not None:
223-
chance_level_line_kw.update(chance_level_kw)
228+
229+
if chance_level_kw is None:
230+
chance_level_kw = {}
231+
232+
chance_level_line_kw = _validate_style_kwargs(
233+
default_chance_level_line_kw, chance_level_kw
234+
)
224235

225236
(self.chance_level_,) = self.ax_.plot(
226237
(0, 1),

sklearn/metrics/_plot/regression.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from ...utils import _safe_indexing, check_random_state
99
from ...utils._optional_dependencies import check_matplotlib_support
10+
from ...utils._plotting import _validate_style_kwargs
1011

1112

1213
class PredictionErrorDisplay:
@@ -142,6 +143,9 @@ def plot(
142143
default_scatter_kwargs = {"color": "tab:blue", "alpha": 0.8}
143144
default_line_kwargs = {"color": "black", "alpha": 0.7, "linestyle": "--"}
144145

146+
scatter_kwargs = _validate_style_kwargs(default_scatter_kwargs, scatter_kwargs)
147+
line_kwargs = _validate_style_kwargs(default_line_kwargs, line_kwargs)
148+
145149
scatter_kwargs = {**default_scatter_kwargs, **scatter_kwargs}
146150
line_kwargs = {**default_line_kwargs, **line_kwargs}
147151

sklearn/metrics/_plot/roc_curve.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# Authors: The scikit-learn developers
22
# SPDX-License-Identifier: BSD-3-Clause
33

4-
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
4+
from ...utils._plotting import (
5+
_BinaryClassifierCurveDisplayMixin,
6+
_validate_style_kwargs,
7+
)
58
from .._ranking import auc, roc_curve
69

710

@@ -129,24 +132,28 @@ def plot(
129132
"""
130133
self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name)
131134

132-
line_kwargs = {}
135+
default_line_kwargs = {}
133136
if self.roc_auc is not None and name is not None:
134-
line_kwargs["label"] = f"{name} (AUC = {self.roc_auc:0.2f})"
137+
default_line_kwargs["label"] = f"{name} (AUC = {self.roc_auc:0.2f})"
135138
elif self.roc_auc is not None:
136-
line_kwargs["label"] = f"AUC = {self.roc_auc:0.2f}"
139+
default_line_kwargs["label"] = f"AUC = {self.roc_auc:0.2f}"
137140
elif name is not None:
138-
line_kwargs["label"] = name
141+
default_line_kwargs["label"] = name
139142

140-
line_kwargs.update(**kwargs)
143+
line_kwargs = _validate_style_kwargs(default_line_kwargs, kwargs)
141144

142-
chance_level_line_kw = {
145+
default_chance_level_line_kw = {
143146
"label": "Chance level (AUC = 0.5)",
144147
"color": "k",
145148
"linestyle": "--",
146149
}
147150

148-
if chance_level_kw is not None:
149-
chance_level_line_kw.update(**chance_level_kw)
151+
if chance_level_kw is None:
152+
chance_level_kw = {}
153+
154+
chance_level_kw = _validate_style_kwargs(
155+
default_chance_level_line_kw, chance_level_kw
156+
)
150157

151158
(self.line_,) = self.ax_.plot(self.fpr, self.tpr, **line_kwargs)
152159
info_pos_label = (
@@ -164,13 +171,11 @@ def plot(
164171
)
165172

166173
if plot_chance_level:
167-
(self.chance_level_,) = self.ax_.plot(
168-
(0, 1), (0, 1), **chance_level_line_kw
169-
)
174+
(self.chance_level_,) = self.ax_.plot((0, 1), (0, 1), **chance_level_kw)
170175
else:
171176
self.chance_level_ = None
172177

173-
if "label" in line_kwargs or "label" in chance_level_line_kw:
178+
if "label" in line_kwargs or "label" in chance_level_kw:
174179
self.ax_.legend(loc="lower right")
175180

176181
return self

sklearn/metrics/_plot/tests/test_precision_recall_display.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def test_precision_recall_display_plotting(
8282
assert display.chance_level_ is None
8383

8484

85-
@pytest.mark.parametrize("chance_level_kw", [None, {"color": "r"}])
85+
@pytest.mark.parametrize("chance_level_kw", [None, {"color": "r"}, {"c": "r"}])
8686
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
8787
def test_precision_recall_chance_level_line(
8888
pyplot,

sklearn/metrics/_plot/tests/test_predict_error_display.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,21 @@ def test_plot_prediction_error_ax(pyplot, regressor_fitted, class_method):
128128

129129

130130
@pytest.mark.parametrize("class_method", ["from_estimator", "from_predictions"])
131-
def test_prediction_error_custom_artist(pyplot, regressor_fitted, class_method):
132-
"""Check that we can tune the style of the lines."""
131+
@pytest.mark.parametrize(
132+
"scatter_kwargs",
133+
[None, {"color": "blue", "alpha": 0.9}, {"c": "blue", "alpha": 0.9}],
134+
)
135+
@pytest.mark.parametrize(
136+
"line_kwargs", [None, {"color": "red", "linestyle": "-"}, {"c": "red", "ls": "-"}]
137+
)
138+
def test_prediction_error_custom_artist(
139+
pyplot, regressor_fitted, class_method, scatter_kwargs, line_kwargs
140+
):
141+
"""Check that we can tune the style of the line and the scatter."""
133142
extra_params = {
134143
"kind": "actual_vs_predicted",
135-
"scatter_kwargs": {"color": "red"},
136-
"line_kwargs": {"color": "black"},
144+
"scatter_kwargs": scatter_kwargs,
145+
"line_kwargs": line_kwargs,
137146
}
138147
if class_method == "from_estimator":
139148
display = PredictionErrorDisplay.from_estimator(
@@ -145,17 +154,16 @@ def test_prediction_error_custom_artist(pyplot, regressor_fitted, class_method):
145154
y_true=y, y_pred=y_pred, **extra_params
146155
)
147156

148-
assert display.line_.get_color() == "black"
149-
assert_allclose(display.scatter_.get_edgecolor(), [[1.0, 0.0, 0.0, 0.8]])
150-
151-
# create a display with the default values
152-
if class_method == "from_estimator":
153-
display = PredictionErrorDisplay.from_estimator(regressor_fitted, X, y)
157+
if line_kwargs is not None:
158+
assert display.line_.get_linestyle() == "-"
159+
assert display.line_.get_color() == "red"
154160
else:
155-
y_pred = regressor_fitted.predict(X)
156-
display = PredictionErrorDisplay.from_predictions(y_true=y, y_pred=y_pred)
157-
pyplot.close("all")
161+
assert display.line_.get_linestyle() == "--"
162+
assert display.line_.get_color() == "black"
163+
assert display.line_.get_alpha() == 0.7
158164

159-
display.plot(**extra_params)
160-
assert display.line_.get_color() == "black"
161-
assert_allclose(display.scatter_.get_edgecolor(), [[1.0, 0.0, 0.0, 0.8]])
165+
if scatter_kwargs is not None:
166+
assert_allclose(display.scatter_.get_facecolor(), [[0.0, 0.0, 1.0, 0.9]])
167+
assert_allclose(display.scatter_.get_edgecolor(), [[0.0, 0.0, 1.0, 0.9]])
168+
else:
169+
assert display.scatter_.get_alpha() == 0.8

0 commit comments

Comments
 (0)