Skip to content

Commit aa05645

Browse files
authored
Merge pull request #2126 from IntelPython/refactor-trig-tests
Fix technical debt in trig element-wise function tests
2 parents e0dd5bd + 636147d commit aa05645

File tree

4 files changed

+6
-264
lines changed

4 files changed

+6
-264
lines changed

.github/workflows/generate-coverage.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ jobs:
9191
- name: Install dpctl dependencies
9292
shell: bash -l {0}
9393
run: |
94-
# TODO: unpin numpy when numpy#29167 resolved
95-
pip install numpy"<2.3.0" cython setuptools"<80" pytest pytest-cov scikit-build cmake coverage[toml] versioneer[toml]==0.29
94+
pip install numpy cython setuptools"<80" pytest pytest-cov scikit-build cmake coverage[toml] versioneer[toml]==0.29
9695
9796
- name: Build dpctl with coverage
9897
shell: bash -l {0}

.github/workflows/os-llvm-sycl-build.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ jobs:
107107
- name: Install dpctl dependencies
108108
shell: bash -l {0}
109109
run: |
110-
# TODO: unpin numpy when numpy#29167 resolved
111-
pip install numpy"<2.3.0" cython setuptools"<80" pytest scikit-build cmake ninja versioneer[toml]==0.29
110+
pip install numpy cython setuptools"<80" pytest scikit-build cmake ninja versioneer[toml]==0.29
112111
113112
- name: Checkout repo
114113
uses: actions/checkout@v4.2.2

dpctl/tests/elementwise/test_hyperbolic.py

Lines changed: 2 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import itertools
18-
import os
19-
import re
20-
2117
import numpy as np
2218
import pytest
2319
from numpy.testing import assert_allclose
@@ -34,7 +30,6 @@
3430
(np.arctanh, dpt.atanh),
3531
]
3632
_all_funcs = _hyper_funcs + _inv_hyper_funcs
37-
_dpt_funcs = [t[1] for t in _all_funcs]
3833

3934

4035
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@@ -45,17 +40,10 @@ def test_hyper_out_type(np_call, dpt_call, dtype):
4540

4641
a = 1 if np_call == np.arccosh else 0
4742

48-
X = dpt.asarray(a, dtype=dtype, sycl_queue=q)
49-
expected_dtype = np_call(np.array(a, dtype=dtype)).dtype
50-
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
51-
assert dpt_call(X).dtype == expected_dtype
52-
53-
X = dpt.asarray(a, dtype=dtype, sycl_queue=q)
43+
x = dpt.asarray(a, dtype=dtype, sycl_queue=q)
5444
expected_dtype = np_call(np.array(a, dtype=dtype)).dtype
5545
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
56-
Y = dpt.empty_like(X, dtype=expected_dtype)
57-
dpt_call(X, out=Y)
58-
assert_allclose(dpt.asnumpy(dpt_call(X)), dpt.asnumpy(Y))
46+
assert dpt_call(x).dtype == expected_dtype
5947

6048

6149
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@@ -119,79 +107,6 @@ def test_hyper_complex_contig(np_call, dpt_call, dtype):
119107
)
120108

121109

122-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
123-
@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
124-
def test_hyper_usm_type(np_call, dpt_call, usm_type):
125-
q = get_queue_or_skip()
126-
127-
arg_dt = np.dtype("f4")
128-
input_shape = (10, 10, 10, 10)
129-
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
130-
if np_call == np.arctanh:
131-
X[..., 0::2] = -0.4
132-
X[..., 1::2] = 0.3
133-
elif np_call == np.arccosh:
134-
X[..., 0::2] = 2.2
135-
X[..., 1::2] = 5.5
136-
else:
137-
X[..., 0::2] = -4.4
138-
X[..., 1::2] = 5.5
139-
140-
Y = dpt_call(X)
141-
assert Y.usm_type == X.usm_type
142-
assert Y.sycl_queue == X.sycl_queue
143-
assert Y.flags.c_contiguous
144-
145-
expected_Y = np_call(dpt.asnumpy(X))
146-
tol = 8 * dpt.finfo(Y.dtype).resolution
147-
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
148-
149-
150-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
151-
@pytest.mark.parametrize("dtype", _all_dtypes)
152-
def test_hyper_order(np_call, dpt_call, dtype):
153-
q = get_queue_or_skip()
154-
skip_if_dtype_not_supported(dtype, q)
155-
156-
arg_dt = np.dtype(dtype)
157-
input_shape = (4, 4, 4, 4)
158-
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
159-
if np_call == np.arctanh:
160-
X[..., 0::2] = -0.4
161-
X[..., 1::2] = 0.3
162-
elif np_call == np.arccosh:
163-
X[..., 0::2] = 2.2
164-
X[..., 1::2] = 5.5
165-
else:
166-
X[..., 0::2] = -4.4
167-
X[..., 1::2] = 5.5
168-
169-
for perms in itertools.permutations(range(4)):
170-
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
171-
with np.errstate(all="ignore"):
172-
expected_Y = np_call(dpt.asnumpy(U))
173-
for ord in ["C", "F", "A", "K"]:
174-
Y = dpt_call(U, order=ord)
175-
tol = 8 * max(
176-
dpt.finfo(Y.dtype).resolution,
177-
np.finfo(expected_Y.dtype).resolution,
178-
)
179-
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
180-
181-
182-
@pytest.mark.parametrize("callable", _dpt_funcs)
183-
@pytest.mark.parametrize("dtype", _all_dtypes)
184-
def test_hyper_error_dtype(callable, dtype):
185-
q = get_queue_or_skip()
186-
skip_if_dtype_not_supported(dtype, q)
187-
188-
x = dpt.ones(5, dtype=dtype)
189-
y = dpt.empty_like(x, dtype="int16")
190-
with pytest.raises(ValueError) as excinfo:
191-
callable(x, out=y)
192-
assert re.match("Output array of type.*is needed", str(excinfo.value))
193-
194-
195110
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
196111
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
197112
def test_hyper_real_strided(np_call, dpt_call, dtype):
@@ -270,46 +185,3 @@ def test_hyper_real_special_cases(np_call, dpt_call, dtype):
270185

