1
1
import pytest
2
2
import numpy as np
3
- from openblas_wrap import (
4
- # level 1
5
- dnrm2 , ddot , daxpy ,
6
- # level 3
7
- dgemm , dsyrk ,
8
- # lapack
9
- dgesv , # linalg.solve
10
- dgesdd , dgesdd_lwork , # linalg.svd
11
- dsyev , dsyev_lwork , # linalg.eigh
12
- )
3
+ import openblas_wrap as ow
4
+
5
+ dtype_map = {
6
+ 's' : np . float32 ,
7
+ 'd' : np . float64 ,
8
+ 'c' : np . complex64 ,
9
+ 'z' : np . complex128 ,
10
+ 'dz' : np . complex128 ,
11
+ }
12
+
13
13
14
14
# ### BLAS level 1 ###
15
15
16
16
# dnrm2
17
17
18
18
dnrm2_sizes = [100 , 1000 ]
19
19
20
- def run_dnrm2 (n , x , incx ):
21
- res = dnrm2 (x , n , incx = incx )
20
+ def run_dnrm2 (n , x , incx , func ):
21
+ res = func (x , n , incx = incx )
22
22
return res
23
23
24
24
25
+ @pytest .mark .parametrize ('variant' , ['d' , 'dz' ])
25
26
@pytest .mark .parametrize ('n' , dnrm2_sizes )
26
- def test_nrm2 (benchmark , n ):
27
+ def test_nrm2 (benchmark , n , variant ):
27
28
rndm = np .random .RandomState (1234 )
28
- x = np .array (rndm .uniform (size = (n ,)), dtype = float )
29
- result = benchmark (run_dnrm2 , n , x , 1 )
29
+ dtyp = dtype_map [variant ]
30
+
31
+ x = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
32
+ nrm2 = ow .get_func ('nrm2' , variant )
33
+ result = benchmark (run_dnrm2 , n , x , 1 , nrm2 )
30
34
31
35
32
36
# ddot
33
37
34
38
ddot_sizes = [100 , 1000 ]
35
39
36
- def run_ddot (x , y ,):
37
- res = ddot (x , y )
40
+ def run_ddot (x , y , func ):
41
+ res = func (x , y )
38
42
return res
39
43
40
44
41
45
@pytest .mark .parametrize ('n' , ddot_sizes )
42
46
def test_dot (benchmark , n ):
43
47
rndm = np .random .RandomState (1234 )
48
+
44
49
x = np .array (rndm .uniform (size = (n ,)), dtype = float )
45
50
y = np .array (rndm .uniform (size = (n ,)), dtype = float )
46
- result = benchmark (run_ddot , x , y )
51
+ dot = ow .get_func ('dot' , 'd' )
52
+ result = benchmark (run_ddot , x , y , dot )
47
53
48
54
49
55
# daxpy
50
56
51
57
daxpy_sizes = [100 , 1000 ]
52
58
53
- def run_daxpy (x , y ,):
54
- res = daxpy (x , y , a = 2.0 )
59
+ def run_daxpy (x , y , func ):
60
+ res = func (x , y , a = 2.0 )
55
61
return res
56
62
57
63
64
+ @pytest .mark .parametrize ('variant' , ['s' , 'd' , 'c' , 'z' ])
58
65
@pytest .mark .parametrize ('n' , daxpy_sizes )
59
- def test_daxpy (benchmark , n ):
66
+ def test_daxpy (benchmark , n , variant ):
60
67
rndm = np .random .RandomState (1234 )
61
- x = np .array (rndm .uniform (size = (n ,)), dtype = float )
62
- y = np .array (rndm .uniform (size = (n ,)), dtype = float )
63
- result = benchmark (run_daxpy , x , y )
68
+ dtyp = dtype_map [variant ]
69
+
70
+ x = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
71
+ y = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
72
+ axpy = ow .get_func ('axpy' , variant )
73
+ result = benchmark (run_daxpy , x , y , axpy )
74
+
75
+
76
+ # ### BLAS level 2 ###
77
+
78
+ gemv_sizes = [100 , 1000 ]
79
+
80
+ def run_gemv (a , x , y , func ):
81
+ res = func (1.0 , a , x , y = y , overwrite_y = True )
82
+ return res
83
+
84
+
85
+ @pytest .mark .parametrize ('variant' , ['s' , 'd' , 'c' , 'z' ])
86
+ @pytest .mark .parametrize ('n' , gemv_sizes )
87
+ def test_dgemv (benchmark , n , variant ):
88
+ rndm = np .random .RandomState (1234 )
89
+ dtyp = dtype_map [variant ]
90
+
91
+ x = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
92
+ y = np .empty (n , dtype = dtyp )
93
+
94
+ a = np .array (rndm .uniform (size = (n ,n )), dtype = dtyp )
95
+ x = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
96
+ y = np .zeros (n , dtype = dtyp )
97
+
98
+ gemv = ow .get_func ('gemv' , variant )
99
+ result = benchmark (run_gemv , a , x , y , gemv )
64
100
101
+ assert result is y
65
102
66
103
104
+ # dgbmv
105
+
106
+ dgbmv_sizes = [100 , 1000 ]
107
+
108
+ def run_gbmv (m , n , kl , ku , a , x , y , func ):
109
+ res = func (m , n , kl , ku , 1.0 , a , x , y = y , overwrite_y = True )
110
+ return res
111
+
112
+
113
+
114
+ @pytest .mark .parametrize ('variant' , ['s' , 'd' , 'c' , 'z' ])
115
+ @pytest .mark .parametrize ('n' , dgbmv_sizes )
116
+ @pytest .mark .parametrize ('kl' , [1 ])
117
+ def test_dgbmv (benchmark , n , kl , variant ):
118
+ rndm = np .random .RandomState (1234 )
119
+ dtyp = dtype_map [variant ]
120
+
121
+ x = np .array (rndm .uniform (size = (n ,)), dtype = dtyp )
122
+ y = np .empty (n , dtype = dtyp )
123
+
124
+ m = n
125
+
126
+ a = rndm .uniform (size = (2 * kl + 1 , n ))
127
+ a = np .array (a , dtype = dtyp , order = 'F' )
128
+
129
+ gbmv = ow .get_func ('gbmv' , variant )
130
+ result = benchmark (run_gbmv , m , n , kl , kl , a , x , y , gbmv )
131
+ assert result is y
132
+
67
133
68
134
# ### BLAS level 3 ###
69
135
70
136
# dgemm
71
137
72
138
gemm_sizes = [100 , 1000 ]
73
139
74
- def run_gemm (a , b , c ):
140
+ def run_gemm (a , b , c , func ):
75
141
alpha = 1.0
76
- res = dgemm (alpha , a , b , c = c , overwrite_c = True )
142
+ res = func (alpha , a , b , c = c , overwrite_c = True )
77
143
return res
78
144
79
145
146
+ @pytest .mark .parametrize ('variant' , ['s' , 'd' , 'c' , 'z' ])
80
147
@pytest .mark .parametrize ('n' , gemm_sizes )
81
- def test_gemm (benchmark , n ):
148
+ def test_gemm (benchmark , n , variant ):
82
149
rndm = np .random .RandomState (1234 )
83
- a = np .array (rndm .uniform (size = (n , n )), dtype = float , order = 'F' )
84
- b = np .array (rndm .uniform (size = (n , n )), dtype = float , order = 'F' )
85
- c = np .empty ((n , n ), dtype = float , order = 'F' )
86
- result = benchmark (run_gemm , a , b , c )
150
+ dtyp = dtype_map [variant ]
151
+ a = np .array (rndm .uniform (size = (n , n )), dtype = dtyp , order = 'F' )
152
+ b = np .array (rndm .uniform (size = (n , n )), dtype = dtyp , order = 'F' )
153
+ c = np .empty ((n , n ), dtype = dtyp , order = 'F' )
154
+ gemm = ow .get_func ('gemm' , variant )
155
+ result = benchmark (run_gemm , a , b , c , gemm )
87
156
assert result is c
88
157
89
158
@@ -92,17 +161,20 @@ def test_gemm(benchmark, n):
92
161
syrk_sizes = [100 , 1000 ]
93
162
94
163
95
- def run_syrk (a , c ):
96
- res = dsyrk (1.0 , a , c = c , overwrite_c = True )
164
+ def run_syrk (a , c , func ):
165
+ res = func (1.0 , a , c = c , overwrite_c = True )
97
166
return res
98
167
99
168
169
+ @pytest .mark .parametrize ('variant' , ['s' , 'd' , 'c' , 'z' ])
100
170
@pytest .mark .parametrize ('n' , syrk_sizes )
101
- def test_syrk (benchmark , n ):
171
+ def test_syrk (benchmark , n , variant ):
102
172
rndm = np .random .RandomState (1234 )
103
- a = np .array (rndm .uniform (size = (n , n )), dtype = float , order = 'F' )
104
- c = np .empty ((n , n ), dtype = float , order = 'F' )
105
- result = benchmark (run_syrk , a , c )
173
+ dtyp = dtype_map [variant ]
174
+ a = np .array (rndm .uniform (size = (n , n )), dtype = dtyp , order = 'F' )
175
+ c = np .empty ((n , n ), dtype = dtyp , order = 'F' )
176
+ syrk = ow .get_func ('syrk' , variant )
177
+ result = benchmark (run_syrk , a , c , syrk )
106
178
assert result is c
107
179
108
180
@@ -113,18 +185,22 @@ def test_syrk(benchmark, n):
113
185
gesv_sizes = [100 , 1000 ]
114
186
115
187
116
- def run_gesv (a , b ):
117
- res = dgesv (a , b , overwrite_a = True , overwrite_b = True )
188
+ def run_gesv (a , b , func ):
189
+ res = func (a , b , overwrite_a = True , overwrite_b = True )
118
190
return res
119
191
120
192
193
+ @pytest .mark .parametrize ('variant' , ['s' , 'd' , 'c' , 'z' ])
121
194
@pytest .mark .parametrize ('n' , gesv_sizes )
122
- def test_gesv (benchmark , n ):
195
+ def test_gesv (benchmark , n , variant ):
123
196
rndm = np .random .RandomState (1234 )
124
- a = (np .array (rndm .uniform (size = (n , n )), dtype = float , order = 'F' ) +
125
- np .eye (n , order = 'F' ))
126
- b = np .array (rndm .uniform (size = (n , 1 )), order = 'F' )
127
- lu , piv , x , info = benchmark (run_gesv , a , b )
197
+ dtyp = dtype_map [variant ]
198
+
199
+ a = (np .array (rndm .uniform (size = (n , n )), dtype = dtyp , order = 'F' ) +
200
+ np .eye (n , dtype = dtyp , order = 'F' ))
201
+ b = np .array (rndm .uniform (size = (n , 1 )), dtype = dtyp , order = 'F' )
202
+ gesv = ow .get_func ('gesv' , variant )
203
+ lu , piv , x , info = benchmark (run_gesv , a , b , gesv )
128
204
assert lu is a
129
205
assert x is b
130
206
assert info == 0
@@ -135,49 +211,63 @@ def test_gesv(benchmark, n):
135
211
gesdd_sizes = [(100 , 5 ), (1000 , 222 )]
136
212
137
213
138
- def run_gesdd (a , lwork ):
139
- res = dgesdd (a , lwork = lwork , full_matrices = False , overwrite_a = False )
214
+ def run_gesdd (a , lwork , func ):
215
+ res = func (a , lwork = lwork , full_matrices = False , overwrite_a = False )
140
216
return res
141
217
142
218
219
+ @pytest .mark .parametrize ('variant' , ['s' , 'd' ])
143
220
@pytest .mark .parametrize ('mn' , gesdd_sizes )
144
- def test_gesdd (benchmark , mn ):
221
+ def test_gesdd (benchmark , mn , variant ):
145
222
m , n = mn
146
223
rndm = np .random .RandomState (1234 )
147
- a = np .array (rndm .uniform (size = (m , n )), dtype = float , order = 'F' )
224
+ dtyp = dtype_map [variant ]
225
+
226
+ a = np .array (rndm .uniform (size = (m , n )), dtype = dtyp , order = 'F' )
148
227
149
- lwork , info = dgesdd_lwork (m , n )
228
+ gesdd_lwork = ow .get_func ('gesdd_lwork' , variant )
229
+
230
+ lwork , info = gesdd_lwork (m , n )
150
231
lwork = int (lwork )
151
232
assert info == 0
152
233
153
- u , s , vt , info = benchmark (run_gesdd , a , lwork )
234
+ gesdd = ow .get_func ('gesdd' , variant )
235
+ u , s , vt , info = benchmark (run_gesdd , a , lwork , gesdd )
154
236
155
237
assert info == 0
156
- np .testing .assert_allclose (u @ np .diag (s ) @ vt , a , atol = 1e-13 )
238
+
239
+ atol = {'s' : 1e-5 , 'd' : 1e-13 }
240
+
241
+ np .testing .assert_allclose (u @ np .diag (s ) @ vt , a , atol = atol [variant ])
157
242
158
243
159
244
# linalg.eigh
160
245
161
246
syev_sizes = [50 , 200 ]
162
247
163
248
164
- def run_syev (a , lwork ):
165
- res = dsyev (a , lwork = lwork , overwrite_a = True )
249
+ def run_syev (a , lwork , func ):
250
+ res = func (a , lwork = lwork , overwrite_a = True )
166
251
return res
167
252
168
253
254
+ @pytest .mark .parametrize ('variant' , ['s' , 'd' ])
169
255
@pytest .mark .parametrize ('n' , syev_sizes )
170
- def test_syev (benchmark , n ):
256
+ def test_syev (benchmark , n , variant ):
171
257
rndm = np .random .RandomState (1234 )
258
+ dtyp = dtype_map [variant ]
259
+
172
260
a = rndm .uniform (size = (n , n ))
173
- a = np .asarray (a + a .T , dtype = float , order = 'F' )
261
+ a = np .asarray (a + a .T , dtype = dtyp , order = 'F' )
174
262
a_ = a .copy ()
175
263
264
+ dsyev_lwork = ow .get_func ('syev_lwork' , variant )
176
265
lwork , info = dsyev_lwork (n )
177
266
lwork = int (lwork )
178
267
assert info == 0
179
268
180
- w , v , info = benchmark (run_syev , a , lwork )
269
+ syev = ow .get_func ('syev' , variant )
270
+ w , v , info = benchmark (run_syev , a , lwork , syev )
181
271
182
272
assert info == 0
183
273
assert a is v # overwrite_a=True
0 commit comments