From f21d3eb4b18dbaee61f2d218e74de841d61ae692 Mon Sep 17 00:00:00 2001 From: Mateusz Kozak Date: Mon, 24 Mar 2025 11:13:38 +0100 Subject: [PATCH 1/4] Add path overloads in `sklearn.metrics._regression` Various sklearn metrics return floats or ndarrays based on the value of `multioutput` parameter. This commit adds overloads for the separate paths. --- stubs/sklearn/metrics/_regression.pyi | 191 ++++++++++++++++++++++---- 1 file changed, 162 insertions(+), 29 deletions(-) diff --git a/stubs/sklearn/metrics/_regression.pyi b/stubs/sklearn/metrics/_regression.pyi index c0cf69d3..e2d2bbd5 100644 --- a/stubs/sklearn/metrics/_regression.pyi +++ b/stubs/sklearn/metrics/_regression.pyi @@ -34,36 +34,84 @@ from ..utils.validation import ( __ALL__: list = ... +@overload +def mean_absolute_error( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, + sample_weight: None | ArrayLike = None, + multioutput: Literal["raw_values"], +) -> ndarray: ... +@overload def mean_absolute_error( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike, *, sample_weight: None | ArrayLike = None, - multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average", -) -> ndarray | Float: ... + multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", +) -> Float: ... +@overload def mean_pinball_loss( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike, *, sample_weight: None | ArrayLike = None, alpha: float = 0.5, - multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average", -) -> ndarray | Float: ... + multioutput: Literal["raw_values"], +) -> ndarray: ... +@overload +def mean_pinball_loss( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, + sample_weight: None | ArrayLike = None, + alpha: float = 0.5, + multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", +) -> Float: ... +@overload +def mean_absolute_percentage_error( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, + sample_weight: None | ArrayLike = None, + multioutput: Literal["raw_values"], +) -> ndarray: ... +@overload def mean_absolute_percentage_error( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike, *, sample_weight: None | ArrayLike = None, - multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average", -) -> ndarray | Float: ... + multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", +) -> Float: ... @overload def mean_squared_error( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike, *, sample_weight: None | ArrayLike = None, - multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average", -) -> ndarray | Float: ... + multioutput: Literal["raw_values"], +) -> ndarray: ... +@overload +def mean_squared_error( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, + sample_weight: None | ArrayLike = None, + multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", +) -> Float: ... +@deprecated( + "`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_error` instead to calculate the root mean squared error." +) +@overload +def mean_squared_error( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, + sample_weight: None | ArrayLike = None, + multioutput: Literal["raw_values"], + squared: bool, +) -> ndarray: ... @deprecated( "`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_error` instead to calculate the root mean squared error." ) @@ -73,17 +121,25 @@ def mean_squared_error( y_pred: MatrixLike | ArrayLike, *, sample_weight: None | ArrayLike = None, - multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average", + multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", squared: bool, -) -> ndarray | Float: ... +) -> Float: ... @overload def mean_squared_log_error( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike, *, sample_weight: None | ArrayLike = None, - multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average", -) -> float | ndarray: ... + multioutput: Literal["raw_values"], +) -> ndarray: ... +@overload +def mean_squared_log_error( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, + sample_weight: None | ArrayLike = None, + multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", +) -> float: ... @deprecated( "`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_log_error` instead to calculate the root mean squared logarithmic error." ) @@ -93,32 +149,73 @@ def mean_squared_log_error( y_pred: MatrixLike | ArrayLike, *, sample_weight: None | ArrayLike = None, - multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average", + multioutput: Literal["raw_values"], squared: bool, -) -> float | ndarray: ... +) -> ndarray: ... +@deprecated( + "`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_log_error` instead to calculate the root mean squared logarithmic error." +) +@overload +def mean_squared_log_error( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, + sample_weight: None | ArrayLike = None, + multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", + squared: bool, +) -> float: ... +@overload +def median_absolute_error( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, + multioutput: Literal["raw_values"], + sample_weight: None | ArrayLike = None, +) -> ndarray: ... +@overload def median_absolute_error( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike, *, - multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average", + multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", + sample_weight: None | ArrayLike = None, +) -> Float: ... +@overload +def explained_variance_score( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, sample_weight: None | ArrayLike = None, -) -> ndarray | Float: ... + multioutput: Literal["raw_values"], + force_finite: bool = True, +) -> ndarray: ... +@overload def explained_variance_score( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike, *, sample_weight: None | ArrayLike = None, - multioutput: Literal["raw_values", "uniform_average", "variance_weighted"] | ArrayLike = "uniform_average", + multioutput: Literal["uniform_average", "variance_weighted"] | ArrayLike = "uniform_average", + force_finite: bool = True, +) -> float: ... +@overload +def r2_score( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, + sample_weight: None | ArrayLike = None, + multioutput: Literal["raw_values"], force_finite: bool = True, -) -> float | ndarray: ... +) -> ndarray: ... +@overload def r2_score( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike, *, sample_weight: None | ArrayLike = None, - multioutput: (Literal["raw_values", "uniform_average", "variance_weighted"] | None | ArrayLike) = "uniform_average", + multioutput: Literal["uniform_average", "variance_weighted"] | ArrayLike | None = "uniform_average", force_finite: bool = True, -) -> ndarray | Float: ... +) -> Float: ... def max_error(y_true: ArrayLike, y_pred: ArrayLike) -> float: ... def mean_tweedie_deviance( y_true: ArrayLike, @@ -135,33 +232,69 @@ def d2_tweedie_score( *, sample_weight: None | ArrayLike = None, power: Float = 0, -) -> float | ndarray: ... +@overload def d2_pinball_score( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike, *, sample_weight: None | ArrayLike = None, alpha: Float = 0.5, - multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average", -) -> float | ndarray: ... + multioutput: Literal["raw_values"], +) -> ndarray: ... +@overload +def d2_pinball_score( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, + sample_weight: None | ArrayLike = None, + alpha: Float = 0.5, + multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", +) -> float: ... +@overload +def d2_absolute_error_score( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, + sample_weight: None | ArrayLike = None, + multioutput: Literal["raw_values"], +) -> ndarray: ... +@overload def d2_absolute_error_score( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike, *, sample_weight: None | ArrayLike = None, - multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average", -) -> float | ndarray: ... + multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", +) -> float: ... +@overload +def root_mean_squared_error( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, + sample_weight: None | ArrayLike = None, + multioutput: Literal["raw_values"], +) -> ndarray: ... +@overload def root_mean_squared_error( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike, *, sample_weight: None | ArrayLike = None, - multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average", -) -> float | ndarray: ... + multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", +) -> float: ... +@overload +def root_mean_squared_log_error( + y_true: MatrixLike | ArrayLike, + y_pred: MatrixLike | ArrayLike, + *, + sample_weight: None | ArrayLike = None, + multioutput: Literal["raw_values"], +) -> ndarray: ... +@overload def root_mean_squared_log_error( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike, *, sample_weight: None | ArrayLike = None, - multioutput: ArrayLike | Literal["raw_values", "uniform_average"] = "uniform_average", -) -> float | ndarray: ... + multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", +) -> float: ... From 849eb4f1d69564d12511313dcefefc36c92f4d38 Mon Sep 17 00:00:00 2001 From: Mateusz Kozak Date: Mon, 24 Mar 2025 11:37:54 +0100 Subject: [PATCH 2/4] Fix float types in `sklearn.metrics._regression` Various sklearn metrics return either a standard Python float, or a numpy flating point scalar type. E.g. ``` >>> import numpy as np >>> from sklearn.metrics import mean_absolute_error, median_absolute_error >>> a = np.array([1,2,3]) >>> b = np.array([4,5,6]) >>> type(mean_absolute_error(a,b)) float >>> type(median_absolute_error(a,b)) numpy.float64 ``` This commit fixes the type annotations for the following functions: - `mean_absolute_error` - `mean_absolute_percentage_error` - `mean_squared_error` - `r2_score` - `mean_tweedie_deviance` - `d2_pinball_score` - `d2_absolute_error_score` --- stubs/sklearn/metrics/_regression.pyi | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/stubs/sklearn/metrics/_regression.pyi b/stubs/sklearn/metrics/_regression.pyi index e2d2bbd5..068a65ec 100644 --- a/stubs/sklearn/metrics/_regression.pyi +++ b/stubs/sklearn/metrics/_regression.pyi @@ -49,7 +49,7 @@ def mean_absolute_error( *, sample_weight: None | ArrayLike = None, multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", -) -> Float: ... +) -> float: ... @overload def mean_pinball_loss( y_true: MatrixLike | ArrayLike, @@ -83,7 +83,7 @@ def mean_absolute_percentage_error( *, sample_weight: None | ArrayLike = None, multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", -) -> Float: ... +) -> float: ... @overload def mean_squared_error( y_true: MatrixLike | ArrayLike, @@ -99,7 +99,7 @@ def mean_squared_error( *, sample_weight: None | ArrayLike = None, multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", -) -> Float: ... +) -> float: ... @deprecated( "`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_error` instead to calculate the root mean squared error." ) @@ -123,7 +123,7 @@ def mean_squared_error( sample_weight: None | ArrayLike = None, multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", squared: bool, -) -> Float: ... +) -> float: ... @overload def mean_squared_log_error( y_true: MatrixLike | ArrayLike, @@ -215,7 +215,7 @@ def r2_score( sample_weight: None | ArrayLike = None, multioutput: Literal["uniform_average", "variance_weighted"] | ArrayLike | None = "uniform_average", force_finite: bool = True, -) -> Float: ... +) -> float: ... def max_error(y_true: ArrayLike, y_pred: ArrayLike) -> float: ... def mean_tweedie_deviance( y_true: ArrayLike, @@ -223,7 +223,7 @@ def mean_tweedie_deviance( *, sample_weight: None | ArrayLike = None, power: Float = 0, -) -> Float: ... +) -> float: ... def mean_poisson_deviance(y_true: ArrayLike, y_pred: ArrayLike, *, sample_weight: None | ArrayLike = None) -> Float: ... def mean_gamma_deviance(y_true: ArrayLike, y_pred: ArrayLike, *, sample_weight: None | ArrayLike = None) -> float: ... def d2_tweedie_score( @@ -249,7 +249,7 @@ def d2_pinball_score( sample_weight: None | ArrayLike = None, alpha: Float = 0.5, multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", -) -> float: ... +) -> Float: ... @overload def d2_absolute_error_score( y_true: MatrixLike | ArrayLike, @@ -265,7 +265,7 @@ def d2_absolute_error_score( *, sample_weight: None | ArrayLike = None, multioutput: Literal["uniform_average"] | ArrayLike = "uniform_average", -) -> float: ... +) -> Float: ... @overload def root_mean_squared_error( y_true: MatrixLike | ArrayLike, From bda81cc883d2f33ade2c21ddf45ea03bc66bfebe Mon Sep 17 00:00:00 2001 From: Mateusz Kozak Date: Mon, 24 Mar 2025 11:49:12 +0100 Subject: [PATCH 3/4] Fix `d2_tweedie_score` return The docs say float or ndarray but there is not ndarray return path. --- stubs/sklearn/metrics/_regression.pyi | 1 + 1 file changed, 1 insertion(+) diff --git a/stubs/sklearn/metrics/_regression.pyi b/stubs/sklearn/metrics/_regression.pyi index 068a65ec..7185e107 100644 --- a/stubs/sklearn/metrics/_regression.pyi +++ b/stubs/sklearn/metrics/_regression.pyi @@ -232,6 +232,7 @@ def d2_tweedie_score( *, sample_weight: None | ArrayLike = None, power: Float = 0, +) -> float: ... @overload def d2_pinball_score( y_true: MatrixLike | ArrayLike, From c35a82fb549a5c25359e824ef3c9ae9af06a6218 Mon Sep 17 00:00:00 2001 From: Erik De Bonte Date: Thu, 29 May 2025 22:00:22 -0700 Subject: [PATCH 4/4] Undo removed @overloads from merge Originally these were not overloads, but now they are. --- stubs/sklearn/metrics/_regression.pyi | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stubs/sklearn/metrics/_regression.pyi b/stubs/sklearn/metrics/_regression.pyi index e6ee667c..29236789 100644 --- a/stubs/sklearn/metrics/_regression.pyi +++ b/stubs/sklearn/metrics/_regression.pyi @@ -107,6 +107,7 @@ def mean_squared_error( @deprecated( "`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_error` instead to calculate the root mean squared error." ) +@overload def mean_squared_error( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike, @@ -146,6 +147,7 @@ def mean_squared_log_error( @deprecated( "`squared` is deprecated in 1.4 and will be removed in 1.6. Use `root_mean_squared_log_error` instead to calculate the root mean squared logarithmic error." ) +@overload def mean_squared_log_error( y_true: MatrixLike | ArrayLike, y_pred: MatrixLike | ArrayLike,