Skip to content

Commit 4480163

Browse files
authored
TST Add unit tests for _BinaryClassifierCurveDisplayMixin (scikit-learn#31193)
1 parent f16de74 commit 4480163

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed

sklearn/utils/tests/test_plotting.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,120 @@
11
import numpy as np
22
import pytest
33

4+
from sklearn.linear_model import LogisticRegression
45
from sklearn.utils._plotting import (
6+
_BinaryClassifierCurveDisplayMixin,
57
_despine,
68
_interval_max_min_ratio,
79
_validate_score_name,
810
_validate_style_kwargs,
911
)
12+
from sklearn.utils._response import _get_response_values_binary
13+
from sklearn.utils._testing import assert_allclose
14+
15+
16+
@pytest.mark.parametrize("ax", [None, "Ax"])
17+
@pytest.mark.parametrize(
18+
"name, expected_name_out", [(None, "TestEstimator"), ("CustomName", "CustomName")]
19+
)
20+
def test_validate_plot_params(pyplot, ax, name, expected_name_out):
21+
"""Check `_validate_plot_params` returns the correct values."""
22+
display = _BinaryClassifierCurveDisplayMixin()
23+
display.estimator_name = "TestEstimator"
24+
if ax:
25+
_, ax = pyplot.subplots()
26+
ax_out, _, name_out = display._validate_plot_params(ax=ax, name=name)
27+
28+
assert name_out == expected_name_out
29+
30+
if ax:
31+
assert ax == ax_out
32+
33+
34+
@pytest.mark.parametrize("pos_label", [None, 0])
35+
@pytest.mark.parametrize("name", [None, "CustomName"])
36+
@pytest.mark.parametrize(
37+
"response_method", ["auto", "predict_proba", "decision_function"]
38+
)
39+
def test_validate_and_get_response_values(pyplot, pos_label, name, response_method):
40+
"""Check `_validate_and_get_response_values` returns the correct values."""
41+
X = np.array([[0, 0], [1, 1], [2, 2], [3, 3]])
42+
y = np.array([0, 0, 2, 2])
43+
estimator = LogisticRegression().fit(X, y)
44+
45+
y_pred, pos_label, name_out = (
46+
_BinaryClassifierCurveDisplayMixin._validate_and_get_response_values(
47+
estimator,
48+
X,
49+
y,
50+
response_method=response_method,
51+
pos_label=pos_label,
52+
name=name,
53+
)
54+
)
55+
56+
expected_y_pred, expected_pos_label = _get_response_values_binary(
57+
estimator, X, response_method=response_method, pos_label=pos_label
58+
)
59+
60+
assert_allclose(y_pred, expected_y_pred)
61+
assert pos_label == expected_pos_label
62+
63+
# Check name is handled correctly
64+
expected_name = name if name is not None else "LogisticRegression"
65+
assert name_out == expected_name
66+
67+
68+
@pytest.mark.parametrize(
69+
"y_true, error_message",
70+
[
71+
(np.array([0, 1, 2]), "The target y is not binary."),
72+
(np.array([0, 1]), "Found input variables with inconsistent"),
73+
(np.array([0, 2, 0, 2]), r"y_true takes value in \{0, 2\} and pos_label"),
74+
],
75+
)
76+
def test_validate_from_predictions_params_errors(pyplot, y_true, error_message):
77+
"""Check `_validate_from_predictions_params` raises the correct errors."""
78+
y_pred = np.array([0.1, 0.2, 0.3, 0.4])
79+
sample_weight = np.ones(4)
80+
81+
with pytest.raises(ValueError, match=error_message):
82+
_BinaryClassifierCurveDisplayMixin._validate_from_predictions_params(
83+
y_true=y_true,
84+
y_pred=y_pred,
85+
sample_weight=sample_weight,
86+
pos_label=None,
87+
)
88+
89+
90+
@pytest.mark.parametrize("name", [None, "CustomName"])
91+
@pytest.mark.parametrize(
92+
"pos_label, y_true",
93+
[
94+
(None, np.array([0, 1, 0, 1])),
95+
(2, np.array([0, 2, 0, 2])),
96+
],
97+
)
98+
def test_validate_from_predictions_params_returns(pyplot, name, pos_label, y_true):
99+
"""Check `_validate_from_predictions_params` returns the correct values."""
100+
y_pred = np.array([0.1, 0.2, 0.3, 0.4])
101+
pos_label_out, name_out = (
102+
_BinaryClassifierCurveDisplayMixin._validate_from_predictions_params(
103+
y_true=y_true,
104+
y_pred=y_pred,
105+
sample_weight=None,
106+
pos_label=pos_label,
107+
name=name,
108+
)
109+
)
110+
111+
# Check name is handled correctly
112+
expected_name = name if name is not None else "Classifier"
113+
assert name_out == expected_name
114+
115+
# Check pos_label is handled correctly
116+
expected_pos_label = pos_label if pos_label is not None else 1
117+
assert pos_label_out == expected_pos_label
10118

11119

12120
def metric():

0 commit comments

Comments
 (0)