Skip to content

Commit 7feaa28

Browse files
committed
Test different dtypes in test_setitem
1 parent ff3fed4 commit 7feaa28

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

array_api_tests/test_array_object.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from . import pytest_helpers as ph
1313
from . import shape_helpers as sh
1414
from . import xps
15+
from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
1516
from .typing import DataType, Param, Scalar, ScalarType, Shape
1617

1718
pytestmark = pytest.mark.ci
@@ -78,14 +79,18 @@ def test_getitem(shape, dtype, data):
7879
ph.assert_array_elements("__getitem__", out, expected)
7980

8081

81-
@given(shape=hh.shapes(min_side=1), dtype=xps.scalar_dtypes(), data=st.data())
82-
def test_setitem(shape, dtype, data):
82+
@given(
83+
shape=hh.shapes(),
84+
dtypes=oneway_promotable_dtypes(dh.all_dtypes),
85+
data=st.data(),
86+
)
87+
def test_setitem(shape, dtypes, data):
8388
zero_sided = any(side == 0 for side in shape)
8489
if zero_sided:
85-
x = xp.zeros(shape, dtype=dtype)
90+
x = xp.zeros(shape, dtype=dtypes.result_dtype)
8691
else:
87-
obj = data.draw(scalar_objects(dtype, shape), label="obj")
88-
x = xp.asarray(obj, dtype=dtype)
92+
obj = data.draw(scalar_objects(dtypes.result_dtype, shape), label="obj")
93+
x = xp.asarray(obj, dtype=dtypes.result_dtype)
8994
note(f"{x=}")
9095
key = data.draw(xps.indices(shape=shape), label="key")
9196
_key = tuple(key) if isinstance(key, tuple) else (key,)
@@ -103,10 +108,10 @@ def test_setitem(shape, dtype, data):
103108
indices = range(side)[i]
104109
out_shape.append(len(indices))
105110
out_shape = tuple(out_shape)
106-
value_strat = xps.arrays(dtype=dtype, shape=out_shape)
111+
value_strat = xps.arrays(dtype=dtypes.result_dtype, shape=out_shape)
107112
if out_shape == ():
108113
# We can pass scalars if we're only indexing one element
109-
value_strat |= xps.from_dtype(dtype)
114+
value_strat |= xps.from_dtype(dtypes.result_dtype)
110115
value = data.draw(value_strat, label="value")
111116

112117
res = xp.asarray(x, copy=True)

0 commit comments

Comments
 (0)