Skip to content

Commit 1be8a5c

Browse files
Adding to multiply func support of complex inputs (#1126)
1 parent f44bd53 commit 1be8a5c

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

dpnp/backend/kernels/dpnp_krnl_elemwise.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,34 @@ static void func_map_init_elemwise_2arg_3type(func_map_t& fmap)
817817
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_LNG] = {eft_DBL, (void*)dpnp_multiply_c<double, double, int64_t>};
818818
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_FLT] = {eft_DBL, (void*)dpnp_multiply_c<double, double, float>};
819819
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_multiply_c<double, double, double>};
820+
821+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C64][eft_BLN] = {
822+
eft_C64, (void*)dpnp_multiply_c<std::complex<float>, std::complex<float>, bool>};
823+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C64][eft_INT] = {
824+
eft_C64, (void*)dpnp_multiply_c<std::complex<float>, std::complex<float>, int32_t>};
825+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C64][eft_LNG] = {
826+
eft_C64, (void*)dpnp_multiply_c<std::complex<float>, std::complex<float>, int64_t>};
827+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C64][eft_FLT] = {
828+
eft_C64, (void*)dpnp_multiply_c<std::complex<float>, std::complex<float>, float>};
829+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C64][eft_DBL] = {
830+
eft_C128, (void*)dpnp_multiply_c<std::complex<double>, std::complex<float>, double>};
831+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C64][eft_C64] = {
832+
eft_C64, (void*)dpnp_multiply_c<std::complex<float>, std::complex<float>, std::complex<float>>};
833+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C64][eft_C128] = {
834+
eft_C128, (void*)dpnp_multiply_c<std::complex<double>, std::complex<float>, std::complex<double>>};
835+
836+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C128][eft_BLN] = {
837+
eft_C128, (void*)dpnp_multiply_c<std::complex<double>, std::complex<double>, bool>};
838+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C128][eft_INT] = {
839+
eft_C128, (void*)dpnp_multiply_c<std::complex<double>, std::complex<double>, int32_t>};
840+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C128][eft_LNG] = {
841+
eft_C128, (void*)dpnp_multiply_c<std::complex<double>, std::complex<double>, int64_t>};
842+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C128][eft_FLT] = {
843+
eft_C128, (void*)dpnp_multiply_c<std::complex<double>, std::complex<double>, float>};
844+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C128][eft_DBL] = {
845+
eft_C128, (void*)dpnp_multiply_c<std::complex<double>, std::complex<double>, double>};
846+
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C128][eft_C64] = {
847+
eft_C128, (void*)dpnp_multiply_c<std::complex<double>, std::complex<double>, std::complex<float>>};
820848
fmap[DPNPFuncName::DPNP_FN_MULTIPLY][eft_C128][eft_C128] = {
821849
eft_C128, (void*)dpnp_multiply_c<std::complex<double>, std::complex<double>, std::complex<double>>};
822850

tests/test_mathematical.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,32 +54,12 @@ def test_diff(array):
5454
numpy.testing.assert_allclose(expected, result)
5555

5656

57-
@pytest.mark.parametrize("data",
58-
[[[1 + 1j, -2j], [3 - 3j, 4j]]],
59-
ids=['[[1+1j, -2j], [3-3j, 4j]]'])
60-
def test_multiply_complex(data):
61-
np_a = numpy.array(data)
62-
dpnp_a = dpnp.array(data)
63-
64-
result = dpnp.multiply(dpnp_a, dpnp_a)
65-
expected = numpy.multiply(np_a, np_a)
66-
numpy.testing.assert_array_equal(result, expected)
67-
68-
result = dpnp.multiply(dpnp_a, 0.5j)
69-
expected = numpy.multiply(np_a, 0.5j)
70-
numpy.testing.assert_array_equal(result, expected)
71-
72-
result = dpnp.multiply(0.5j, dpnp_a)
73-
expected = numpy.multiply(0.5j, np_a)
74-
numpy.testing.assert_array_equal(result, expected)
75-
76-
7757
@pytest.mark.parametrize("dtype1",
78-
[numpy.bool_, numpy.float64, numpy.float32, numpy.int64, numpy.int32],
79-
ids=['numpy.bool_', 'numpy.float64', 'numpy.float32', 'numpy.int64', 'numpy.int32'])
58+
[numpy.bool_, numpy.float64, numpy.float32, numpy.int64, numpy.int32, numpy.complex64, numpy.complex128],
59+
ids=['numpy.bool_', 'numpy.float64', 'numpy.float32', 'numpy.int64', 'numpy.int32', 'numpy.complex64', 'numpy.complex128'])
8060
@pytest.mark.parametrize("dtype2",
81-
[numpy.bool_, numpy.float64, numpy.float32, numpy.int64, numpy.int32],
82-
ids=['numpy.bool_', 'numpy.float64', 'numpy.float32', 'numpy.int64', 'numpy.int32'])
61+
[numpy.bool_, numpy.float64, numpy.float32, numpy.int64, numpy.int32, numpy.complex64, numpy.complex128],
62+
ids=['numpy.bool_', 'numpy.float64', 'numpy.float32', 'numpy.int64', 'numpy.int32', 'numpy.complex64', 'numpy.complex128'])
8363
@pytest.mark.parametrize("data",
8464
[[[1, 2], [3, 4]]],
8565
ids=['[[1, 2], [3, 4]]'])

0 commit comments

Comments
 (0)