Skip to content

Commit 6cdca2c

Browse files
KsanaKozlovashssf
andauthored
add parameter out to exp, log and cos funcs (#775)
* add parameter out to exp, log and cos funcs * change call_fptr_1in_1out function * fix typo * remove keyword parameters * Update dpnp_algo.pyx Co-authored-by: Sergey Shalnov <shssf@users.noreply.github.com>
1 parent 7e93609 commit 6cdca2c

File tree

5 files changed

+190
-60
lines changed

5 files changed

+190
-60
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -336,19 +336,19 @@ cpdef dparray dpnp_arcsinh(dpnp_descriptor array1)
336336
cpdef dparray dpnp_arctan(dpnp_descriptor array1)
337337
cpdef dparray dpnp_arctanh(dpnp_descriptor array1)
338338
cpdef dparray dpnp_cbrt(dpnp_descriptor array1)
339-
cpdef dparray dpnp_cos(dpnp_descriptor array1)
339+
cpdef dparray dpnp_cos(dpnp_descriptor array1, dparray out)
340340
cpdef dparray dpnp_cosh(dpnp_descriptor array1)
341341
cpdef dparray dpnp_degrees(dpnp_descriptor array1)
342-
cpdef dparray dpnp_exp(dpnp_descriptor array1)
342+
cpdef dparray dpnp_exp(dpnp_descriptor array1, dparray out)
343343
cpdef dparray dpnp_exp2(dpnp_descriptor array1)
344344
cpdef dparray dpnp_expm1(dpnp_descriptor array1)
345-
cpdef dparray dpnp_log(dpnp_descriptor array1)
345+
cpdef dparray dpnp_log(dpnp_descriptor array1, dparray out)
346346
cpdef dparray dpnp_log10(dpnp_descriptor array1)
347347
cpdef dparray dpnp_log1p(dpnp_descriptor array1)
348348
cpdef dparray dpnp_log2(dpnp_descriptor array1)
349349
cpdef dparray dpnp_radians(dpnp_descriptor array1)
350350
cpdef dparray dpnp_recip(dpnp_descriptor array1)
351-
cpdef dparray dpnp_sin(dpnp_descriptor array1, dparray out=*)
351+
cpdef dparray dpnp_sin(dpnp_descriptor array1, dparray out)
352352
cpdef dparray dpnp_sinh(dpnp_descriptor array1)
353353
cpdef dparray dpnp_sqrt(dpnp_descriptor array1)
354354
cpdef dparray dpnp_square(dpnp_descriptor array1)

dpnp/dpnp_algo/dpnp_algo.pyx

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,16 +254,27 @@ cdef dparray call_fptr_1out(DPNPFuncName fptr_name, dparray_shape_type result_sh
254254
return result
255255

256256

257-
cdef dparray call_fptr_1in_1out(DPNPFuncName fptr_name, utils.dpnp_descriptor x1, dparray_shape_type result_shape):
257+
cdef dparray call_fptr_1in_1out(DPNPFuncName fptr_name, utils.dpnp_descriptor x1, dparray_shape_type result_shape, dparray out=None, func_name=None):
258258

259259
""" Convert string type names (dparray.dtype) to C enum DPNPFuncType """
260260
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
261261

262262
""" get the FPTR data structure """
263263
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, param1_type, param1_type)
264+
265+
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
264266

265-
""" Create result array with type given by FPTR data """
266-
cdef dparray result = utils.create_output_array(result_shape, kernel_data.return_type, None)
267+
cdef dparray result
268+
269+
if out is None:
270+
""" Create result array with type given by FPTR data """
271+
result = utils.create_output_array(result_shape, kernel_data.return_type, None)
272+
else:
273+
if out.dtype != result_type:
274+
utils.checker_throw_value_error(func_name, 'out.dtype', out.dtype, result_type)
275+
if out.shape != result_shape:
276+
utils.checker_throw_value_error(func_name, 'out.shape', out.shape, result_shape)
277+
result = out
267278

268279
cdef fptr_1in_1out_t func = <fptr_1in_1out_t > kernel_data.ptr
269280
""" Call FPTR function """

dpnp/dpnp_algo/dpnp_algo_trigonometric.pyx

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ cpdef dparray dpnp_cbrt(utils.dpnp_descriptor x1):
9292
return call_fptr_1in_1out(DPNP_FN_CBRT, x1, x1.shape)
9393

9494

95-
cpdef dparray dpnp_cos(utils.dpnp_descriptor x1):
96-
return call_fptr_1in_1out(DPNP_FN_COS, x1, x1.shape)
95+
cpdef dparray dpnp_cos(utils.dpnp_descriptor x1, dparray out):
96+
return call_fptr_1in_1out(DPNP_FN_COS, x1, x1.shape, out=out, func_name='cos')
9797

9898

9999
cpdef dparray dpnp_cosh(utils.dpnp_descriptor x1):
@@ -104,8 +104,8 @@ cpdef dparray dpnp_degrees(utils.dpnp_descriptor x1):
104104
return call_fptr_1in_1out(DPNP_FN_DEGREES, x1, x1.shape)
105105

106106

107-
cpdef dparray dpnp_exp(utils.dpnp_descriptor x1):
108-
return call_fptr_1in_1out(DPNP_FN_EXP, x1, x1.shape)
107+
cpdef dparray dpnp_exp(utils.dpnp_descriptor x1, dparray out):
108+
return call_fptr_1in_1out(DPNP_FN_EXP, x1, x1.shape, out=out, func_name='exp')
109109

110110

111111
cpdef dparray dpnp_exp2(utils.dpnp_descriptor x1):
@@ -116,8 +116,8 @@ cpdef dparray dpnp_expm1(utils.dpnp_descriptor x1):
116116
return call_fptr_1in_1out(DPNP_FN_EXPM1, x1, x1.shape)
117117

118118

119-
cpdef dparray dpnp_log(utils.dpnp_descriptor x1):
120-
return call_fptr_1in_1out(DPNP_FN_LOG, x1, x1.shape)
119+
cpdef dparray dpnp_log(utils.dpnp_descriptor x1, dparray out):
120+
return call_fptr_1in_1out(DPNP_FN_LOG, x1, x1.shape, out=out, func_name='log')
121121

122122

123123
cpdef dparray dpnp_log10(utils.dpnp_descriptor x1):
@@ -140,32 +140,8 @@ cpdef dparray dpnp_radians(utils.dpnp_descriptor x1):
140140
return call_fptr_1in_1out(DPNP_FN_RADIANS, x1, x1.shape)
141141

142142

143-
cpdef dparray dpnp_sin(utils.dpnp_descriptor x1, dparray out=None):
144-
145-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
146-
147-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_SIN, param1_type, param1_type)
148-
149-
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
150-
151-
shape_result = x1.shape
152-
153-
cdef dparray result
154-
155-
if out is not None:
156-
if out.dtype != result_type:
157-
utils.checker_throw_value_error('sin', 'out.dtype', out.dtype, result_type)
158-
if out.shape != shape_result:
159-
utils.checker_throw_value_error('sin', 'out.shape', out.shape, shape_result)
160-
result = out
161-
else:
162-
result = dparray(shape_result, dtype=result_type)
163-
164-
cdef fptr_1in_1out_t func = <fptr_1in_1out_t > kernel_data.ptr
165-
166-
func(x1.get_data(), result.get_data(), x1.size)
167-
168-
return result
143+
cpdef dparray dpnp_sin(utils.dpnp_descriptor x1, dparray out):
144+
return call_fptr_1in_1out(DPNP_FN_SIN, x1, x1.shape, out=out, func_name='sin')
169145

