Skip to content

Commit ff3fed4

Browse files
committed
Test non-0d-resulting keys in test_setitem
Also `ph.assert_array()` -> `ph.assert_array_elements()`
1 parent 34fda92 commit ff3fed4

File tree

5 files changed

+65
-47
lines changed

5 files changed

+65
-47
lines changed

array_api_tests/meta/test_pytest_helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ def test_assert_dtype():
1313
ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool)
1414

1515

16-
def test_assert_array():
17-
ph.assert_array("int zeros", xp.asarray(0), xp.asarray(0))
18-
ph.assert_array("pos zeros", xp.asarray(0.0), xp.asarray(0.0))
16+
def test_assert_array_elements():
17+
ph.assert_array_elements("int zeros", xp.asarray(0), xp.asarray(0))
18+
ph.assert_array_elements("pos zeros", xp.asarray(0.0), xp.asarray(0.0))
1919
with raises(AssertionError):
20-
ph.assert_array("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0))
20+
ph.assert_array_elements("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0))
2121
with raises(AssertionError):
22-
ph.assert_array("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0))
22+
ph.assert_array_elements("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0))

array_api_tests/pytest_helpers.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"assert_keepdimable_shape",
2626
"assert_0d_equals",
2727
"assert_fill",
28-
"assert_array",
28+
"assert_array_elements",
2929
]
3030

3131

@@ -374,28 +374,30 @@ def assert_fill(
374374
assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg
375375

376376

377-
def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
377+
def assert_array_elements(
378+
func_name: str, out: Array, expected: Array, /, *, out_repr: str = "out", **kw
379+
):
378380
"""
379-
Assert array is (strictly) as expected, e.g.
381+
Assert array elements are (strictly) as expected, e.g.
380382
381383
>>> x = xp.arange(5)
382384
>>> out = xp.asarray(x)
383-
>>> assert_array('asarray', out, x)
385+
>>> assert_array_elements('asarray', out, x)
384386
385387
is equivalent to
386388
387389
>>> assert xp.all(out == x)
388390
389391
"""
390-
assert_dtype(func_name, out.dtype, expected.dtype)
391-
assert_shape(func_name, out.shape, expected.shape, **kw)
392+
dh.result_type(out.dtype, expected.dtype) # sanity check
393+
assert_shape(func_name, out.shape, expected.shape, **kw) # sanity check
392394
f_func = f"[{func_name}({fmt_kw(kw)})]"
393395
if dh.is_float_dtype(out.dtype):
394396
for idx in sh.ndindex(out.shape):
395397
at_out = out[idx]
396398
at_expected = expected[idx]
397399
msg = (
398-
f"{sh.fmt_idx('out', idx)}={at_out}, should be {at_expected} "
400+
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
399401
f"{f_func}"
400402
)
401403
if xp.isnan(at_expected):
@@ -411,6 +413,6 @@ def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
411413
else:
412414
assert at_out == at_expected, msg
413415
else:
414-
assert xp.all(out == expected), (
415-
f"out not as expected {f_func}\n" f"{out=}\n{expected=}"
416-
)
416+
assert xp.all(
417+
out == expected
418+
), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}"

array_api_tests/test_array_object.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@ def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scal
2525
)
2626

2727

28-
@given(hh.shapes(), st.data())
29-
def test_getitem(shape, data):
30-
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
28+
@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data())
29+
def test_getitem(shape, dtype, data):
3130
zero_sided = any(side == 0 for side in shape)
3231
if zero_sided:
33-
x = xp.ones(shape, dtype=dtype)
32+
x = xp.zeros(shape, dtype=dtype)
3433
else:
3534
obj = data.draw(scalar_objects(dtype, shape), label="obj")
3635
x = xp.asarray(obj, dtype=dtype)
@@ -76,45 +75,62 @@ def test_getitem(shape, data):
7675
out_obj.append(val)
7776
out_obj = sh.reshape(out_obj, out_shape)
7877
expected = xp.asarray(out_obj, dtype=dtype)
79-
ph.assert_array("__getitem__", out, expected)
78+
ph.assert_array_elements("__getitem__", out, expected)
8079

8180

82-
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
83-
def test_setitem(shape, data):
84-
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
85-
obj = data.draw(scalar_objects(dtype, shape), label="obj")
86-
x = xp.asarray(obj, dtype=dtype)
81+
@given(shape=hh.shapes(min_side=1), dtype=xps.scalar_dtypes(), data=st.data())
82+
def test_setitem(shape, dtype, data):
83+
zero_sided = any(side == 0 for side in shape)
84+
if zero_sided:
85+
x = xp.zeros(shape, dtype=dtype)
86+
else:
87+
obj = data.draw(scalar_objects(dtype, shape), label="obj")
88+
x = xp.asarray(obj, dtype=dtype)
8789
note(f"{x=}")
88-
# TODO: test setting non-0d arrays
89-
key = data.draw(xps.indices(shape=shape, max_dims=0), label="key")
90-
value = data.draw(
91-
xps.from_dtype(dtype) | xps.arrays(dtype=dtype, shape=()), label="value"
92-
)
90+
key = data.draw(xps.indices(shape=shape), label="key")
91+
_key = tuple(key) if isinstance(key, tuple) else (key,)
92+
if Ellipsis in _key:
93+
nonexpanding_key = tuple(i for i in _key if i is not None)
94+
start_a = nonexpanding_key.index(Ellipsis)
95+
stop_a = start_a + (len(shape) - (len(nonexpanding_key) - 1))
96+
slices = tuple(slice(None) for _ in range(start_a, stop_a))
97+
start_pos = _key.index(Ellipsis)
98+
_key = _key[:start_pos] + slices + _key[start_pos + 1 :]
99+
out_shape = []
100+
for a, i in enumerate(_key):
101+
if isinstance(i, slice):
102+
side = shape[a]
103+
indices = range(side)[i]
104+
out_shape.append(len(indices))
105+
out_shape = tuple(out_shape)
106+
value_strat = xps.arrays(dtype=dtype, shape=out_shape)
107+
if out_shape == ():
108+
# We can pass scalars if we're only indexing one element
109+
value_strat |= xps.from_dtype(dtype)
110+
value = data.draw(value_strat, label="value")
93111

94112
res = xp.asarray(x, copy=True)
95113
res[key] = value
96114

97115
ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype")
98116
ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.shape")
117+
f_res = f"res[{sh.fmt_idx('x', key)}]"
99118
if isinstance(value, get_args(Scalar)):
100-
msg = f"x[{key}]={res[key]!r}, but should be {value=} [__setitem__()]"
119+
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
101120
if math.isnan(value):
102121
assert xp.isnan(res[key]), msg
103122
else:
104123
assert res[key] == value, msg
105124
else:
106-
ph.assert_0d_equals(
107-
"__setitem__", "value", value, f"modified x[{key}]", res[key]
108-
)
109-
_key = key if isinstance(key, tuple) else (key,)
110-
assume(all(isinstance(i, int) for i in _key)) # TODO: normalise slices and ellipsis
111-
_key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape))
112-
unaffected_indices = list(sh.ndindex(res.shape))
113-
unaffected_indices.remove(_key)
114-
for idx in unaffected_indices:
115-
ph.assert_0d_equals(
116-
"__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx]
117-
)
125+
ph.assert_array_elements("__setitem__", res[key], value, out_repr=f_res)
126+
if all(isinstance(i, int) for i in _key): # TODO: normalise slices and ellipsis
127+
_key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape))
128+
unaffected_indices = list(sh.ndindex(res.shape))
129+
unaffected_indices.remove(_key)
130+
for idx in unaffected_indices:
131+
ph.assert_0d_equals(
132+
"__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx]
133+
)
118134

