Skip to content

Commit 13816bd

Browse files
authored
Follow recommendation on the interaction with numpy.ndarray in binary ops (#2266)
The PR proposes to follow NEP-13 [recommendations](https://numpy.org/neps/nep-0013-ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations) on how to interact with `numpy.ndarray` in binary the operations. It will set `__array_ufunc__ = None` which means that dpnp implements Python binary operations freely and so `numpy.ufuncs` called on this argument will raise `TypeError`: ```python a = numpy.ones(10) b = dpnp.ones(10) a += b --------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[9], line 1 ----> 1 a += b TypeError: operand 'dpnp_array' does not support ufuncs (__array_ufunc__=None) ``` And an elementwise operation with `numpy.ndarray` will cause an explicit exception in dpnp: ```python a + b --------------------------------------------------------------------------- TypeError Traceback (most recent call last) Cell In[4], line 1 ----> 1 a + b File ~/code/dpnp/dpnp/dpnp_array.py:518, in dpnp_array.__radd__(self, other) 516 def __radd__(self, other): 517 """Return ``value+self``.""" --> 518 return dpnp.add(other, self) File ~/code/dpnp/dpnp/dpnp_algo/dpnp_elementwise_common.py:314, in DPNPBinaryFunc.__call__(self, x1, x2, out, where, order, dtype, subok, **kwargs) 303 def __call__( 304 self, 305 x1, (...) 312 **kwargs, 313 ): --> 314 dpnp.check_supported_arrays_type( 315 x1, x2, scalar_type=True, all_scalars=False 316 ) 317 if kwargs: 318 raise NotImplementedError( 319 f"Requested function={self.name_} with kwargs={kwargs} " 320 "isn't currently supported." 321 ) File ~/code/dpnp/dpnp/dpnp_iface.py:400, in check_supported_arrays_type(scalar_type, all_scalars, *arrays) 397 if scalar_type and dpnp.isscalar(a): 398 continue --> 400 raise TypeError( 401 f"An array must be any of supported type, but got {type(a)}" 402 ) 404 if len(arrays) > 0 and not (all_scalars or any_is_array): 405 raise TypeError( 406 "At least one input must be of supported array type, " 407 "but got all scalars." 408 ) TypeError: An array must be any of supported type, but got <class 'numpy.ndarray'> ``` Previously it works as in a way: ```python a = numpy.ones(10) b = dpnp.ones(10) a + b # Out: # array([array(2.), array(2.), array(2.), array(2.), array(2.), array(2.), # array(2.), array(2.), array(2.), array(2.)], dtype=object) ``` Note, some updates in tests from #2260 have been mapped to that PR.
1 parent 7ce5219 commit 13816bd

File tree

6 files changed

+52
-11
lines changed

6 files changed

+52
-11
lines changed

dpnp/dpnp_array.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,9 @@ def __and__(self, other):
192192
# '__array_prepare__',
193193
# '__array_priority__',
194194
# '__array_struct__',
195-
# '__array_ufunc__',
195+
196+
__array_ufunc__ = None
197+
196198
# '__array_wrap__',
197199

198200
def __array_namespace__(self, /, *, api_version=None):

dpnp/tests/helper.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,20 @@ def get_all_dtypes(
161161
return dtypes
162162

163163

164+
def get_array(xp, a):
165+
"""
166+
Cast input array `a` to a type supported by `xp` interface.
167+
168+
Implicit conversion of either DPNP or DPCTL array to a NumPy array is not
169+
allowed. Input array has to be explicitly casted with `asnumpy` function.
170+
171+
"""
172+
173+
if xp is numpy and dpnp.is_supported_array_type(a):
174+
return dpnp.asnumpy(a)
175+
return a
176+
177+
164178
def generate_random_numpy_array(
165179
shape,
166180
dtype=None,

dpnp/tests/test_arraycreation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .helper import (
1818
assert_dtype_allclose,
1919
get_all_dtypes,
20+
get_array,
2021
)
2122
from .third_party.cupy import testing
2223

@@ -768,7 +769,7 @@ def test_space_numpy_dtype(func, start_dtype, stop_dtype):
768769
],
769770
)
770771
def test_linspace_arrays(start, stop):
771-
func = lambda xp: xp.linspace(start, stop, 10)
772+
func = lambda xp: xp.linspace(get_array(xp, start), get_array(xp, stop), 10)
772773
assert func(numpy).shape == func(dpnp).shape
773774

774775

dpnp/tests/test_linalg.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,7 +1935,7 @@ def test_matrix_rank(self, data, dtype):
19351935

19361936
np_rank = numpy.linalg.matrix_rank(a)
19371937
dp_rank = dpnp.linalg.matrix_rank(a_dp)
1938-
assert np_rank == dp_rank
1938+
assert dp_rank.asnumpy() == np_rank
19391939

19401940
@pytest.mark.parametrize("dtype", get_all_dtypes())
19411941
@pytest.mark.parametrize(
@@ -1953,7 +1953,7 @@ def test_matrix_rank_hermitian(self, data, dtype):
19531953

19541954
np_rank = numpy.linalg.matrix_rank(a, hermitian=True)
19551955
dp_rank = dpnp.linalg.matrix_rank(a_dp, hermitian=True)
1956-
assert np_rank == dp_rank
1956+
assert dp_rank.asnumpy() == np_rank
19571957

19581958
@pytest.mark.parametrize(
19591959
"high_tol, low_tol",
@@ -1986,15 +1986,15 @@ def test_matrix_rank_tolerance(self, high_tol, low_tol):
19861986
dp_rank_high_tol = dpnp.linalg.matrix_rank(
19871987
a_dp, hermitian=True, tol=dp_high_tol
19881988
)
1989-
assert np_rank_high_tol == dp_rank_high_tol
1989+
assert dp_rank_high_tol.asnumpy() == np_rank_high_tol
19901990

19911991
np_rank_low_tol = numpy.linalg.matrix_rank(
19921992
a, hermitian=True, tol=low_tol
19931993
)
19941994
dp_rank_low_tol = dpnp.linalg.matrix_rank(
19951995
a_dp, hermitian=True, tol=dp_low_tol
19961996
)
1997-
assert np_rank_low_tol == dp_rank_low_tol
1997+
assert dp_rank_low_tol.asnumpy() == np_rank_low_tol
19981998

19991999
# rtol kwarg was added in numpy 2.0
20002000
@testing.with_requires("numpy>=2.0")
@@ -2807,15 +2807,14 @@ def check_decomposition(
28072807
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
28082808
dpnp_diag_s[..., i, i] = dp_s[..., i]
28092809
reconstructed = dpnp.dot(dp_u, dpnp.dot(dpnp_diag_s, dp_vt))
2810-
# TODO: use assert dpnp.allclose() inside check_decomposition()
2811-
# when it will support complex dtypes
2812-
assert_allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)
2810+
2811+
assert dpnp.allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)
28132812

28142813
assert_allclose(dp_s, np_s, rtol=tol, atol=1e-03)
28152814

28162815
if compute_vt:
28172816
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
2818-
if np_u[..., 0, i] * dp_u[..., 0, i] < 0:
2817+
if np_u[..., 0, i] * dpnp.asnumpy(dp_u[..., 0, i]) < 0:
28192818
np_u[..., :, i] = -np_u[..., :, i]
28202819
np_vt[..., i, :] = -np_vt[..., i, :]
28212820
for i in range(numpy.count_nonzero(np_s > tol)):

dpnp/tests/test_manipulation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .helper import (
1616
assert_dtype_allclose,
1717
get_all_dtypes,
18+
get_array,
1819
get_complex_dtypes,
1920
get_float_complex_dtypes,
2021
get_float_dtypes,
@@ -1232,7 +1233,10 @@ def test_axes(self):
12321233
def test_axes_type(self, axes):
12331234
a = numpy.ones((50, 40, 3))
12341235
ia = dpnp.array(a)
1235-
assert_equal(dpnp.rot90(ia, axes=axes), numpy.rot90(a, axes=axes))
1236+
assert_equal(
1237+
dpnp.rot90(ia, axes=axes),
1238+
numpy.rot90(a, axes=get_array(numpy, axes)),
1239+
)
12361240

12371241
def test_rotation_axes(self):
12381242
a = numpy.arange(8).reshape((2, 2, 2))

dpnp/tests/test_ndarray.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,27 @@ def test_wrong_api_version(self, api_version):
150150
)
151151

152152

153+
class TestArrayUfunc:
154+
def test_add(self):
155+
a = numpy.ones(10)
156+
b = dpnp.ones(10)
157+
msg = "An array must be any of supported type"
158+
159+
with assert_raises_regex(TypeError, msg):
160+
a + b
161+
162+
with assert_raises_regex(TypeError, msg):
163+
b + a
164+
165+
def test_add_inplace(self):
166+
a = numpy.ones(10)
167+
b = dpnp.ones(10)
168+
with assert_raises_regex(
169+
TypeError, "operand 'dpnp_array' does not support ufuncs"
170+
):
171+
a += b
172+
173+
153174
class TestItem:
154175
@pytest.mark.parametrize("args", [2, 7, (1, 2), (2, 0)])
155176
def test_basic(self, args):

0 commit comments

Comments
 (0)