Skip to content

Commit 46ef5cc

Browse files
tomwhitethodson-usgs
authored andcommitted
More lenient dtype support (#550)
* Allow `bool` in `sum` and `prod` * Make dtype checking more lenient in case of 'all'
1 parent 65c58fa commit 46ef5cc

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

cubed/array_api/array_object.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -399,12 +399,18 @@ def __int__(self, /):
399399
# Utility methods
400400

401401
def _check_allowed_dtypes(self, other, dtype_category, op):
402-
if self.dtype not in _dtype_categories[dtype_category]:
402+
if (
403+
dtype_category != "all"
404+
and self.dtype not in _dtype_categories[dtype_category]
405+
):
403406
raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
404407
if isinstance(other, (int, complex, float, bool)):
405408
other = self._promote_scalar(other)
406409
elif isinstance(other, CoreArray):
407-
if other.dtype not in _dtype_categories[dtype_category]:
410+
if (
411+
dtype_category != "all"
412+
and other.dtype not in _dtype_categories[dtype_category]
413+
):
408414
raise TypeError(f"Only {dtype_category} dtypes are allowed in {op}")
409415
else:
410416
return NotImplemented

cubed/array_api/statistical_functions.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22

33
from cubed.array_api.dtypes import (
4+
_boolean_dtypes,
45
_numeric_dtypes,
56
_real_floating_dtypes,
67
_real_numeric_dtypes,
@@ -124,10 +125,13 @@ def min(x, /, *, axis=None, keepdims=False, use_new_impl=True, split_every=None)
124125
def prod(
125126
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
126127
):
127-
if x.dtype not in _numeric_dtypes:
128-
raise TypeError("Only numeric dtypes are allowed in prod")
128+
# boolean is allowed by numpy
129+
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
130+
raise TypeError("Only numeric or boolean dtypes are allowed in prod")
129131
if dtype is None:
130-
if x.dtype in _signed_integer_dtypes:
132+
if x.dtype in _boolean_dtypes:
133+
dtype = int64
134+
elif x.dtype in _signed_integer_dtypes:
131135
dtype = int64
132136
elif x.dtype in _unsigned_integer_dtypes:
133137
dtype = uint64
@@ -153,10 +157,13 @@ def prod(
153157
def sum(
154158
x, /, *, axis=None, dtype=None, keepdims=False, use_new_impl=True, split_every=None
155159
):
156-
if x.dtype not in _numeric_dtypes:
157-
raise TypeError("Only numeric dtypes are allowed in sum")
160+
# boolean is allowed by numpy
161+
if x.dtype not in _numeric_dtypes and x.dtype not in _boolean_dtypes:
162+
raise TypeError("Only numeric or boolean dtypes are allowed in sum")
158163
if dtype is None:
159-
if x.dtype in _signed_integer_dtypes:
164+
if x.dtype in _boolean_dtypes:
165+
dtype = int64
166+
elif x.dtype in _signed_integer_dtypes:
160167
dtype = int64
161168
elif x.dtype in _unsigned_integer_dtypes:
162169
dtype = uint64

cubed/tests/test_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from numpy.testing import assert_array_equal
2+
3+
import cubed.array_api as xp
4+
5+
6+
# This is less strict than the spec, but is supported by implementations like NumPy
7+
def test_prod_sum_bool():
8+
a = xp.ones((2,), dtype=xp.bool)
9+
assert_array_equal(xp.prod(a).compute(), xp.asarray([1], dtype=xp.int64))
10+
assert_array_equal(xp.sum(a).compute(), xp.asarray([2], dtype=xp.int64))

0 commit comments

Comments
 (0)