Skip to content

Commit 7b1172b

Browse files
authored
expansion multiply (#334)
* expansion multiply
1 parent 20cb4b0 commit 7b1172b

File tree

3 files changed

+67
-9
lines changed

3 files changed

+67
-9
lines changed

dpnp/backend.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ cpdef dparray dpnp_divide(dparray array1, dparray array2)
233233
cpdef dparray dpnp_hypot(dparray array1, dparray array2)
234234
cpdef dparray dpnp_maximum(dparray array1, dparray array2)
235235
cpdef dparray dpnp_minimum(dparray array1, dparray array2)
236-
cpdef dparray dpnp_multiply(dparray array1, dparray array2)
236+
cpdef dparray dpnp_multiply(dparray array1, array2)
237237
cpdef dparray dpnp_negative(dparray array1)
238238
cpdef dparray dpnp_power(dparray array1, dparray array2)
239239
cpdef dparray dpnp_remainder(dparray array1, dparray array2)

dpnp/backend_mathematical.pyx

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,61 @@ cpdef tuple dpnp_modf(dparray x1):
159159
return result1, result2
160160

161161

162-
cpdef dparray dpnp_multiply(dparray x1, dparray x2):
163-
return call_fptr_2in_1out(DPNP_FN_MULTIPLY, x1, x2, x1.shape)
162+
cpdef dparray dpnp_multiply(dparray x1, x2):
163+
x2_is_scalar = dpnp.isscalar(x2)
164+
165+
x1_dtype_ = x1.dtype
166+
x2_dtype_ = type(x2) if x2_is_scalar else x2.dtype
167+
168+
types_map = {float: dpnp.float64, int: dpnp.int64}
169+
x1_dtype = types_map.get(x1_dtype_, x1_dtype_)
170+
x2_dtype = types_map.get(x2_dtype_, x2_dtype_)
171+
172+
if x1_dtype == dpnp.float64:
173+
if x2_dtype == dpnp.float64:
174+
res_type = dpnp.float64
175+
elif x2_dtype == dpnp.float32:
176+
res_type = dpnp.float64
177+
elif x2_dtype == dpnp.int64:
178+
res_type = dpnp.float64
179+
elif x2_dtype == dpnp.int32:
180+
res_type = dpnp.float64
181+
elif x1_dtype == dpnp.float32:
182+
if x2_dtype == dpnp.float64:
183+
res_type = dpnp.float32
184+
elif x2_dtype == dpnp.float32:
185+
res_type = dpnp.float32
186+
elif x2_dtype == dpnp.int64:
187+
res_type = dpnp.float32
188+
elif x2_dtype == dpnp.int32:
189+
res_type = dpnp.float32
190+
elif x1_dtype == dpnp.int64:
191+
if x2_dtype == dpnp.float64:
192+
res_type = dpnp.float64
193+
elif x2_dtype == dpnp.float32:
194+
res_type = dpnp.float32
195+
elif x2_dtype == dpnp.int64:
196+
res_type = dpnp.int64
197+
elif x2_dtype == dpnp.int32:
198+
res_type = dpnp.int64
199+
elif x1_dtype == dpnp.int32:
200+
if x2_dtype == dpnp.float64:
201+
res_type = dpnp.float64
202+
elif x2_dtype == dpnp.float32:
203+
res_type = dpnp.float32
204+
elif x2_dtype == dpnp.int64:
205+
res_type = dpnp.int32
206+
elif x2_dtype == dpnp.int32:
207+
res_type = dpnp.int32
208+
209+
cdef dparray result = dparray(x1.shape, dtype=res_type)
210+
211+
if x2_is_scalar:
212+
for i in range(result.size):
213+
result[i] = x1[i] * x2
214+
return result
215+
else:
216+
return call_fptr_2in_1out(DPNP_FN_MULTIPLY, x1, x2, x1.shape)
164217

165218

166219
cpdef dpnp_nanprod(dparray x1):

dpnp/dpnp_iface_mathematical.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -684,14 +684,19 @@ def multiply(x1, x2, **kwargs):
684684
is_x1_dparray = isinstance(x1, dparray)
685685
is_x2_dparray = isinstance(x2, dparray)
686686

687-
if (not use_origin_backend(x1) and is_x1_dparray and is_x2_dparray and not kwargs):
688-
if (x1.size != x2.size):
689-
checker_throw_value_error("multiply", "size", x1.size, x2.size)
687+
is_x1_scalar = dpnp.isscalar(x1)
688+
is_x2_scalar = dpnp.isscalar(x2)
690689

691-
if (x1.shape != x2.shape):
692-
checker_throw_value_error("multiply", "shape", x1.shape, x2.shape)
690+
if (not use_origin_backend(x1) and (is_x1_dparray or is_x1_scalar)) and \
691+
(not use_origin_backend(x2) and (is_x2_dparray or is_x2_scalar)) and \
692+
not (is_x1_scalar and is_x2_scalar) and not kwargs:
693693

694-
return dpnp_multiply(x1, x2)
694+
if is_x1_scalar:
695+
return dpnp_multiply(x2, x1)
696+
else:
697+
if is_x1_dparray and is_x2_dparray:
698+
if (x1.size == x2.size) and (x1.shape == x2.shape):
699+
return dpnp_multiply(x1, x2)
695700

696701
return call_origin(numpy.multiply, x1, x2, **kwargs)
697702

0 commit comments

Comments
 (0)