@@ -142,6 +142,35 @@ def test_isin_strided_bool():
142
142
assert r2 .shape == x_s .shape
143
143
144
144
145
+ @pytest .mark .parametrize ("dt1" , _numeric_dtypes )
146
+ @pytest .mark .parametrize ("dt2" , _numeric_dtypes )
147
+ def test_isin_dtype_matrix (dt1 , dt2 ):
148
+ q = get_queue_or_skip ()
149
+ skip_if_dtype_not_supported (dt1 , q )
150
+ skip_if_dtype_not_supported (dt2 , q )
151
+
152
+ sz = 10
153
+ x = dpt .asarray ([0 , 1 , 11 ], dtype = dt1 , sycl_queue = q )
154
+ test1 = dpt .arange (sz , dtype = dt2 , sycl_queue = q )
155
+
156
+ r1 = dpt .isin (x , test1 )
157
+ assert isinstance (r1 , dpt .usm_ndarray )
158
+ assert r1 .dtype == dpt .bool
159
+ assert r1 .shape == x .shape
160
+ assert not r1 [- 1 ]
161
+ assert dpt .all (r1 [0 :- 1 ])
162
+ assert r1 .sycl_queue == x .sycl_queue
163
+
164
+ test2 = dpt .tile (dpt .asarray ([[0 , 1 ]], dtype = dt2 , sycl_queue = q ).mT , 2 )
165
+ r2 = dpt .isin (x , test2 )
166
+ assert isinstance (r2 , dpt .usm_ndarray )
167
+ assert r2 .dtype == dpt .bool
168
+ assert r2 .shape == x .shape
169
+ assert not r2 [- 1 ]
170
+ assert dpt .all (r1 [0 :- 1 ])
171
+ assert r2 .sycl_queue == x .sycl_queue
172
+
173
+
145
174
def test_isin_empty_inputs ():
146
175
get_queue_or_skip ()
147
176
0 commit comments