@@ -127,16 +127,24 @@ def test_projection_complex(dtype):
127
127
q = get_queue_or_skip ()
128
128
skip_if_dtype_not_supported (dtype , q )
129
129
130
- X = [complex (1 , 2 ), complex (dpt .inf , - 1 ), complex (0 , - dpt .inf )]
131
- Y = [complex (1 , 2 ), complex (dpt .inf , - 0 ), complex (dpt .inf , - 0 )]
130
+ X = [
131
+ complex (1 , 2 ),
132
+ complex (dpt .inf , - 1 ),
133
+ complex (0 , - dpt .inf ),
134
+ complex (- dpt .inf , dpt .nan ),
135
+ ]
136
+ Y = [
137
+ complex (1 , 2 ),
138
+ complex (np .inf , - 0.0 ),
139
+ complex (np .inf , - 0.0 ),
140
+ complex (np .inf , 0.0 ),
141
+ ]
132
142
133
143
Xf = dpt .asarray (X , dtype = dtype , sycl_queue = q )
134
- Yf = dpt . asarray (Y , dtype = dtype , sycl_queue = q )
144
+ Yf = np . array (Y , dtype = dtype )
135
145
136
146
tol = 8 * dpt .finfo (Xf .dtype ).resolution
137
- assert_allclose (
138
- dpt .asnumpy (dpt .proj (Xf )), dpt .asnumpy (Yf ), atol = tol , rtol = tol
139
- )
147
+ assert_allclose (dpt .asnumpy (dpt .proj (Xf )), Yf , atol = tol , rtol = tol )
140
148
141
149
142
150
@pytest .mark .parametrize ("dtype" , _all_dtypes )
@@ -146,19 +154,17 @@ def test_projection(dtype):
146
154
147
155
Xf = dpt .asarray (1 , dtype = dtype , sycl_queue = q )
148
156
out_dtype = dpt .proj (Xf ).dtype
149
- Yf = dpt . asarray (complex (1 , 0 ), dtype = out_dtype , sycl_queue = q )
157
+ Yf = np . array (complex (1 , 0 ), dtype = out_dtype )
150
158
151
159
tol = 8 * dpt .finfo (Yf .dtype ).resolution
152
- assert_allclose (
153
- dpt .asnumpy (dpt .proj (Xf )), dpt .asnumpy (Yf ), atol = tol , rtol = tol
154
- )
160
+ assert_allclose (dpt .asnumpy (dpt .proj (Xf )), Yf , atol = tol , rtol = tol )
155
161
156
162
157
163
@pytest .mark .parametrize (
158
164
"np_call, dpt_call" ,
159
165
[(np .real , dpt .real ), (np .imag , dpt .imag ), (np .conj , dpt .conj )],
160
166
)
161
- @pytest .mark .parametrize ("dtype" , ["f " , "d " ])
167
+ @pytest .mark .parametrize ("dtype" , ["f4 " , "f8 " ])
162
168
@pytest .mark .parametrize ("stride" , [- 1 , 1 , 2 , 4 , 5 ])
163
169
def test_complex_strided (np_call , dpt_call , dtype , stride ):
164
170
q = get_queue_or_skip ()
@@ -176,7 +182,7 @@ def test_complex_strided(np_call, dpt_call, dtype, stride):
176
182
assert_allclose (y , dpt .asnumpy (z ), atol = tol , rtol = tol )
177
183
178
184
179
- @pytest .mark .parametrize ("dtype" , ["e " , "f " , "d " ])
185
+ @pytest .mark .parametrize ("dtype" , ["f2 " , "f4 " , "f8 " ])
180
186
def test_complex_special_cases (dtype ):
181
187
q = get_queue_or_skip ()
182
188
skip_if_dtype_not_supported (dtype , q )
0 commit comments