Skip to content

Fix technical debt in trig element-wise function tests #2126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/generate-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ jobs:
- name: Install dpctl dependencies
shell: bash -l {0}
run: |
# TODO: unpin numpy when numpy#29167 resolved
pip install numpy"<2.3.0" cython setuptools"<80" pytest pytest-cov scikit-build cmake coverage[toml] versioneer[toml]==0.29
pip install numpy cython setuptools"<80" pytest pytest-cov scikit-build cmake coverage[toml] versioneer[toml]==0.29

- name: Build dpctl with coverage
shell: bash -l {0}
Expand Down
3 changes: 1 addition & 2 deletions .github/workflows/os-llvm-sycl-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ jobs:
- name: Install dpctl dependencies
shell: bash -l {0}
run: |
# TODO: unpin numpy when numpy#29167 resolved
pip install numpy"<2.3.0" cython setuptools"<80" pytest scikit-build cmake ninja versioneer[toml]==0.29
pip install numpy cython setuptools"<80" pytest scikit-build cmake ninja versioneer[toml]==0.29

- name: Checkout repo
uses: actions/checkout@v4.2.2
Expand Down
132 changes: 2 additions & 130 deletions dpctl/tests/elementwise/test_hyperbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import os
import re

import numpy as np
import pytest
from numpy.testing import assert_allclose
Expand All @@ -34,7 +30,6 @@
(np.arctanh, dpt.atanh),
]
_all_funcs = _hyper_funcs + _inv_hyper_funcs
_dpt_funcs = [t[1] for t in _all_funcs]


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
Expand All @@ -45,17 +40,10 @@ def test_hyper_out_type(np_call, dpt_call, dtype):

a = 1 if np_call == np.arccosh else 0

X = dpt.asarray(a, dtype=dtype, sycl_queue=q)
expected_dtype = np_call(np.array(a, dtype=dtype)).dtype
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
assert dpt_call(X).dtype == expected_dtype

X = dpt.asarray(a, dtype=dtype, sycl_queue=q)
x = dpt.asarray(a, dtype=dtype, sycl_queue=q)
expected_dtype = np_call(np.array(a, dtype=dtype)).dtype
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
Y = dpt.empty_like(X, dtype=expected_dtype)
dpt_call(X, out=Y)
assert_allclose(dpt.asnumpy(dpt_call(X)), dpt.asnumpy(Y))
assert dpt_call(x).dtype == expected_dtype


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
Expand Down Expand Up @@ -119,79 +107,6 @@ def test_hyper_complex_contig(np_call, dpt_call, dtype):
)


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
def test_hyper_usm_type(np_call, dpt_call, usm_type):
q = get_queue_or_skip()

arg_dt = np.dtype("f4")
input_shape = (10, 10, 10, 10)
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
if np_call == np.arctanh:
X[..., 0::2] = -0.4
X[..., 1::2] = 0.3
elif np_call == np.arccosh:
X[..., 0::2] = 2.2
X[..., 1::2] = 5.5
else:
X[..., 0::2] = -4.4
X[..., 1::2] = 5.5

Y = dpt_call(X)
assert Y.usm_type == X.usm_type
assert Y.sycl_queue == X.sycl_queue
assert Y.flags.c_contiguous

expected_Y = np_call(dpt.asnumpy(X))
tol = 8 * dpt.finfo(Y.dtype).resolution
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", _all_dtypes)
def test_hyper_order(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

arg_dt = np.dtype(dtype)
input_shape = (4, 4, 4, 4)
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
if np_call == np.arctanh:
X[..., 0::2] = -0.4
X[..., 1::2] = 0.3
elif np_call == np.arccosh:
X[..., 0::2] = 2.2
X[..., 1::2] = 5.5
else:
X[..., 0::2] = -4.4
X[..., 1::2] = 5.5

for perms in itertools.permutations(range(4)):
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
with np.errstate(all="ignore"):
expected_Y = np_call(dpt.asnumpy(U))
for ord in ["C", "F", "A", "K"]:
Y = dpt_call(U, order=ord)
tol = 8 * max(
dpt.finfo(Y.dtype).resolution,
np.finfo(expected_Y.dtype).resolution,
)
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)


@pytest.mark.parametrize("callable", _dpt_funcs)
@pytest.mark.parametrize("dtype", _all_dtypes)
def test_hyper_error_dtype(callable, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

x = dpt.ones(5, dtype=dtype)
y = dpt.empty_like(x, dtype="int16")
with pytest.raises(ValueError) as excinfo:
callable(x, out=y)
assert re.match("Output array of type.*is needed", str(excinfo.value))


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
def test_hyper_real_strided(np_call, dpt_call, dtype):
Expand Down Expand Up @@ -270,46 +185,3 @@ def test_hyper_real_special_cases(np_call, dpt_call, dtype):

tol = 8 * dpt.finfo(dtype).resolution
assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol)


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_hyper_complex_special_cases_conj_property(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0]
xc = [complex(*val) for val in itertools.product(x, repeat=2)]

Xc_np = np.array(xc, dtype=dtype)
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)

tol = 50 * dpt.finfo(dtype).resolution
Y = dpt_call(Xc)
Yc = dpt_call(dpt.conj(Xc))

dpt.allclose(Y, dpt.conj(Yc), atol=tol, rtol=tol)


@pytest.mark.skipif(
os.name != "posix", reason="Known to fail on Windows due to bug in NumPy"
)
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_hyper_complex_special_cases(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0]
xc = [complex(*val) for val in itertools.product(x, repeat=2)]

Xc_np = np.array(xc, dtype=dtype)
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)

with np.errstate(all="ignore"):
Ynp = np_call(Xc_np)

