Skip to content

Commit 4cb0e74

Browse files
authored
Improve type annotations in sklearn.metrics._regression (#357)
1 parent b311fd7 commit 4cb0e74

File tree

1 file changed

+164
-30
lines changed

1 file changed

+164
-30
lines changed

stubs/sklearn/metrics/_regression.pyi

Lines changed: 164 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,99 +26,196 @@ from .._typing import ArrayLike, Float, MatrixLike
2626

2727
__ALL__: list = ...
2828

29+
@overload
30+
def mean_absolute_error(
31+
y_true: MatrixLike | ArrayLike,
32+
y_pred: MatrixLike | ArrayLike,
33+
*,
34+
sample_weight: None | ArrayLike = None,
35+
multioutput: Literal["raw_values"],
36+
) -> ndarray: ...
37+
@overload
2938
def mean_absolute_error(
3039
y_true: MatrixLike | ArrayLike,
3140
y_pred: MatrixLike | ArrayLike,
3241
*,
3342
sample_weight: None | ArrayLike = None,
34-
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
35-
) -> ndarray | Float: ...
43+
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
44+
) -> float: ...
45+
@overload
46+
def mean_pinball_loss(
47+
y_true: MatrixLike | ArrayLike,
48+
y_pred: MatrixLike | ArrayLike,
49+
*,
50+
sample_weight: None | ArrayLike = None,
51+
alpha: float = 0.5,
52+
multioutput: Literal["raw_values"],
53+
) -> ndarray: ...
54+
@overload
3655
def mean_pinball_loss(
3756
y_true: MatrixLike | ArrayLike,
3857
y_pred: MatrixLike | ArrayLike,
3958
*,
4059
sample_weight: None | ArrayLike = None,
4160
alpha: float = 0.5,
42-
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
43-
) -> ndarray | Float: ...
61+
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
62+
) -> Float: ...
63+
@overload
64+
def mean_absolute_percentage_error(
65+
y_true: MatrixLike | ArrayLike,
66+
y_pred: MatrixLike | ArrayLike,
67+
*,
68+
sample_weight: None | ArrayLike = None,
69+
multioutput: Literal["raw_values"],
70+
) -> ndarray: ...
71+
@overload
4472
def mean_absolute_percentage_error(
4573
y_true: MatrixLike | ArrayLike,
4674
y_pred: MatrixLike | ArrayLike,
4775
*,
4876
sample_weight: None | ArrayLike = None,
49-
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
50-
) -> ndarray | Float: ...
77+
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
78+
) -> float: ...
79+
@overload
80+
def mean_squared_error(
81+
y_true: MatrixLike | ArrayLike,
82+
y_pred: MatrixLike | ArrayLike,
83+
*,
84+
sample_weight: None | ArrayLike = None,
85+
multioutput: Literal["raw_values"],
86+
) -> ndarray: ...
5187
@overload
5288
def mean_squared_error(
5389
y_true: MatrixLike | ArrayLike,
5490
y_pred: MatrixLike | ArrayLike,
5591
*,
5692
sample_weight: None | ArrayLike = None,
57-
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
58-
) -> ndarray | Float: ...
93+
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
94+
) -> float: ...
95+
@deprecated(
96+
"`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_error` instead to calculate the root mean squared error."
97+
)
5998
@overload
99+
def mean_squared_error(
100+
y_true: MatrixLike | ArrayLike,
101+
y_pred: MatrixLike | ArrayLike,
102+
*,
103+
sample_weight: None | ArrayLike = None,
104+
multioutput: Literal["raw_values"],
105+
squared: bool,
106+
) -> ndarray: ...
60107
@deprecated(
61108
"`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_error` instead to calculate the root mean squared error."
62109
)
110+
@overload
63111
def mean_squared_error(
64112
y_true: MatrixLike | ArrayLike,
65113
y_pred: MatrixLike | ArrayLike,
66114
*,
67115
sample_weight: None | ArrayLike = None,
68-
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
116+
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
69117
squared: bool,
70-
) -> ndarray | Float: ...
118+
) -> float: ...
119+
@overload
120+
def mean_squared_log_error(
121+
y_true: MatrixLike | ArrayLike,
122+
y_pred: MatrixLike | ArrayLike,
123+
*,
124+
sample_weight: None | ArrayLike = None,
125+
multioutput: Literal["raw_values"],
126+
) -> ndarray: ...
71127
@overload
72128
def mean_squared_log_error(
73129
y_true: MatrixLike | ArrayLike,
74130
y_pred: MatrixLike | ArrayLike,
75131
*,
76132
sample_weight: None | ArrayLike = None,
77-
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
78-
) -> float | ndarray: ...
133+
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
134+
) -> float: ...
135+
@deprecated(
136+
"`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_log_error` instead to calculate the root mean squared logarithmic error."
137+
)
79138
@overload
139+
def mean_squared_log_error(
140+
y_true: MatrixLike | ArrayLike,
141+
y_pred: MatrixLike | ArrayLike,
142+
*,
143+
sample_weight: None | ArrayLike = None,
144+
multioutput: Literal["raw_values"],
145+
squared: bool,
146+
) -> ndarray: ...
80147
@deprecated(
81148
"`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_log_error` instead to calculate the root mean squared logarithmic error."
82149
)
150+
@overload
83151
def mean_squared_log_error(
84152
y_true: MatrixLike | ArrayLike,
85153
y_pred: MatrixLike | ArrayLike,
86154
*,
87155
sample_weight: None | ArrayLike = None,
88-
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
156+
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
89157
squared: bool,
90-
) -> float | ndarray: ...
158+
) -> float: ...
159+
@overload
91160
def median_absolute_error(
92161
y_true: MatrixLike | ArrayLike,
93162
y_pred: MatrixLike | ArrayLike,
94163
*,
95-
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
164+
multioutput: Literal["raw_values"],
96165
sample_weight: None | ArrayLike = None,
97-
) -> ndarray | Float: ...
166+
) -> ndarray: ...
167+
@overload
168+
def median_absolute_error(
169+
y_true: MatrixLike | ArrayLike,
170+
y_pred: MatrixLike | ArrayLike,
171+
*,
172+
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
173+
sample_weight: None | ArrayLike = None,
174+
) -> Float: ...
175+
@overload
176+
def explained_variance_score(
177+
y_true: MatrixLike | ArrayLike,
178+
y_pred: MatrixLike | ArrayLike,
179+
*,
180+
sample_weight: None | ArrayLike = None,
181+
multioutput: Literal["raw_values"],
182+
force_finite: bool = True,
183+
) -> ndarray: ...
184+
@overload
98185
def explained_variance_score(
99186
y_true: MatrixLike | ArrayLike,
100187
y_pred: MatrixLike | ArrayLike,
101188
*,
102189
sample_weight: None | ArrayLike = None,
103-
multioutput: Literal["raw_values", "uniform_average", "variance_weighted"] | ArrayLike = "uniform_average",
190+
multioutput: Literal["uniform_average", "variance_weighted"] | ArrayLike = "uniform_average",
191+
force_finite: bool = True,
192+
) -> float: ...
193+
@overload
194+
def r2_score(
195+
y_true: MatrixLike | ArrayLike,
196+
y_pred: MatrixLike | ArrayLike,
197+
*,
198+
sample_weight: None | ArrayLike = None,
199+
multioutput: Literal["raw_values"],
104200
force_finite: bool = True,
105-
) -> float | ndarray: ...
201+
) -> ndarray: ...
202+
@overload
106203
def r2_score(
107204
y_true: MatrixLike | ArrayLike,
108205
y_pred: MatrixLike | ArrayLike,
109206
*,
110207
sample_weight: None | ArrayLike = None,
111-
multioutput: (Literal["raw_values", "uniform_average", "variance_weighted"] | None | ArrayLike) = "uniform_average",
208+
multioutput: Literal["uniform_average", "variance_weighted"] | ArrayLike | None = "uniform_average",
112209
force_finite: bool = True,
113-
) -> ndarray | Float: ...
210+
) -> float: ...
114211
def max_error(y_true: ArrayLike, y_pred: ArrayLike) -> float: ...
115212
def mean_tweedie_deviance(
116213
y_true: ArrayLike,
117214
y_pred: ArrayLike,
118215
*,
119216
sample_weight: None | ArrayLike = None,
120217
power: Float = 0,
121-
) -> Float: ...
218+
) -> float: ...
122219
def mean_poisson_deviance(y_true: ArrayLike, y_pred: ArrayLike, *, sample_weight: None | ArrayLike = None) -> Float: ...
123220
def mean_gamma_deviance(y_true: ArrayLike, y_pred: ArrayLike, *, sample_weight: None | ArrayLike = None) -> float: ...
124221
def d2_tweedie_score(
@@ -127,33 +224,70 @@ def d2_tweedie_score(
127224
*,
128225
sample_weight: None | ArrayLike = None,
129226
power: Float = 0,
130-
) -> float | ndarray: ...
227+
) -> float: ...
228+
@overload
131229
def d2_pinball_score(
132230
y_true: MatrixLike | ArrayLike,
133231
y_pred: MatrixLike | ArrayLike,
134232
*,
135233
sample_weight: None | ArrayLike = None,
136234
alpha: Float = 0.5,
137-
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
138-
) -> float | ndarray: ...
235+
multioutput: Literal["raw_values"],
236+
) -> ndarray: ...
237+
@overload
238+
def d2_pinball_score(
239+
y_true: MatrixLike | ArrayLike,
240+
y_pred: MatrixLike | ArrayLike,
241+
*,
242+
sample_weight: None | ArrayLike = None,
243+
alpha: Float = 0.5,
244+
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
245+
) -> Float: ...
246+
@overload
247+
def d2_absolute_error_score(
248+
y_true: MatrixLike | ArrayLike,
249+
y_pred: MatrixLike | ArrayLike,
250+
*,
251+
sample_weight: None | ArrayLike = None,
252+
multioutput: Literal["raw_values"],
253+
) -> ndarray: ...
254+
@overload
139255
def d2_absolute_error_score(
140256
y_true: MatrixLike | ArrayLike,
141257
y_pred: MatrixLike | ArrayLike,
142258
*,
143259
sample_weight: None | ArrayLike = None,
144-
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
145-
) -> float | ndarray: ...
260+
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
261+
) -> Float: ...
262+
@overload
146263
def root_mean_squared_error(
147264
y_true: MatrixLike | ArrayLike,
148265
y_pred: MatrixLike | ArrayLike,
149266
*,
150267
sample_weight: None | ArrayLike = None,
151-
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
152-
) -> float | ndarray: ...
268+
multioutput: Literal["raw_values"],
269+
) -> ndarray: ...
270+
@overload
271+
def root_mean_squared_error(
272+
y_true: MatrixLike | ArrayLike,
273+
y_pred: MatrixLike | ArrayLike,
274+
*,
275+
sample_weight: None | ArrayLike = None,
276+
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
277+
) -> float: ...
278+
@overload
279+
def root_mean_squared_log_error(
280+
y_true: MatrixLike | ArrayLike,
281+
y_pred: MatrixLike | ArrayLike,
282+
*,
283+
sample_weight: None | ArrayLike = None,
284+
multioutput: Literal["raw_values"],
285+
) -> ndarray: ...
286+
@overload
153287
def root_mean_squared_log_error(
154288
y_true: MatrixLike | ArrayLike,
155289
y_pred: MatrixLike | ArrayLike,
156290
*,
157291
sample_weight: None | ArrayLike = None,
158-
multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average",
159-
) -> float | ndarray: ...
292+
multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average",
293+
) -> float: ...

0 commit comments

Comments
 (0)