Skip to content

Commit 04aa878

Browse files
committed
address comments
1 parent 7379c61 commit 04aa878

File tree

4 files changed

+51
-36
lines changed

4 files changed

+51
-36
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -420,38 +420,6 @@ def dpnp_ceil(x, out=None, order="K"):
420420
"""
421421

422422

423-
def _call_cos(src, dst, sycl_queue, depends=None):
424-
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
425-
426-
if depends is None:
427-
depends = []
428-
429-
if vmi._mkl_cos_to_call(sycl_queue, src, dst):
430-
# call pybind11 extension for cos() function from OneMKL VM
431-
return vmi._cos(sycl_queue, src, dst, depends)
432-
return ti._cos(src, dst, sycl_queue, depends)
433-
434-
435-
cos_func = UnaryElementwiseFunc(
436-
"cos", ti._cos_result_type, _call_cos, _cos_docstring
437-
)
438-
439-
440-
def dpnp_cos(x, out=None, order="K"):
441-
"""
442-
Invokes cos() function from pybind11 extension of OneMKL VM if possible.
443-
444-
Otherwise fully relies on dpctl.tensor implementation for cos() function.
445-
"""
446-
447-
# dpctl.tensor only works with usm_ndarray
448-
x1_usm = dpnp.get_usm_ndarray(x)
449-
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
450-
451-
res_usm = cos_func(x1_usm, out=out_usm, order=order)
452-
return dpnp_array._create_from_usm_ndarray(res_usm)
453-
454-
455423
_conj_docstring = """
456424
conj(x, out=None, order='K')
457425
@@ -490,6 +458,36 @@ def _call_conj(src, dst, sycl_queue, depends=None):
490458
)
491459

492460

461+
def dpnp_cos(x, out=None, order="K"):
462+
"""
463+
Invokes cos() function from pybind11 extension of OneMKL VM if possible.
464+
465+
Otherwise fully relies on dpctl.tensor implementation for cos() function.
466+
467+
"""
468+
469+
def _call_cos(src, dst, sycl_queue, depends=None):
470+
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
471+
472+
if depends is None:
473+
depends = []
474+
475+
if vmi._mkl_cos_to_call(sycl_queue, src, dst):
476+
# call pybind11 extension for cos() function from OneMKL VM
477+
return vmi._cos(sycl_queue, src, dst, depends)
478+
return ti._cos(src, dst, sycl_queue, depends)
479+
480+
# dpctl.tensor only works with usm_ndarray
481+
x1_usm = dpnp.get_usm_ndarray(x)
482+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
483+
484+
func = UnaryElementwiseFunc(
485+
"cos", ti._cos_result_type, _call_cos, _cos_docstring
486+
)
487+
res_usm = func(x1_usm, out=out_usm, order=order)
488+
return dpnp_array._create_from_usm_ndarray(res_usm)
489+
490+
493491
def dpnp_conj(x, out=None, order="K"):
494492
"""
495493
Invokes conj() function from pybind11 extension of OneMKL VM if possible.

dpnp/dpnp_array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ def conj(self):
622622
623623
"""
624624

625-
if not dpnp.issubsctype(self.dtype, dpnp.complex_):
625+
if not dpnp.issubsctype(self.dtype, dpnp.complexfloating):
626626
return self
627627
else:
628628
return dpnp.conjugate(self)
@@ -635,7 +635,7 @@ def conjugate(self):
635635
636636
"""
637637

638-
if not dpnp.issubsctype(self.dtype, dpnp.complex_):
638+
if not dpnp.issubsctype(self.dtype, dpnp.complexfloating):
639639
return self
640640
else:
641641
return dpnp.conjugate(self)

dpnp/dpnp_iface_mathematical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def conjugate(
397397
Examples
398398
--------
399399
>>> import dpnp as np
400-
>>> np.conjugate(1+2j)
400+
>>> np.conjugate(np.array(1+2j))
401401
(1-2j)
402402
403403
>>> x = np.eye(2) + 1j * np.eye(2)

tests/third_party/cupy/math_tests/test_arithmetic.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
testing.product(
2727
{
2828
"nargs": [1],
29-
"name": ["reciprocal", "angle"],
29+
"name": [
30+
"reciprocal",
31+
"conj",
32+
"conjugate",
33+
"angle",
34+
],
3035
}
3136
)
3237
+ testing.product(
@@ -68,6 +73,18 @@ def test_raises_with_numpy_input(self):
6873
@testing.parameterize(
6974
*(
7075
testing.product(
76+
{
77+
"arg1": (
78+
[
79+
testing.shaped_arange((2, 3), numpy, dtype=d)
80+
for d in all_types
81+
]
82+
+ [0, 0.0j, 0j, 2, 2.0, 2j, True, False]
83+
),
84+
"name": ["conj", "conjugate"],
85+
}
86+
)
87+
+ testing.product(
7188
{
7289
"arg1": (
7390
[

0 commit comments

Comments
 (0)