tol = 50 * dpt.finfo(dtype).resolution
Y = dpt_call(Xc)
assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol)
132 changes: 2 additions & 130 deletions dpctl/tests/elementwise/test_trigonometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import os
import re

import numpy as np
import pytest
from numpy.testing import assert_allclose
Expand All @@ -34,7 +30,6 @@
(np.arctan, dpt.atan),
]
_all_funcs = _trig_funcs + _inv_trig_funcs
_dpt_funcs = [t[1] for t in _all_funcs]


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
Expand All @@ -43,17 +38,10 @@ def test_trig_out_type(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

X = dpt.asarray(0, dtype=dtype, sycl_queue=q)
expected_dtype = np_call(np.array(0, dtype=dtype)).dtype
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
assert dpt_call(X).dtype == expected_dtype

X = dpt.asarray(0, dtype=dtype, sycl_queue=q)
x = dpt.asarray(0, dtype=dtype, sycl_queue=q)
expected_dtype = np_call(np.array(0, dtype=dtype)).dtype
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
Y = dpt.empty_like(X, dtype=expected_dtype)
dpt_call(X, out=Y)
assert_allclose(dpt.asnumpy(dpt_call(X)), dpt.asnumpy(Y))
assert dpt_call(x).dtype == expected_dtype


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
Expand Down Expand Up @@ -127,78 +115,6 @@ def test_trig_complex_contig(np_call, dpt_call, dtype):
assert_allclose(dpt.asnumpy(Z), expected, atol=tol, rtol=tol)


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("usm_type", ["device", "shared", "host"])
def test_trig_usm_type(np_call, dpt_call, usm_type):
q = get_queue_or_skip()

arg_dt = np.dtype("f4")
input_shape = (10, 10, 10, 10)
X = dpt.empty(input_shape, dtype=arg_dt, usm_type=usm_type, sycl_queue=q)
if np_call in _trig_funcs:
X[..., 0::2] = np.pi / 6
X[..., 1::2] = np.pi / 3
if np_call == np.arctan:
X[..., 0::2] = -2.2
X[..., 1::2] = 3.3
else:
X[..., 0::2] = -0.3
X[..., 1::2] = 0.7

Y = dpt_call(X)
assert Y.usm_type == X.usm_type
assert Y.sycl_queue == X.sycl_queue
assert Y.flags.c_contiguous

expected_Y = np_call(dpt.asnumpy(X))
tol = 8 * dpt.finfo(Y.dtype).resolution
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", _all_dtypes)
def test_trig_order(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

arg_dt = np.dtype(dtype)
input_shape = (4, 4, 4, 4)
X = dpt.empty(input_shape, dtype=arg_dt, sycl_queue=q)
if np_call in _trig_funcs:
X[..., 0::2] = np.pi / 6
X[..., 1::2] = np.pi / 3
if np_call == np.arctan:
X[..., 0::2] = -2.2
X[..., 1::2] = 3.3
else:
X[..., 0::2] = -0.3
X[..., 1::2] = 0.7

for perms in itertools.permutations(range(4)):
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
expected_Y = np_call(dpt.asnumpy(U))
for ord in ["C", "F", "A", "K"]:
Y = dpt_call(U, order=ord)
tol = 8 * max(
dpt.finfo(Y.dtype).resolution,
np.finfo(expected_Y.dtype).resolution,
)
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)


@pytest.mark.parametrize("callable", _dpt_funcs)
@pytest.mark.parametrize("dtype", _all_dtypes)
def test_trig_error_dtype(callable, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

x = dpt.zeros(5, dtype=dtype)
y = dpt.empty_like(x, dtype="int16")
with pytest.raises(ValueError) as excinfo:
callable(x, out=y)
assert re.match("Output array of type.*is needed", str(excinfo.value))


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
def test_trig_real_strided(np_call, dpt_call, dtype):
Expand Down Expand Up @@ -298,47 +214,3 @@ def test_trig_real_special_cases(np_call, dpt_call, dtype):
tol = 8 * dpt.finfo(dtype).resolution
Y = dpt_call(yf)
assert_allclose(dpt.asnumpy(Y), Y_np, atol=tol, rtol=tol)


@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_trig_complex_special_cases_conj_property(np_call, dpt_call, dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0]
xc = [complex(*val) for val in itertools.product(x, repeat=2)]

Xc_np = np.array(xc, dtype=dtype)
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)

tol = 50 * dpt.finfo(dtype).resolution
Y = dpt_call(Xc)
Yc = dpt_call(dpt.conj(Xc))

dpt.allclose(Y, dpt.conj(Yc), atol=tol, rtol=tol)


@pytest.mark.skipif(
os.name != "posix", reason="Known to fail on Windows due to bug in NumPy"
)
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
@pytest.mark.parametrize("dtype", ["c8", "c16"])
def test_trig_complex_special_cases(np_call, dpt_call, dtype):

q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

x = [np.nan, np.inf, -np.inf, +0.0, -0.0, +1.0, -1.0]
xc = [complex(*val) for val in itertools.product(x, repeat=2)]

Xc_np = np.array(xc, dtype=dtype)
Xc = dpt.asarray(Xc_np, dtype=dtype, sycl_queue=q)

with np.errstate(all="ignore"):
Ynp = np_call(Xc_np)

tol = 50 * dpt.finfo(dtype).resolution
Y = dpt_call(Xc)
assert_allclose(dpt.asnumpy(dpt.real(Y)), np.real(Ynp), atol=tol, rtol=tol)
assert_allclose(dpt.asnumpy(dpt.imag(Y)), np.imag(Ynp), atol=tol, rtol=tol)
Loading