119135

120136
@pytest.mark.data_dependent_shapes

array_api_tests/test_creation_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def test_arange(dtype, data):
180180
if dh.is_int_dtype(_dtype):
181181
elements = list(r)
182182
assume(out.size == len(elements))
183-
ph.assert_array("arange", out, xp.asarray(elements, dtype=_dtype))
183+
ph.assert_array_elements("arange", out, xp.asarray(elements, dtype=_dtype))
184184
else:
185185
assume(out.size == size)
186186
if out.size > 0:
@@ -262,7 +262,7 @@ def test_asarray_arrays(x, data):
262262
ph.assert_kw_dtype("asarray", dtype, out.dtype)
263263
ph.assert_shape("asarray", out.shape, x.shape)
264264
if dtype is None or dtype == x.dtype:
265-
ph.assert_array("asarray", out, x, **kw)
265+
ph.assert_array_elements("asarray", out, x, **kw)
266266
else:
267267
pass # TODO
268268
copy = kw.get("copy", None)
@@ -452,7 +452,7 @@ def test_linspace(num, dtype, endpoint, data):
452452
# the first num elements when endpoint=False
453453
expected = xp.linspace(start, stop, num + 1, dtype=dtype, endpoint=True)
454454
expected = expected[:-1]
455-
ph.assert_array("linspace", out, expected)
455+
ph.assert_array_elements("linspace", out, expected)
456456

457457

458458
@given(dtype=xps.numeric_dtypes(), data=st.data())

array_api_tests/test_operators_and_elementwise_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,7 @@ def test_positive(ctx, data):
11241124

11251125
ph.assert_dtype(ctx.func_name, x.dtype, out.dtype)
11261126
ph.assert_shape(ctx.func_name, out.shape, x.shape)
1127-
ph.assert_array(ctx.func_name, out, x)
1127+
ph.assert_array_elements(ctx.func_name, out, x)
11281128

11291129

11301130
@pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes))

0 commit comments

Comments
 (0)