Skip to content

Commit 05ae945

Browse files
Fixed tests to run on Iris Xe
Also ensure that test_add_order exercises non-same dtypes to improve coverage.
1 parent 121f819 commit 05ae945

File tree

1 file changed

+46
-34
lines changed

1 file changed

+46
-34
lines changed

dpctl/tests/elementwise/test_add.py

Lines changed: 46 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype):
3333
assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all()
3434
assert r.sycl_queue == ar1.sycl_queue
3535

36-
out = dpt.empty_like(ar1, dtype=expected_dtype)
36+
out = dpt.empty_like(ar1, dtype=r.dtype)
3737
dpt.add(ar1, ar2, out)
3838
assert (dpt.asnumpy(out) == np.full(out.shape, 2, dtype=out.dtype)).all()
3939

@@ -49,7 +49,7 @@ def test_add_dtype_matrix(op1_dtype, op2_dtype):
4949
assert r.shape == ar3.shape
5050
assert (dpt.asnumpy(r) == np.full(r.shape, 2, dtype=r.dtype)).all()
5151

52-
out = dpt.empty_like(ar1, dtype=expected_dtype)
52+
out = dpt.empty_like(ar1, dtype=r.dtype)
5353
dpt.add(ar3[::-1], ar4[::2], out)
5454
assert (dpt.asnumpy(out) == np.full(out.shape, 2, dtype=out.dtype)).all()
5555

@@ -74,37 +74,49 @@ def test_add_usm_type_matrix(op1_usm_type, op2_usm_type):
7474
def test_add_order():
7575
get_queue_or_skip()
7676

77-
ar1 = dpt.ones((20, 20), dtype="i4", order="C")
78-
ar2 = dpt.ones((20, 20), dtype="i4", order="C")
79-
r1 = dpt.add(ar1, ar2, order="C")
80-
assert r1.flags.c_contiguous
81-
r2 = dpt.add(ar1, ar2, order="F")
82-
assert r2.flags.f_contiguous
83-
r3 = dpt.add(ar1, ar2, order="A")
84-
assert r3.flags.c_contiguous
85-
r4 = dpt.add(ar1, ar2, order="K")
86-
assert r4.flags.c_contiguous
87-
88-
ar1 = dpt.ones((20, 20), dtype="i4", order="F")
89-
ar2 = dpt.ones((20, 20), dtype="i4", order="F")
90-
r1 = dpt.add(ar1, ar2, order="C")
91-
assert r1.flags.c_contiguous
92-
r2 = dpt.add(ar1, ar2, order="F")
93-
assert r2.flags.f_contiguous
94-
r3 = dpt.add(ar1, ar2, order="A")
95-
assert r3.flags.f_contiguous
96-
r4 = dpt.add(ar1, ar2, order="K")
97-
assert r4.flags.f_contiguous
98-
99-
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
100-
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2]
101-
r4 = dpt.add(ar1, ar2, order="K")
102-
assert r4.strides == (20, -1)
103-
104-
ar1 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
105-
ar2 = dpt.ones((40, 40), dtype="i4", order="C")[:20, ::-2].mT
106-
r4 = dpt.add(ar1, ar2, order="K")
107-
assert r4.strides == (-1, 20)
77+
test_shape = (
78+
20,
79+
20,
80+
)
81+
test_shape2 = tuple(2 * dim for dim in test_shape)
82+
n = test_shape[-1]
83+
84+
for dt1, dt2 in zip(["i4", "i4", "f4"], ["i4", "f4", "i4"]):
85+
ar1 = dpt.ones(test_shape, dtype=dt1, order="C")
86+
ar2 = dpt.ones(test_shape, dtype=dt2, order="C")
87+
r1 = dpt.add(ar1, ar2, order="C")
88+
assert r1.flags.c_contiguous
89+
r2 = dpt.add(ar1, ar2, order="F")
90+
assert r2.flags.f_contiguous
91+
r3 = dpt.add(ar1, ar2, order="A")
92+
assert r3.flags.c_contiguous
93+
r4 = dpt.add(ar1, ar2, order="K")
94+
assert r4.flags.c_contiguous
95+
96+
ar1 = dpt.ones(test_shape, dtype=dt1, order="F")
97+
ar2 = dpt.ones(test_shape, dtype=dt2, order="F")
98+
r1 = dpt.add(ar1, ar2, order="C")
99+
assert r1.flags.c_contiguous
100+
r2 = dpt.add(ar1, ar2, order="F")
101+
assert r2.flags.f_contiguous
102+
r3 = dpt.add(ar1, ar2, order="A")
103+
assert r3.flags.f_contiguous
104+
r4 = dpt.add(ar1, ar2, order="K")
105+
assert r4.flags.f_contiguous
106+
107+
ar1 = dpt.ones(test_shape2, dtype=dt1, order="C")[:20, ::-2]
108+
ar2 = dpt.ones(test_shape2, dtype=dt2, order="C")[:20, ::-2]
109+
r4 = dpt.add(ar1, ar2, order="K")
110+
assert r4.strides == (n, -1)
111+
r5 = dpt.add(ar1, ar2, order="C")
112+
assert r5.strides == (n, 1)
113+
114+
ar1 = dpt.ones(test_shape2, dtype=dt1, order="C")[:20, ::-2].mT
115+
ar2 = dpt.ones(test_shape2, dtype=dt2, order="C")[:20, ::-2].mT
116+
r4 = dpt.add(ar1, ar2, order="K")
117+
assert r4.strides == (-1, n)
118+
r5 = dpt.add(ar1, ar2, order="C")
119+
assert r5.strides == (n, 1)
108120

109121

110122
def test_add_broadcasting():
@@ -266,7 +278,7 @@ def test_add_dtype_error(
266278
skip_if_dtype_not_supported(dtype, q)
267279

268280
ar1 = dpt.ones(5, dtype=dtype)
269-
ar2 = dpt.ones_like(ar1, dtype="f8")
281+
ar2 = dpt.ones_like(ar1, dtype="f4")
270282

271283
y = dpt.zeros_like(ar1, dtype="int8")
272284
assert_raises_regex(

0 commit comments

Comments
 (0)