@@ -107,7 +107,50 @@ def test_setitem(shape, data):
107
107
ph .assert_0d_equals ("__setitem__" , "value" , value , f"x[{ key } ]" , res [key ])
108
108
109
109
110
- # TODO: test boolean indexing
110
+ # TODO: make mask tests optional
111
+
112
+
113
+ @given (hh .shapes (), st .data ())
114
+ def test_getitem_mask (shape , data ):
115
+ x = data .draw (xps .arrays (xps .scalar_dtypes (), shape = shape ), label = "x" )
116
+ mask_shapes = st .one_of (
117
+ st .sampled_from ([x .shape , ()]),
118
+ st .lists (st .booleans (), min_size = x .ndim , max_size = x .ndim ).map (
119
+ lambda l : tuple (s if b else 0 for s , b in zip (x .shape , l ))
120
+ ),
121
+ hh .shapes (),
122
+ )
123
+ key = data .draw (xps .arrays (dtype = xp .bool , shape = mask_shapes ), label = "key" )
124
+
125
+ if key .ndim > x .ndim or not all (
126
+ ks in (xs , 0 ) for xs , ks in zip (x .shape , key .shape )
127
+ ):
128
+ with pytest .raises (IndexError ):
129
+ x [key ]
130
+ return
131
+
132
+ out = x [key ]
133
+
134
+ ph .assert_dtype ("__getitem__" , x .dtype , out .dtype )
135
+ if key .ndim == 0 :
136
+ out_shape = (1 ,) if key else (0 ,)
137
+ out_shape += x .shape
138
+ else :
139
+ size = int (xp .sum (xp .astype (key , xp .uint8 )))
140
+ out_shape = (size ,) + x .shape [key .ndim :]
141
+ ph .assert_shape ("__getitem__" , out .shape , out_shape )
142
+
143
+
144
+ @given (hh .shapes (min_side = 1 ), st .data ())
145
+ def test_setitem_mask (shape , data ):
146
+ x = data .draw (xps .arrays (xps .scalar_dtypes (), shape = shape ), label = "x" )
147
+ key = data .draw (xps .arrays (dtype = xp .bool , shape = shape ), label = "key" )
148
+ value = data .draw (xps .from_dtype (x .dtype ), label = "value" )
149
+
150
+ res = xp .asarray (x , copy = True )
151
+ res [key ] = value
152
+
153
+ # TODO
111
154
112
155
113
156
def make_param (method_name : str , dtype : DataType , stype : ScalarType ) -> Param :
0 commit comments