Skip to content

Commit 21ab5e1

Browse files
OmarManzoorlesteve
andauthored
FIX avoid error for metrics on polars series for numpy<1.21 (scikit-learn#29490)
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
1 parent 6fd2f83 commit 21ab5e1

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

sklearn/metrics/tests/test_common.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,3 +2039,20 @@ def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers)
20392039
@pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
20402040
def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func):
20412041
check_func(metric, array_namespace, device, dtype_name)
2042+
2043+
2044+
@pytest.mark.parametrize("df_lib_name", ["pandas", "polars"])
2045+
@pytest.mark.parametrize("metric_name", sorted(ALL_METRICS))
2046+
def test_metrics_dataframe_series(metric_name, df_lib_name):
2047+
df_lib = pytest.importorskip(df_lib_name)
2048+
2049+
y_pred = df_lib.Series([0.0, 1.0, 0, 1.0])
2050+
y_true = df_lib.Series([1.0, 0.0, 0.0, 0.0])
2051+
2052+
metric = ALL_METRICS[metric_name]
2053+
try:
2054+
expected_metric = metric(y_pred.to_numpy(), y_true.to_numpy())
2055+
except ValueError:
2056+
pytest.skip(f"{metric_name} can not deal with 1d inputs")
2057+
2058+
assert_allclose(metric(y_pred, y_true), expected_metric)

sklearn/utils/_array_api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,15 @@ def reshape(self, x, shape, *, copy=None):
439439
return numpy.reshape(x, shape)
440440

441441
def isdtype(self, dtype, kind):
442-
return isdtype(dtype, kind, xp=self)
442+
try:
443+
return isdtype(dtype, kind, xp=self)
444+
except TypeError:
445+
# In older versions of numpy, data types that arise from outside
446+
# numpy like from a Polars Series raise a TypeError.
447+
# e.g. TypeError: Cannot interpret 'Int64' as a data type.
448+
# Therefore, we return False.
449+
# TODO: Remove when minimum supported version of numpy is >= 1.21.
450+
return False
443451

444452
def pow(self, x1, x2):
445453
return numpy.power(x1, x2)

0 commit comments

Comments
 (0)