@@ -25,11 +25,15 @@ def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scal
25
25
)
26
26
27
27
28
- @given (hh .shapes (min_side = 1 ), st .data ()) # TODO: test 0-sided arrays
28
+ @given (hh .shapes (), st .data ())
29
29
def test_getitem (shape , data ):
30
30
dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
31
- obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
32
- x = xp .asarray (obj , dtype = dtype )
31
+ zero_sided = any (side == 0 for side in shape )
32
+ if zero_sided :
33
+ x = xp .ones (shape , dtype = dtype )
34
+ else :
35
+ obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
36
+ x = xp .asarray (obj , dtype = dtype )
33
37
note (f"{ x = } " )
34
38
key = data .draw (xps .indices (shape = shape , allow_newaxis = True ), label = "key" )
35
39
@@ -62,16 +66,17 @@ def test_getitem(shape, data):
62
66
a += 1
63
67
out_shape = tuple (out_shape )
64
68
ph .assert_shape ("__getitem__" , out .shape , out_shape )
65
- assume (all (len (indices ) > 0 for indices in axes_indices ))
66
- out_obj = []
67
- for idx in product (* axes_indices ):
68
- val = obj
69
- for i in idx :
70
- val = val [i ]
71
- out_obj .append (val )
72
- out_obj = sh .reshape (out_obj , out_shape )
73
- expected = xp .asarray (out_obj , dtype = dtype )
74
- ph .assert_array ("__getitem__" , out , expected )
69
+ out_zero_sided = any (side == 0 for side in out_shape )
70
+ if not zero_sided and not out_zero_sided :
71
+ out_obj = []
72
+ for idx in product (* axes_indices ):
73
+ val = obj
74
+ for i in idx :
75
+ val = val [i ]
76
+ out_obj .append (val )
77
+ out_obj = sh .reshape (out_obj , out_shape )
78
+ expected = xp .asarray (out_obj , dtype = dtype )
79
+ ph .assert_array ("__getitem__" , out , expected )
75
80
76
81
77
82
@given (hh .shapes (min_side = 1 ), st .data ()) # TODO: test 0-sided arrays
0 commit comments