@@ -479,27 +479,26 @@ def test_linspace(num, dtype, endpoint, data):
479
479
ah .assert_exactly_equal (out , expected )
480
480
481
481
482
- @given (
482
+ @given (dtype = xps .numeric_dtypes (), data = st .data ())
483
+ def test_meshgrid (dtype , data ):
483
484
# The number and size of generated arrays is arbitrarily limited to prevent
484
485
# meshgrid() running out of memory.
485
- dtypes = hh .mutually_promotable_dtypes (5 , dtypes = dh .numeric_dtypes ),
486
- data = st .data (),
487
- )
488
- def test_meshgrid (dtypes , data ):
489
- arrays = []
490
486
shapes = data .draw (
491
- hh .mutually_broadcastable_shapes (
492
- len (dtypes ), min_dims = 1 , max_dims = 1 , max_side = 5
487
+ st .integers (1 , 5 ).flatmap (
488
+ lambda n : hh .mutually_broadcastable_shapes (
489
+ n , min_dims = 1 , max_dims = 1 , max_side = 5
490
+ )
493
491
),
494
492
label = "shapes" ,
495
493
)
496
- for i , (dtype , shape ) in enumerate (zip (dtypes , shapes ), 1 ):
494
+ arrays = []
495
+ for i , shape in enumerate (shapes , 1 ):
497
496
x = data .draw (xps .arrays (dtype = dtype , shape = shape ), label = f"x{ i } " )
498
497
arrays .append (x )
499
498
assert math .prod (x .size for x in arrays ) <= hh .MAX_ARRAY_SIZE # sanity check
500
499
out = xp .meshgrid (* arrays )
501
500
for i , x in enumerate (out ):
502
- ph .assert_dtype ("meshgrid" , dtypes , x .dtype , repr_name = f"out[{ i } ].dtype" )
501
+ ph .assert_dtype ("meshgrid" , dtype , x .dtype , repr_name = f"out[{ i } ].dtype" )
503
502
504
503
505
504
def make_one (dtype : DataType ) -> Scalar :
0 commit comments