Skip to content

Commit 78a5b20

Browse files
committed
Merge branch 'master' into xfails-file
2 parents 44c7498 + 0c2c0f7 commit 78a5b20

File tree

6 files changed

+44
-25
lines changed

6 files changed

+44
-25
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,17 @@ def test_arange(dtype, data):
175175
#
176176
min_size = math.floor(size * 0.9)
177177
max_size = max(math.ceil(size * 1.1), 1)
178+
out_size = math.prod(out.shape)
178179
assert (
179-
min_size <= out.size <= max_size
180-
), f"{out.size=}, but should be roughly {size} {f_func}"
180+
min_size <= out_size <= max_size
181+
), f"prod(out.shape)={out_size}, but should be roughly {size} {f_func}"
181182
if dh.is_int_dtype(_dtype):
182183
elements = list(r)
183-
assume(out.size == len(elements))
184+
assume(out_size == len(elements))
184185
ph.assert_array_elements("arange", out, xp.asarray(elements, dtype=_dtype))
185186
else:
186-
assume(out.size == size)
187-
if out.size > 0:
187+
assume(out_size == size)
188+
if out_size > 0:
188189
assert xp.equal(
189190
out[0], xp.asarray(_start, dtype=out.dtype)
190191
), f"out[0]={out[0]}, but should be {_start} {f_func}"
@@ -497,7 +498,8 @@ def test_meshgrid(dtype, data):
497498
for i, shape in enumerate(shapes, 1):
498499
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
499500
arrays.append(x)
500-
assert math.prod(x.size for x in arrays) <= hh.MAX_ARRAY_SIZE # sanity check
501+
# sanity check
502+
assert math.prod(math.prod(x.shape) for x in arrays) <= hh.MAX_ARRAY_SIZE
501503
out = xp.meshgrid(*arrays)
502504
for i, x in enumerate(out):
503505
ph.assert_dtype("meshgrid", dtype, x.dtype, repr_name=f"out[{i}].dtype")

array_api_tests/test_manipulation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_concat(dtypes, base_shape, data):
9191
ph.assert_result_shape("concat", shapes, out.shape, shape, **kw)
9292

9393
if _axis is None:
94-
out_indices = (i for i in range(out.size))
94+
out_indices = (i for i in range(math.prod(out.shape)))
9595
for x_num, x in enumerate(arrays, 1):
9696
for x_idx in sh.ndindex(x.shape):
9797
out_i = next(out_indices)

array_api_tests/test_searching_functions.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import pytest
24
from hypothesis import given
35
from hypothesis import strategies as st
@@ -90,12 +92,14 @@ def test_nonzero(x):
9092
assert len(out) == 1, f"{len(out)=}, but should be 1 for 0-dimensional arrays"
9193
else:
9294
assert len(out) == x.ndim, f"{len(out)=}, but should be {x.ndim=}"
93-
size = out[0].size
95+
out_size = math.prod(out[0].shape)
9496
for i in range(len(out)):
9597
assert out[i].ndim == 1, f"out[{i}].ndim={x.ndim}, but should be 1"
96-
assert (
97-
out[i].size == size
98-
), f"out[{i}].size={x.size}, but should be out[0].size={size}"
98+
size_at = math.prod(out[i].shape)
99+
assert size_at == out_size, (
100+
f"prod(out[{i}].shape)={size_at}, "
101+
f"but should be prod(out[0].shape)={out_size}"
102+
)
99103
ph.assert_default_index("nonzero", out[i].dtype, repr_name=f"out[{i}].dtype")
100104
indices = []
101105
if x.dtype == xp.bool:
@@ -107,11 +111,11 @@ def test_nonzero(x):
107111
if x[idx] != 0:
108112
indices.append(idx)
109113
if x.ndim == 0:
110-
assert out[0].size == len(
114+
assert out_size == len(
111115
indices
112-
), f"{out[0].size=}, but should be {len(indices)}"
116+
), f"prod(out[0].shape)={out_size}, but should be {len(indices)}"
113117
else:
114-
for i in range(size):
118+
for i in range(out_size):
115119
idx = tuple(int(x[i]) for x in out)
116120
f_idx = f"Extrapolated index (x[{i}] for x in out)={idx}"
117121
f_element = f"x[{idx}]={x[idx]}"

array_api_tests/test_set_functions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_unique_all(x):
110110
vals_idx[val] = idx
111111