271186
tol = 8 * dpt.finfo(dtype).resolution
272187
assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol)
273-
274-
275-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
276-
@pytest.mark.parametrize("dtype", ["c8", "c16"])
277-
def test_hyper_complex_special_cases_conj_property(np_call, dpt_call, dtype):
278-
q = get_queue_or_skip()
279-
skip_if_dtype_not_supported(dtype, q)
280-
281-
x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0]
282-
xc = [complex(*val) for val in itertools.product(x, repeat=2)]
283-
284-
Xc_np = np.array(xc, dtype=dtype)
285-
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)
286-
287-
tol = 50 * dpt.finfo(dtype).resolution
288-
Y = dpt_call(Xc)
289-
Yc = dpt_call(dpt.conj(Xc))
290-
291-
dpt.allclose(Y, dpt.conj(Yc), atol=tol, rtol=tol)
292-
293-
294-
@pytest.mark.skipif(
295-
os.name != "posix", reason="Known to fail on Windows due to bug in NumPy"
296-
)
297-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
298-
@pytest.mark.parametrize("dtype", ["c8", "c16"])
299-
def test_hyper_complex_special_cases(np_call, dpt_call, dtype):
300-
q = get_queue_or_skip()
301-
skip_if_dtype_not_supported(dtype, q)
302-
303-
x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0]
304-
xc = [complex(*val) for val in itertools.product(x, repeat=2)]
305-
306-
Xc_np = np.array(xc, dtype=dtype)
307-
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)
308-
309-
with np.errstate(all="ignore"):
310-
Ynp = np_call(Xc_np)
311-
312-
tol = 50 * dpt.finfo(dtype).resolution
313-
Y = dpt_call(Xc)
314-
assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol)
315-
assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol)

dpctl/tests/elementwise/test_trigonometric.py

Lines changed: 2 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
import itertools
18-
import os
19-
import re
20-
2117
import numpy as np
2218
import pytest
2319
from numpy.testing import assert_allclose
@@ -34,7 +30,6 @@
3430
(np.arctan, dpt.atan),
3531
]
3632
_all_funcs = _trig_funcs + _inv_trig_funcs
37-
_dpt_funcs = [t[1] for t in _all_funcs]
3833

3934

4035
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@@ -43,17 +38,10 @@ def test_trig_out_type(np_call, dpt_call, dtype):
4338
q = get_queue_or_skip()
4439
skip_if_dtype_not_supported(dtype, q)
4540

46-
X = dpt.asarray(0, dtype=dtype, sycl_queue=q)
47-
expected_dtype = np_call(np.array(0, dtype=dtype)).dtype
48-
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
49-
assert dpt_call(X).dtype == expected_dtype
50-
51-
X = dpt.asarray(0, dtype=dtype, sycl_queue=q)
41+
x = dpt.asarray(0, dtype=dtype, sycl_queue=q)
5242
expected_dtype = np_call(np.array(0, dtype=dtype)).dtype
5343
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
54-
Y = dpt.empty_like(X, dtype=expected_dtype)
55-
dpt_call(X, out=Y)
56-
assert_allclose(dpt.asnumpy(dpt_call(X)), dpt.asnumpy(Y))
44+
assert dpt_call(x).dtype == expected_dtype
5745

5846

5947
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@@ -127,78 +115,6 @@ def test_trig_complex_contig(np_call, dpt_call, dtype):
127115
assert_allclose(dpt.asnumpy(Z), expected, atol=tol, rtol=tol)
128116

129117

