Skip to content

Commit 5879f2e

Browse files
authored
FIX pass xp to avoid redundant namespace inspection (scikit-learn#30092)
1 parent 325930e commit 5879f2e

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

sklearn/metrics/_regression.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
7171
dtype : str or list, default="numeric"
7272
the dtype argument passed to check_array.
7373
74+
xp : module, default=None
75+
Precomputed array namespace module. When passed, typically from a caller
76+
that has already performed inspection of its own inputs, skips array
77+
namespace inspection.
78+
7479
Returns
7580
-------
7681
type_true : one of {'continuous', continuous-multioutput'}
@@ -398,7 +403,7 @@ def mean_absolute_percentage_error(
398403
dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)
399404

400405
y_type, y_true, y_pred, multioutput = _check_reg_targets(
401-
y_true, y_pred, multioutput
406+
y_true, y_pred, multioutput, dtype=dtype, xp=xp
402407
)
403408
check_consistent_length(y_true, y_pred, sample_weight)
404409
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=dtype)
@@ -1253,7 +1258,7 @@ def max_error(y_true, y_pred):
12531258
np.int64(1)
12541259
"""
12551260
xp, _ = get_namespace(y_true, y_pred)
1256-
y_type, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, None)
1261+
y_type, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, None, xp=xp)
12571262
if y_type == "continuous-multioutput":
12581263
raise ValueError("Multioutput not supported in max_error")
12591264
return xp.max(xp.abs(y_true - y_pred))
@@ -1352,7 +1357,7 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0):
13521357
"""
13531358
xp, _ = get_namespace(y_true, y_pred)
13541359
y_type, y_true, y_pred, _ = _check_reg_targets(
1355-
y_true, y_pred, None, dtype=[xp.float64, xp.float32]
1360+
y_true, y_pred, None, dtype=[xp.float64, xp.float32], xp=xp
13561361
)
13571362
if y_type == "continuous-multioutput":
13581363
raise ValueError("Multioutput not supported in mean_tweedie_deviance")

0 commit comments

Comments
 (0)