@@ -71,6 +71,11 @@ def _check_reg_targets(y_true, y_pred, multioutput, dtype="numeric", xp=None):
71
71
dtype : str or list, default="numeric"
72
72
the dtype argument passed to check_array.
73
73
74
+ xp : module, default=None
75
+ Precomputed array namespace module. When passed, typically from a caller
76
+ that has already performed inspection of its own inputs, skips array
77
+ namespace inspection.
78
+
74
79
Returns
75
80
-------
76
81
type_true : one of {'continuous', continuous-multioutput'}
@@ -398,7 +403,7 @@ def mean_absolute_percentage_error(
398
403
dtype = _find_matching_floating_dtype (y_true , y_pred , sample_weight , xp = xp )
399
404
400
405
y_type , y_true , y_pred , multioutput = _check_reg_targets (
401
- y_true , y_pred , multioutput
406
+ y_true , y_pred , multioutput , dtype = dtype , xp = xp
402
407
)
403
408
check_consistent_length (y_true , y_pred , sample_weight )
404
409
epsilon = xp .asarray (xp .finfo (xp .float64 ).eps , dtype = dtype )
@@ -1253,7 +1258,7 @@ def max_error(y_true, y_pred):
1253
1258
np.int64(1)
1254
1259
"""
1255
1260
xp , _ = get_namespace (y_true , y_pred )
1256
- y_type , y_true , y_pred , _ = _check_reg_targets (y_true , y_pred , None )
1261
+ y_type , y_true , y_pred , _ = _check_reg_targets (y_true , y_pred , None , xp = xp )
1257
1262
if y_type == "continuous-multioutput" :
1258
1263
raise ValueError ("Multioutput not supported in max_error" )
1259
1264
return xp .max (xp .abs (y_true - y_pred ))
@@ -1352,7 +1357,7 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0):
1352
1357
"""
1353
1358
xp , _ = get_namespace (y_true , y_pred )
1354
1359
y_type , y_true , y_pred , _ = _check_reg_targets (
1355
- y_true , y_pred , None , dtype = [xp .float64 , xp .float32 ]
1360
+ y_true , y_pred , None , dtype = [xp .float64 , xp .float32 ], xp = xp
1356
1361
)
1357
1362
if y_type == "continuous-multioutput" :
1358
1363
raise ValueError ("Multioutput not supported in mean_tweedie_deviance" )
0 commit comments