130-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
131-
@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
132-
def test_trig_usm_type(np_call, dpt_call, usm_type):
133-
q = get_queue_or_skip()
134-
135-
arg_dt = np.dtype("f4")
136-
input_shape = (10, 10, 10, 10)
137-
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
138-
if np_call in _trig_funcs:
139-
X[..., 0::2] = np.pi / 6
140-
X[..., 1::2] = np.pi / 3
141-
if np_call == np.arctan:
142-
X[..., 0::2] = -2.2
143-
X[..., 1::2] = 3.3
144-
else:
145-
X[..., 0::2] = -0.3
146-
X[..., 1::2] = 0.7
147-
148-
Y = dpt_call(X)
149-
assert Y.usm_type == X.usm_type
150-
assert Y.sycl_queue == X.sycl_queue
151-
assert Y.flags.c_contiguous
152-
153-
expected_Y = np_call(dpt.asnumpy(X))
154-
tol = 8 * dpt.finfo(Y.dtype).resolution
155-
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
156-
157-
158-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
159-
@pytest.mark.parametrize("dtype", _all_dtypes)
160-
def test_trig_order(np_call, dpt_call, dtype):
161-
q = get_queue_or_skip()
162-
skip_if_dtype_not_supported(dtype, q)
163-
164-
arg_dt = np.dtype(dtype)
165-
input_shape = (4, 4, 4, 4)
166-
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
167-
if np_call in _trig_funcs:
168-
X[..., 0::2] = np.pi / 6
169-
X[..., 1::2] = np.pi / 3
170-
if np_call == np.arctan:
171-
X[..., 0::2] = -2.2
172-
X[..., 1::2] = 3.3
173-
else:
174-
X[..., 0::2] = -0.3
175-
X[..., 1::2] = 0.7
176-
177-
for perms in itertools.permutations(range(4)):
178-
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
179-
expected_Y = np_call(dpt.asnumpy(U))
180-
for ord in ["C", "F", "A", "K"]:
181-
Y = dpt_call(U, order=ord)
182-
tol = 8 * max(
183-
dpt.finfo(Y.dtype).resolution,
184-
np.finfo(expected_Y.dtype).resolution,
185-
)
186-
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
187-
188-
189-
@pytest.mark.parametrize("callable", _dpt_funcs)
190-
@pytest.mark.parametrize("dtype", _all_dtypes)
191-
def test_trig_error_dtype(callable, dtype):
192-
q = get_queue_or_skip()
193-
skip_if_dtype_not_supported(dtype, q)
194-
195-
x = dpt.zeros(5, dtype=dtype)
196-
y = dpt.empty_like(x, dtype="int16")
197-
with pytest.raises(ValueError) as excinfo:
198-
callable(x, out=y)
199-
assert re.match("Output array of type.*is needed", str(excinfo.value))
200-
201-
202118
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
203119
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
204120
def test_trig_real_strided(np_call, dpt_call, dtype):
@@ -298,47 +214,3 @@ def test_trig_real_special_cases(np_call, dpt_call, dtype):
298214
tol = 8 * dpt.finfo(dtype).resolution
299215
Y = dpt_call(yf)
300216
assert_allclose(dpt.asnumpy(Y), Y_np, atol=tol, rtol=tol)
301-
302-
303-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
304-
@pytest.mark.parametrize("dtype", ["c8", "c16"])
305-
def test_trig_complex_special_cases_conj_property(np_call, dpt_call, dtype):
306-
q = get_queue_or_skip()
307-
skip_if_dtype_not_supported(dtype, q)
308-
309-
x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0]
310-
xc = [complex(*val) for val in itertools.product(x, repeat=2)]
311-
312-
Xc_np = np.array(xc, dtype=dtype)
313-
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)
314-
315-
tol = 50 * dpt.finfo(dtype).resolution
316-
Y = dpt_call(Xc)
317-
Yc = dpt_call(dpt.conj(Xc))
318-
319-
dpt.allclose(Y, dpt.conj(Yc), atol=tol, rtol=tol)
320-
321-
322-
@pytest.mark.skipif(
323-
os.name != "posix", reason="Known to fail on Windows due to bug in NumPy"
324-
)
325-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
326-
@pytest.mark.parametrize("dtype", ["c8", "c16"])
327-
def test_trig_complex_special_cases(np_call, dpt_call, dtype):
328-
329-
q = get_queue_or_skip()
330-
skip_if_dtype_not_supported(dtype, q)
331-
332-
x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0]
333-
xc = [complex(*val) for val in itertools.product(x, repeat=2)]
334-
335-
Xc_np = np.array(xc, dtype=dtype)
336-
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)
337-
338-
with np.errstate(all="ignore"):
339-
Ynp = np_call(Xc_np)
340-
341-
tol = 50 * dpt.finfo(dtype).resolution
342-
Y = dpt_call(Xc)
343-
assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol)
344-
assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol)

0 commit comments

Comments
 (0)