Skip to content

Commit 9ee8381

Browse files
committed
Basic mask tests
1 parent 081d700 commit 9ee8381

File tree

1 file changed

+44
-1
lines changed

1 file changed

+44
-1
lines changed

array_api_tests/test_array_object.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,50 @@ def test_setitem(shape, data):
107107
ph.assert_0d_equals("__setitem__", "value", value, f"x[{key}]", res[key])
108108

109109

110-
# TODO: test boolean indexing
110+
# TODO: make mask tests optional
111+
112+
113+
@given(hh.shapes(), st.data())
114+
def test_getitem_mask(shape, data):
115+
x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x")
116+
mask_shapes = st.one_of(
117+
st.sampled_from([x.shape, ()]),
118+
st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map(
119+
lambda l: tuple(s if b else 0 for s, b in zip(x.shape, l))
120+
),
121+
hh.shapes(),
122+
)
123+
key = data.draw(xps.arrays(dtype=xp.bool, shape=mask_shapes), label="key")
124+
125+
if key.ndim > x.ndim or not all(
126+
ks in (xs, 0) for xs, ks in zip(x.shape, key.shape)
127+
):
128+
with pytest.raises(IndexError):
129+
x[key]
130+
return
131+
132+
out = x[key]
133+
134+
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
135+
if key.ndim == 0:
136+
out_shape = (1,) if key else (0,)
137+
out_shape += x.shape
138+
else:
139+
size = int(xp.sum(xp.astype(key, xp.uint8)))
140+
out_shape = (size,) + x.shape[key.ndim :]
141+
ph.assert_shape("__getitem__", out.shape, out_shape)
142+
143+
144+
@given(hh.shapes(min_side=1), st.data())
145+
def test_setitem_mask(shape, data):
146+
x = data.draw(xps.arrays(xps.scalar_dtypes(), shape=shape), label="x")
147+
key = data.draw(xps.arrays(dtype=xp.bool, shape=shape), label="key")
148+
value = data.draw(xps.from_dtype(x.dtype), label="value")
149+
150+
res = xp.asarray(x, copy=True)
151+
res[key] = value
152+
153+
# TODO
111154

112155

113156
def make_param(method_name: str, dtype: DataType, stype: ScalarType) -> Param:

0 commit comments

Comments
 (0)