Skip to content

Commit c80d344

Browse files
committed
modifying test cases
1 parent 4c9a340 commit c80d344

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

dpctl/tests/elementwise/test_complex.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,24 @@ def test_projection_complex(dtype):
127127
q = get_queue_or_skip()
128128
skip_if_dtype_not_supported(dtype, q)
129129

130-
X = [complex(1, 2), complex(dpt.inf, -1), complex(0, -dpt.inf)]
131-
Y = [complex(1, 2), complex(dpt.inf, -0), complex(dpt.inf, -0)]
130+
X = [
131+
complex(1, 2),
132+
complex(dpt.inf, -1),
133+
complex(0, -dpt.inf),
134+
complex(-dpt.inf, dpt.nan),
135+
]
136+
Y = [
137+
complex(1, 2),
138+
complex(np.inf, -0.0),
139+
complex(np.inf, -0.0),
140+
complex(np.inf, 0.0),
141+
]
132142

133143
Xf = dpt.asarray(X, dtype=dtype, sycl_queue=q)
134-
Yf = dpt.asarray(Y, dtype=dtype, sycl_queue=q)
144+
Yf = np.array(Y, dtype=dtype)
135145

136146
tol = 8 * dpt.finfo(Xf.dtype).resolution
137-
assert_allclose(
138-
dpt.asnumpy(dpt.proj(Xf)), dpt.asnumpy(Yf), atol=tol, rtol=tol
139-
)
147+
assert_allclose(dpt.asnumpy(dpt.proj(Xf)), Yf, atol=tol, rtol=tol)
140148

141149

142150
@pytest.mark.parametrize("dtype", _all_dtypes)
@@ -146,19 +154,17 @@ def test_projection(dtype):
146154

147155
Xf = dpt.asarray(1, dtype=dtype, sycl_queue=q)
148156
out_dtype = dpt.proj(Xf).dtype
149-
Yf = dpt.asarray(complex(1, 0), dtype=out_dtype, sycl_queue=q)
157+
Yf = np.array(complex(1, 0), dtype=out_dtype)
150158

151159
tol = 8 * dpt.finfo(Yf.dtype).resolution
152-
assert_allclose(
153-
dpt.asnumpy(dpt.proj(Xf)), dpt.asnumpy(Yf), atol=tol, rtol=tol
154-
)
160+
assert_allclose(dpt.asnumpy(dpt.proj(Xf)), Yf, atol=tol, rtol=tol)
155161

156162

157163
@pytest.mark.parametrize(
158164
"np_call, dpt_call",
159165
[(np.real, dpt.real), (np.imag, dpt.imag), (np.conj, dpt.conj)],
160166
)
161-
@pytest.mark.parametrize("dtype", ["f", "d"])
167+
@pytest.mark.parametrize("dtype", ["f4", "f8"])
162168
@pytest.mark.parametrize("stride", [-1, 1, 2, 4, 5])
163169
def test_complex_strided(np_call, dpt_call, dtype, stride):
164170
q = get_queue_or_skip()
@@ -176,7 +182,7 @@ def test_complex_strided(np_call, dpt_call, dtype, stride):
176182
assert_allclose(y, dpt.asnumpy(z), atol=tol, rtol=tol)
177183

178184

179-
@pytest.mark.parametrize("dtype", ["e", "f", "d"])
185+
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
180186
def test_complex_special_cases(dtype):
181187
q = get_queue_or_skip()
182188
skip_if_dtype_not_supported(dtype, q)

0 commit comments

Comments
 (0)