14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
17
- import itertools
18
- import os
19
- import re
20
-
21
17
import numpy as np
22
18
import pytest
23
19
from numpy .testing import assert_allclose
34
30
(np .arctan , dpt .atan ),
35
31
]
36
32
_all_funcs = _trig_funcs + _inv_trig_funcs
37
- _dpt_funcs = [t [1 ] for t in _all_funcs ]
38
33
39
34
40
35
@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
@@ -43,17 +38,10 @@ def test_trig_out_type(np_call, dpt_call, dtype):
43
38
q = get_queue_or_skip ()
44
39
skip_if_dtype_not_supported (dtype , q )
45
40
46
- X = dpt .asarray (0 , dtype = dtype , sycl_queue = q )
47
- expected_dtype = np_call (np .array (0 , dtype = dtype )).dtype
48
- expected_dtype = _map_to_device_dtype (expected_dtype , q .sycl_device )
49
- assert dpt_call (X ).dtype == expected_dtype
50
-
51
- X = dpt .asarray (0 , dtype = dtype , sycl_queue = q )
41
+ x = dpt .asarray (0 , dtype = dtype , sycl_queue = q )
52
42
expected_dtype = np_call (np .array (0 , dtype = dtype )).dtype
53
43
expected_dtype = _map_to_device_dtype (expected_dtype , q .sycl_device )
54
- Y = dpt .empty_like (X , dtype = expected_dtype )
55
- dpt_call (X , out = Y )
56
- assert_allclose (dpt .asnumpy (dpt_call (X )), dpt .asnumpy (Y ))
44
+ assert dpt_call (x ).dtype == expected_dtype
57
45
58
46
59
47
@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
@@ -127,78 +115,6 @@ def test_trig_complex_contig(np_call, dpt_call, dtype):
127
115
assert_allclose (dpt .asnumpy (Z ), expected , atol = tol , rtol = tol )
128
116
129
117
130
- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
131
- @pytest .mark .parametrize ("usm_type" , ["device" , "shared" , "host" ])
132
- def test_trig_usm_type (np_call , dpt_call , usm_type ):
133
- q = get_queue_or_skip ()
134
-
135
- arg_dt = np .dtype ("f4" )
136
- input_shape = (10 , 10 , 10 , 10 )
137
- X = dpt .empty (input_shape , dtype = arg_dt , usm_type = usm_type , sycl_queue = q )
138
- if np_call in _trig_funcs :
139
- X [..., 0 ::2 ] = np .pi / 6
140
- X [..., 1 ::2 ] = np .pi / 3
141
- if np_call == np .arctan :
142
- X [..., 0 ::2 ] = - 2.2
143
- X [..., 1 ::2 ] = 3.3
144
- else :
145
- X [..., 0 ::2 ] = - 0.3
146
- X [..., 1 ::2 ] = 0.7
147
-
148
- Y = dpt_call (X )
149
- assert Y .usm_type == X .usm_type
150
- assert Y .sycl_queue == X .sycl_queue
151
- assert Y .flags .c_contiguous
152
-
153
- expected_Y = np_call (dpt .asnumpy (X ))
154
- tol = 8 * dpt .finfo (Y .dtype ).resolution
155
- assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
156
-
157
-
158
- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
159
- @pytest .mark .parametrize ("dtype" , _all_dtypes )
160
- def test_trig_order (np_call , dpt_call , dtype ):
161
- q = get_queue_or_skip ()
162
- skip_if_dtype_not_supported (dtype , q )
163
-
164
- arg_dt = np .dtype (dtype )
165
- input_shape = (4 , 4 , 4 , 4 )
166
- X = dpt .empty (input_shape , dtype = arg_dt , sycl_queue = q )
167
- if np_call in _trig_funcs :
168
- X [..., 0 ::2 ] = np .pi / 6
169
- X [..., 1 ::2 ] = np .pi / 3
170
- if np_call == np .arctan :
171
- X [..., 0 ::2 ] = - 2.2
172
- X [..., 1 ::2 ] = 3.3
173
- else :
174
- X [..., 0 ::2 ] = - 0.3
175
- X [..., 1 ::2 ] = 0.7
176
-
177
- for perms in itertools .permutations (range (4 )):
178
- U = dpt .permute_dims (X [:, ::- 1 , ::- 1 , :], perms )
179
- expected_Y = np_call (dpt .asnumpy (U ))
180
- for ord in ["C" , "F" , "A" , "K" ]:
181
- Y = dpt_call (U , order = ord )
182
- tol = 8 * max (
183
- dpt .finfo (Y .dtype ).resolution ,
184
- np .finfo (expected_Y .dtype ).resolution ,
185
- )
186
- assert_allclose (dpt .asnumpy (Y ), expected_Y , atol = tol , rtol = tol )
187
-
188
-
189
- @pytest .mark .parametrize ("callable" , _dpt_funcs )
190
- @pytest .mark .parametrize ("dtype" , _all_dtypes )
191
- def test_trig_error_dtype (callable , dtype ):
192
- q = get_queue_or_skip ()
193
- skip_if_dtype_not_supported (dtype , q )
194
-
195
- x = dpt .zeros (5 , dtype = dtype )
196
- y = dpt .empty_like (x , dtype = "int16" )
197
- with pytest .raises (ValueError ) as excinfo :
198
- callable (x , out = y )
199
- assert re .match ("Output array of type.*is needed" , str (excinfo .value ))
200
-
201
-
202
118
@pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
203
119
@pytest .mark .parametrize ("dtype" , ["f2" , "f4" , "f8" ])
204
120
def test_trig_real_strided (np_call , dpt_call , dtype ):
@@ -298,47 +214,3 @@ def test_trig_real_special_cases(np_call, dpt_call, dtype):
298
214
tol = 8 * dpt .finfo (dtype ).resolution
299
215
Y = dpt_call (yf )
300
216
assert_allclose (dpt .asnumpy (Y ), Y_np , atol = tol , rtol = tol )
301
-
302
-
303
- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
304
- @pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
305
- def test_trig_complex_special_cases_conj_property (np_call , dpt_call , dtype ):
306
- q = get_queue_or_skip ()
307
- skip_if_dtype_not_supported (dtype , q )
308
-
309
- x = [np .nan , np .inf , - np .inf , + 0.0 , - 0.0 , + 1.0 , - 1.0 ]
310
- xc = [complex (* val ) for val in itertools .product (x , repeat = 2 )]
311
-
312
- Xc_np = np .array (xc , dtype = dtype )
313
- Xc = dpt .asarray (Xc_np , dtype = dtype , sycl_queue = q )
314
-
315
- tol = 50 * dpt .finfo (dtype ).resolution
316
- Y = dpt_call (Xc )
317
- Yc = dpt_call (dpt .conj (Xc ))
318
-
319
- dpt .allclose (Y , dpt .conj (Yc ), atol = tol , rtol = tol )
320
-
321
-
322
- @pytest .mark .skipif (
323
- os .name != "posix" , reason = "Known to fail on Windows due to bug in NumPy"
324
- )
325
- @pytest .mark .parametrize ("np_call, dpt_call" , _all_funcs )
326
- @pytest .mark .parametrize ("dtype" , ["c8" , "c16" ])
327
- def test_trig_complex_special_cases (np_call , dpt_call , dtype ):
328
-
329
- q = get_queue_or_skip ()
330
- skip_if_dtype_not_supported (dtype , q )
331
-
332
- x = [np .nan , np .inf , - np .inf , + 0.0 , - 0.0 , + 1.0 , - 1.0 ]
333
- xc = [complex (* val ) for val in itertools .product (x , repeat = 2 )]
334
-
335
- Xc_np = np .array (xc , dtype = dtype )
336
- Xc = dpt .asarray (Xc_np , dtype = dtype , sycl_queue = q )
337
-
338
- with np .errstate (all = "ignore" ):
339
- Ynp = np_call (Xc_np )
340
-
341
- tol = 50 * dpt .finfo (dtype ).resolution
342
- Y = dpt_call (Xc )
343
- assert_allclose (dpt .asnumpy (dpt .real (Y )), np .real (Ynp ), atol = tol , rtol = tol )
344
- assert_allclose (dpt .asnumpy (dpt .imag (Y )), np .imag (Ynp ), atol = tol , rtol = tol )
0 commit comments