@@ -25,12 +25,11 @@ def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scal
25
25
)
26
26
27
27
28
- @given (hh .shapes (), st .data ())
29
- def test_getitem (shape , data ):
30
- dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
28
+ @given (shape = hh .shapes (), dtype = xps .scalar_dtypes (), data = st .data ())
29
+ def test_getitem (shape , dtype , data ):
31
30
zero_sided = any (side == 0 for side in shape )
32
31
if zero_sided :
33
- x = xp .ones (shape , dtype = dtype )
32
+ x = xp .zeros (shape , dtype = dtype )
34
33
else :
35
34
obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
36
35
x = xp .asarray (obj , dtype = dtype )
@@ -76,45 +75,62 @@ def test_getitem(shape, data):
76
75
out_obj .append (val )
77
76
out_obj = sh .reshape (out_obj , out_shape )
78
77
expected = xp .asarray (out_obj , dtype = dtype )
79
- ph .assert_array ("__getitem__" , out , expected )
78
+ ph .assert_array_elements ("__getitem__" , out , expected )
80
79
81
80
82
- @given (hh .shapes (min_side = 1 ), st .data ()) # TODO: test 0-sided arrays
83
- def test_setitem (shape , data ):
84
- dtype = data .draw (xps .scalar_dtypes (), label = "dtype" )
85
- obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
86
- x = xp .asarray (obj , dtype = dtype )
81
+ @given (shape = hh .shapes (min_side = 1 ), dtype = xps .scalar_dtypes (), data = st .data ())
82
+ def test_setitem (shape , dtype , data ):
83
+ zero_sided = any (side == 0 for side in shape )
84
+ if zero_sided :
85
+ x = xp .zeros (shape , dtype = dtype )
86
+ else :
87
+ obj = data .draw (scalar_objects (dtype , shape ), label = "obj" )
88
+ x = xp .asarray (obj , dtype = dtype )
87
89
note (f"{ x = } " )
88
- # TODO: test setting non-0d arrays
89
- key = data .draw (xps .indices (shape = shape , max_dims = 0 ), label = "key" )
90
- value = data .draw (
91
- xps .from_dtype (dtype ) | xps .arrays (dtype = dtype , shape = ()), label = "value"
92
- )
90
+ key = data .draw (xps .indices (shape = shape ), label = "key" )
91
+ _key = tuple (key ) if isinstance (key , tuple ) else (key ,)
92
+ if Ellipsis in _key :
93
+ nonexpanding_key = tuple (i for i in _key if i is not None )
94
+ start_a = nonexpanding_key .index (Ellipsis )
95
+ stop_a = start_a + (len (shape ) - (len (nonexpanding_key ) - 1 ))
96
+ slices = tuple (slice (None ) for _ in range (start_a , stop_a ))
97
+ start_pos = _key .index (Ellipsis )
98
+ _key = _key [:start_pos ] + slices + _key [start_pos + 1 :]
99
+ out_shape = []
100
+ for a , i in enumerate (_key ):
101
+ if isinstance (i , slice ):
102
+ side = shape [a ]
103
+ indices = range (side )[i ]
104
+ out_shape .append (len (indices ))
105
+ out_shape = tuple (out_shape )
106
+ value_strat = xps .arrays (dtype = dtype , shape = out_shape )
107
+ if out_shape == ():
108
+ # We can pass scalars if we're only indexing one element
109
+ value_strat |= xps .from_dtype (dtype )
110
+ value = data .draw (value_strat , label = "value" )
93
111
94
112
res = xp .asarray (x , copy = True )
95
113
res [key ] = value
96
114
97
115
ph .assert_dtype ("__setitem__" , x .dtype , res .dtype , repr_name = "x.dtype" )
98
116
ph .assert_shape ("__setitem__" , res .shape , x .shape , repr_name = "x.shape" )
117
+ f_res = f"res[{ sh .fmt_idx ('x' , key )} ]"
99
118
if isinstance (value , get_args (Scalar )):
100
- msg = f"x[ { key } ] ={ res [key ]!r} , but should be { value = } [__setitem__()]"
119
+ msg = f"{ f_res } ={ res [key ]!r} , but should be { value = } [__setitem__()]"
101
120
if math .isnan (value ):
102
121
assert xp .isnan (res [key ]), msg
103
122
else :
104
123
assert res [key ] == value , msg
105
124
else :
106
- ph .assert_0d_equals (
107
- "__setitem__" , "value" , value , f"modified x[{ key } ]" , res [key ]
108
- )
109
- _key = key if isinstance (key , tuple ) else (key ,)
110
- assume (all (isinstance (i , int ) for i in _key )) # TODO: normalise slices and ellipsis
111
- _key = tuple (i if i >= 0 else s + i for i , s in zip (_key , x .shape ))
112
- unaffected_indices = list (sh .ndindex (res .shape ))
113
- unaffected_indices .remove (_key )
114
- for idx in unaffected_indices :
115
- ph .assert_0d_equals (
116
- "__setitem__" , f"old x[{ idx } ]" , x [idx ], f"modified x[{ idx } ]" , res [idx ]
117
- )
125
+ ph .assert_array_elements ("__setitem__" , res [key ], value , out_repr = f_res )
126
+ if all (isinstance (i , int ) for i in _key ): # TODO: normalise slices and ellipsis
127
+ _key = tuple (i if i >= 0 else s + i for i , s in zip (_key , x .shape ))
128
+ unaffected_indices = list (sh .ndindex (res .shape ))
129
+ unaffected_indices .remove (_key )
130
+ for idx in unaffected_indices :
131
+ ph .assert_0d_equals (
132
+ "__setitem__" , f"old x[{ idx } ]" , x [idx ], f"modified x[{ idx } ]" , res [idx ]
133
+ )
118
134
119
135
120
136
@pytest .mark .data_dependent_shapes
0 commit comments