170146

171147
cpdef dparray dpnp_sinh(utils.dpnp_descriptor x1):

dpnp/dpnp_iface_trigonometric.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def arctan2(x1, x2, dtype=None, out=None, where=True, **kwargs):
371371
return call_origin(numpy.arctan2, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
372372

373373

374-
def cos(x1):
374+
def cos(x1, out=None, **kwargs):
375375
"""
376376
Trigonometric cosine, element-wise.
377377
@@ -395,9 +395,9 @@ def cos(x1):
395395

396396
x1_desc = dpnp.get_dpnp_descriptor(x1)
397397
if x1_desc:
398-
return dpnp_cos(x1_desc)
398+
return dpnp_cos(x1_desc, out)
399399

400-
return call_origin(numpy.cos, x1, **kwargs)
400+
return call_origin(numpy.cos, x1, out=out, **kwargs)
401401

402402

403403
def cosh(x1):
@@ -479,7 +479,7 @@ def degrees(x1):
479479
return call_origin(numpy.degrees, x1, **kwargs)
480480

481481

482-
def exp(x1):
482+
def exp(x1, out=None, **kwargs):
483483
"""
484484
Trigonometric exponent, element-wise.
485485
@@ -507,9 +507,9 @@ def exp(x1):
507507

508508
x1_desc = dpnp.get_dpnp_descriptor(x1)
509509
if x1_desc:
510-
return dpnp_exp(x1_desc)
510+
return dpnp_exp(x1_desc, out)
511511

512-
return call_origin(numpy.exp, x1)
512+
return call_origin(numpy.exp, x1, out=out, **kwargs)
513513

514514

515515
def exp2(x1):
@@ -630,7 +630,7 @@ def hypot(x1, x2, dtype=None, out=None, where=True, **kwargs):
630630
return call_origin(numpy.hypot, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
631631

632632

633-
def log(x1):
633+
def log(x1, out=None, **kwargs):
634634
"""
635635
Trigonometric logarithm, element-wise.
636636
@@ -662,9 +662,9 @@ def log(x1):
662662

663663
x1_desc = dpnp.get_dpnp_descriptor(x1)
664664
if x1_desc:
665-
return dpnp_log(x1_desc)
665+
return dpnp_log(x1_desc, out)
666666

667-
return call_origin(numpy.log, x1)
667+
return call_origin(numpy.log, x1, out=out, **kwargs)
668668

669669

670670
def log10(x1):
@@ -876,7 +876,7 @@ def sin(x1, out=None, **kwargs):
876876

877877
x1_desc = dpnp.get_dpnp_descriptor(x1)
878878
if x1_desc:
879-
return dpnp_sin(x1_desc, out=out)
879+
return dpnp_sin(x1_desc, out)
880880

881881
return call_origin(numpy.sin, x1, out=out, **kwargs)
882882

tests/test_umath.py

Lines changed: 154 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,160 @@ def test_umaths(test_cases):
7373
numpy.testing.assert_allclose(result, expected, rtol=1e-6)
7474

7575

76-
def test_sin():
77-
array_data = numpy.arange(10)
78-
out = numpy.empty(10, dtype=numpy.float64)
76+
class TestSin:
7977

80-
# DPNP
81-
dp_array = dpnp.array(array_data, dtype=dpnp.float64)
82-
dp_out = dpnp.array(out, dtype=dpnp.float64)
83-
result = dpnp.sin(dp_array, out=dp_out)
78+
def test_sin_ordinary(self):
79+
array_data = numpy.arange(10)
80+
out = numpy.empty(10, dtype=numpy.float64)
8481

85-
# original
86-
np_array = numpy.array(array_data, dtype=numpy.float64)
87-
expected = numpy.sin(np_array, out=out)
82+
# DPNP
83+
dp_array = dpnp.array(array_data, dtype=dpnp.float64)
84+
dp_out = dpnp.array(out, dtype=dpnp.float64)
85+
result = dpnp.sin(dp_array, out=dp_out)
86+
87+
# original
88+
np_array = numpy.array(array_data, dtype=numpy.float64)
89+
expected = numpy.sin(np_array, out=out)
90+
91+
numpy.testing.assert_array_equal(expected, result)
92+
93+
@pytest.mark.parametrize("dtype",
94+
[numpy.float32, numpy.int64, numpy.int32],
95+
ids=['numpy.float32', 'numpy.int64', 'numpy.int32'])
96+
def test_invalid_dtype(self, dtype):
97+
98+
dp_array = dpnp.arange(10, dtype=dpnp.float64)
99+
dp_out = dpnp.empty(10, dtype=dtype)
100+
101+
with pytest.raises(ValueError):
102+
dpnp.sin(dp_array, out=dp_out)
103+
104+
@pytest.mark.parametrize("shape",
105+
[(0,), (15, ), (2,2)],
106+
ids=['(0,)', '(15, )', '(2,2)'])
107+
def test_invalid_shape(self, shape):
108+
109+
dp_array = dpnp.arange(10, dtype=dpnp.float64)
110+
dp_out = dpnp.empty(shape, dtype=dpnp.float64)
111+
112+
with pytest.raises(ValueError):
113+
dpnp.sin(dp_array, out=dp_out)
114+
115+
class TestCos:
116+
117+
def test_cos(self):
118+
array_data = numpy.arange(10)
119+
out = numpy.empty(10, dtype=numpy.float64)
120+
121+
# DPNP
122+
dp_array = dpnp.array(array_data, dtype=dpnp.float64)
123+
dp_out = dpnp.array(out, dtype=dpnp.float64)
124+
result = dpnp.cos(dp_array, out=dp_out)
125+
126+
# original
127+
np_array = numpy.array(array_data, dtype=numpy.float64)
128+
expected = numpy.cos(np_array, out=out)
129+
130+
numpy.testing.assert_array_equal(expected, result)
131+
132+
@pytest.mark.parametrize("dtype",
133+
[numpy.float32, numpy.int64, numpy.int32],
134+
ids=['numpy.float32', 'numpy.int64', 'numpy.int32'])
135+
def test_invalid_dtype(self, dtype):
136+
137+
dp_array = dpnp.arange(10, dtype=dpnp.float64)
138+
dp_out = dpnp.empty(10, dtype=dtype)
139+
140+
with pytest.raises(ValueError):
141+
dpnp.cos(dp_array, out=dp_out)
142+
143+
@pytest.mark.parametrize("shape",
144+
[(0,), (15, ), (2,2)],
145+
ids=['(0,)', '(15, )', '(2,2)'])
146+
def test_invalid_shape(self, shape):
147+
148+
dp_array = dpnp.arange(10, dtype=dpnp.float64)
149+
dp_out = dpnp.empty(shape, dtype=dpnp.float64)
150+
151+
with pytest.raises(ValueError):
152+
dpnp.cos(dp_array, out=dp_out)
153+
154+
155+
class TestsLog:
156+
157+
def test_log(self):
158+
array_data = numpy.arange(10)
159+
out = numpy.empty(10, dtype=numpy.float64)
160+
161+
# DPNP
162+
dp_array = dpnp.array(array_data, dtype=dpnp.float64)
163+
dp_out = dpnp.array(out, dtype=dpnp.float64)
164+
result = dpnp.log(dp_array, out=dp_out)
165+
166+
# original
167+
np_array = numpy.array(array_data, dtype=numpy.float64)
168+
expected = numpy.log(np_array, out=out)
169+
170+
numpy.testing.assert_array_equal(expected, result)
171+
172+
@pytest.mark.parametrize("dtype",
173+
[numpy.float32, numpy.int64, numpy.int32],
174+
ids=['numpy.float32', 'numpy.int64', 'numpy.int32'])
175+
def test_invalid_dtype(self, dtype):
176+
177+
dp_array = dpnp.arange(10, dtype=dpnp.float64)
178+
dp_out = dpnp.empty(10, dtype=dtype)
179+
180+
with pytest.raises(ValueError):
181+
dpnp.log(dp_array, out=dp_out)
182+
183+
@pytest.mark.parametrize("shape",
184+
[(0,), (15, ), (2,2)],
185+
ids=['(0,)', '(15, )', '(2,2)'])
186+
def test_invalid_shape(self, shape):
187+
188+
dp_array = dpnp.arange(10, dtype=dpnp.float64)
189+
dp_out = dpnp.empty(shape, dtype=dpnp.float64)
190+
191+
with pytest.raises(ValueError):
192+
dpnp.log(dp_array, out=dp_out)
193+
194+
195+
class TestExp:
196+
197+
def test_exp(self):
198+
array_data = numpy.arange(10)
199+
out = numpy.empty(10, dtype=numpy.float64)
200+
201+
# DPNP
202+
dp_array = dpnp.array(array_data, dtype=dpnp.float64)
203+
dp_out = dpnp.array(out, dtype=dpnp.float64)
204+
result = dpnp.exp(dp_array, out=dp_out)
205+
206+
# original
207+
np_array = numpy.array(array_data, dtype=numpy.float64)
208+
expected = numpy.exp(np_array, out=out)
209+
210+
numpy.testing.assert_array_equal(expected, result)
211+
212+
@pytest.mark.parametrize("dtype",
213+
[numpy.float32, numpy.int64, numpy.int32],
214+
ids=['numpy.float32', 'numpy.int64', 'numpy.int32'])
215+
def test_invalid_dtype(self, dtype):
216+
217+
dp_array = dpnp.arange(10, dtype=dpnp.float64)
218+
dp_out = dpnp.empty(10, dtype=dtype)
219+
220+
with pytest.raises(ValueError):
221+
dpnp.exp(dp_array, out=dp_out)
222+
223+
@pytest.mark.parametrize("shape",
224+
[(0,), (15, ), (2,2)],
225+
ids=['(0,)', '(15, )', '(2,2)'])
226+
def test_invalid_shape(self, shape):
227+
228+
dp_array = dpnp.arange(10, dtype=dpnp.float64)
229+
dp_out = dpnp.empty(shape, dtype=dpnp.float64)
88230

89-
numpy.testing.assert_array_equal(expected, result)
231+
with pytest.raises(ValueError):
232+
dpnp.exp(dp_array, out=dp_out)

0 commit comments

Comments
 (0)