112112
if dh.is_float_dtype(out.values.dtype):
113-
assume(x.size <= 128) # may not be representable
113+
assume(math.prod(x.shape) <= 128) # may not be representable
114114
expected = sum(v for k, v in counts.items() if math.isnan(k))
115115
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
116116

@@ -157,7 +157,7 @@ def test_unique_counts(x):
157157
), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
158158
vals_idx[val] = idx
159159
if dh.is_float_dtype(out.values.dtype):
160-
assume(x.size <= 128) # may not be representable
160+
assume(math.prod(x.shape) <= 128) # may not be representable
161161
expected = sum(v for k, v in counts.items() if math.isnan(k))
162162
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"
163163

@@ -210,7 +210,7 @@ def test_unique_inverse(x):
210210
else:
211211
assert val == expected, msg
212212
if dh.is_float_dtype(out.values.dtype):
213-
assume(x.size <= 128) # may not be representable
213+
assume(math.prod(x.shape) <= 128) # may not be representable
214214
expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8))
215215
assert nans == expected, f"{nans} NaNs in out.values, but should be {expected}"
216216

@@ -234,6 +234,6 @@ def test_unique_values(x):
234234
), f"out[{idx}]={val}, but {val} is also in out[{vals_idx[val]}]"
235235
vals_idx[val] = idx
236236
if dh.is_float_dtype(out.dtype):
237-
assume(x.size <= 128) # may not be representable
237+
assume(math.prod(x.shape) <= 128) # may not be representable
238238
expected = xp.sum(xp.astype(xp.isnan(x), xp.uint8))
239239
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"

array_api_tests/test_statistical_functions.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
2222
dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
23+
if hh.FILTER_UNDEFINED_DTYPES:
24+
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
2325
return st.none() | st.sampled_from(dtypes)
2426

2527

@@ -145,9 +147,13 @@ def test_prod(x, data):
145147
_dtype = dh.default_float
146148
else:
147149
_dtype = dtype
148-
# We ignore asserting the out dtype if what we expect is undefined
149-
# See https://github.com/data-apis/array-api-tests/issues/106
150-
if not isinstance(_dtype, _UndefinedStub):
150+
if isinstance(_dtype, _UndefinedStub):
151+
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
152+
# uint32 or uint64), we skip testing the output dtype.
153+
# See https://github.com/data-apis/array-api-tests/issues/106
154+
if _dtype in dh.uint_dtypes:
155+
assert dh.is_int_dtype(out.dtype) # sanity check
156+
else:
151157
ph.assert_dtype("prod", x.dtype, out.dtype, _dtype)
152158
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
153159
ph.assert_keepdimable_shape(
@@ -173,7 +179,7 @@ def test_prod(x, data):
173179
dtype=xps.floating_dtypes(),
174180
shape=hh.shapes(min_side=1),
175181
elements={"allow_nan": False},
176-
).filter(lambda x: x.size >= 2),
182+
).filter(lambda x: math.prod(x.shape) >= 2),
177183
data=st.data(),
178184
)
179185
def test_std(x, data):
@@ -246,7 +252,14 @@ def test_sum(x, data):
246252
_dtype = dh.default_float
247253
else:
248254
_dtype = dtype
249-
ph.assert_dtype("sum", x.dtype, out.dtype, _dtype)
255+
if isinstance(_dtype, _UndefinedStub):
256+
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
257+
# uint32 or uint64), we skip testing the output dtype.
258+
# See https://github.com/data-apis/array-api-tests/issues/160
259+
if _dtype in dh.uint_dtypes:
260+
assert dh.is_int_dtype(out.dtype) # sanity check
261+
else:
262+
ph.assert_dtype("sum", x.dtype, out.dtype, _dtype)
250263
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
251264
ph.assert_keepdimable_shape(
252265
"sum", x.shape, out.shape, _axes, kw.get("keepdims", False), **kw
@@ -271,7 +284,7 @@ def test_sum(x, data):
271284
dtype=xps.floating_dtypes(),
272285
shape=hh.shapes(min_side=1),
273286
elements={"allow_nan": False},
274-
).filter(lambda x: x.size >= 2),
287+
).filter(lambda x: math.prod(x.shape) >= 2),
275288
data=st.data(),
276289
)
277290
def test_var(x, data):

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
pytest
22
pytest-json-report
3-
hypothesis>=6.55.0
3+
hypothesis>=6.62.1
44
ndindex>=1.6

0 commit comments

Comments
 (0)