Skip to content

Commit 85b3fba

Browse files
committed
Address review comments for isin tests
1 parent cf28258 commit 85b3fba

File tree

1 file changed

+33
-26
lines changed

1 file changed

+33
-26
lines changed

dpctl/tests/test_tensor_isin.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,16 @@ def test_isin_basic(dtype):
4444
skip_if_dtype_not_supported(dtype, q)
4545

4646
n = 100
47-
x = dpt.arange(n, dtype=dtype)
48-
test = dpt.arange(n - 1, dtype=dtype)
47+
x = dpt.arange(n, dtype=dtype, sycl_queue=q)
48+
test = dpt.arange(n - 1, dtype=dtype, sycl_queue=q)
4949
r1 = dpt.isin(x, test)
5050
assert dpt.all(r1[:-1])
5151
assert not r1[-1]
5252
assert r1.shape == x.shape
5353

5454
# test with invert keyword
5555
r2 = dpt.isin(x, test, invert=True)
56-
assert not dpt.all(r2[:-1])
56+
assert not dpt.any(r2[:-1])
5757
assert r2[-1]
5858
assert r2.shape == x.shape
5959

@@ -70,7 +70,7 @@ def test_isin_basic_bool():
7070
assert r1.shape == x.shape
7171

7272
r2 = dpt.isin(x, test, invert=True)
73-
assert not dpt.all(r2[:-1])
73+
assert not dpt.any(r2[:-1])
7474
assert r2[-1]
7575
assert r2.shape == x.shape
7676

@@ -98,37 +98,44 @@ def test_isin_strided(dtype):
9898
skip_if_dtype_not_supported(dtype, q)
9999

100100
n, m = 100, 20
101-
x = dpt.zeros((n, m), dtype=dtype, order="F")
102-
x[:, ::2] = dpt.arange(1, (m / 2) + 1, dtype=dtype)
103-
test = dpt.arange(1, (m / 2) + 1, dtype=dtype)
104-
r1 = dpt.isin(x, test)
105-
assert dpt.all(r1[:, ::2])
106-
assert not dpt.all(r1[:, 1::2])
107-
assert r1.shape == x.shape
101+
x = dpt.zeros((n, m), dtype=dtype, order="F", sycl_queue=q)
102+
x[:, ::2] = dpt.arange(1, (m / 2) + 1, dtype=dtype, sycl_queue=q)
103+
x_s = x[:, ::2]
104+
test = dpt.arange(1, (m / 2), dtype=dtype, sycl_queue=q)
105+
r1 = dpt.isin(x_s, test)
106+
assert dpt.all(r1[:, :-1])
107+
assert not dpt.any(r1[:, -1])
108+
assert not dpt.any(x[:, 1::2])
109+
assert r1.shape == x_s.shape
108110

109111
# test with invert keyword
110-
r2 = dpt.isin(x, test, invert=True)
111-
assert not dpt.all(r2[:, ::2])
112-
assert dpt.all(r2[:, 1::2])
113-
assert r2.shape == x.shape
112+
r2 = dpt.isin(x_s, test, invert=True)
113+
assert not dpt.any(r2[:, :-1])
114+
assert dpt.all(r2[:, -1])
115+
assert not dpt.any(x[:, 1:2])
116+
assert r2.shape == x_s.shape
114117

115118

116119
def test_isin_strided_bool():
117120
dt = dpt.bool
121+
118122
n, m = 100, 20
119-
x = dpt.ones((n, m), dtype=dt, order="F")
120-
x[:, ::2] = False
121-
test = dpt.zeros((), dtype=dt)
122-
r1 = dpt.isin(x, test)
123-
assert dpt.all(r1[:, ::2])
124-
assert not dpt.all(r1[:, 1::2])
125-
assert r1.shape == x.shape
123+
x = dpt.zeros((n, m), dtype=dt, order="F")
124+
x[:, :-2:2] = True
125+
x_s = x[:, ::2]
126+
test = dpt.ones((), dtype=dt)
127+
r1 = dpt.isin(x_s, test)
128+
assert dpt.all(r1[:, :-1])
129+
assert not dpt.any(r1[:, -1])
130+
assert not dpt.any(x[:, 1::2])
131+
assert r1.shape == x_s.shape
126132

127133
# test with invert keyword
128-
r2 = dpt.isin(x, test, invert=True)
129-
assert not dpt.all(r2[:, ::2])
130-
assert dpt.all(r2[:, 1::2])
131-
assert r2.shape == x.shape
134+
r2 = dpt.isin(x_s, test, invert=True)
135+
assert not dpt.any(r2[:, :-1])
136+
assert dpt.all(r2[:, -1])
137+
assert not dpt.any(x[:, 1:2])
138+
assert r2.shape == x_s.shape
132139

133140

134141
def test_isin_empty_inputs():

0 commit comments

Comments
 (0)