Skip to content

Commit dd83d68

Browse files
committed
Fix test_meshgrid generating multiple dtypes
1 parent 91026c0 commit dd83d68

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

array_api_tests/test_creation_functions.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -479,27 +479,26 @@ def test_linspace(num, dtype, endpoint, data):
479479
ah.assert_exactly_equal(out, expected)
480480

481481

482-
@given(
482+
@given(dtype=xps.numeric_dtypes(), data=st.data())
483+
def test_meshgrid(dtype, data):
483484
# The number and size of generated arrays is arbitrarily limited to prevent
484485
# meshgrid() running out of memory.
485-
dtypes=hh.mutually_promotable_dtypes(5, dtypes=dh.numeric_dtypes),
486-
data=st.data(),
487-
)
488-
def test_meshgrid(dtypes, data):
489-
arrays = []
490486
shapes = data.draw(
491-
hh.mutually_broadcastable_shapes(
492-
len(dtypes), min_dims=1, max_dims=1, max_side=5
487+
st.integers(1, 5).flatmap(
488+
lambda n: hh.mutually_broadcastable_shapes(
489+
n, min_dims=1, max_dims=1, max_side=5
490+
)
493491
),
494492
label="shapes",
495493
)
496-
for i, (dtype, shape) in enumerate(zip(dtypes, shapes), 1):
494+
arrays = []
495+
for i, shape in enumerate(shapes, 1):
497496
x = data.draw(xps.arrays(dtype=dtype, shape=shape), label=f"x{i}")
498497
arrays.append(x)
499498
assert math.prod(x.size for x in arrays) <= hh.MAX_ARRAY_SIZE # sanity check
500499
out = xp.meshgrid(*arrays)
501500
for i, x in enumerate(out):
502-
ph.assert_dtype("meshgrid", dtypes, x.dtype, repr_name=f"out[{i}].dtype")
501+
ph.assert_dtype("meshgrid", dtype, x.dtype, repr_name=f"out[{i}].dtype")
503502

504503

505504
def make_one(dtype: DataType) -> Scalar:

0 commit comments

Comments
 (0)