1
+ /***************************************************************************
2
+ Copyright (c) 2014, The OpenBLAS Project
3
+ All rights reserved.
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are
6
+ met:
7
+ 1. Redistributions of source code must retain the above copyright
8
+ notice, this list of conditions and the following disclaimer.
9
+ 2. Redistributions in binary form must reproduce the above copyright
10
+ notice, this list of conditions and the following disclaimer in
11
+ the documentation and/or other materials provided with the
12
+ distribution.
13
+ 3. Neither the name of the OpenBLAS project nor the names of
14
+ its contributors may be used to endorse or promote products
15
+ derived from this software without specific prior written permission.
16
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19
+ ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25
+ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26
+ *****************************************************************************/
27
+
28
+ /* need a new enough GCC for avx512 support */
29
+ #if (( defined(__GNUC__ ) && __GNUC__ >= 6 && defined(__AVX512CD__ )) || (defined(__clang__ ) && __clang_major__ >= 6 ))
30
+
31
+ #define HAVE_SGEMV_N_SKYLAKE_KERNEL 1
32
+ #include "common.h"
33
+ #include <immintrin.h>
34
+ static int sgemv_kernel_n_128 (BLASLONG m , BLASLONG n , float alpha , float * a , BLASLONG lda , float * x , float * y )
35
+ {
36
+ __m512 matrixArray_0 , matrixArray_1 , matrixArray_2 , matrixArray_3 , matrixArray_4 , matrixArray_5 , matrixArray_6 , matrixArray_7 ;
37
+ __m512 accum512_0 , accum512_1 , accum512_2 , accum512_3 , accum512_4 , accum512_5 , accum512_6 , accum512_7 ;
38
+ __m512 xArray_0 ;
39
+ __m512 ALPHAVECTOR = _mm512_set1_ps (alpha );
40
+ BLASLONG tag_m_128x = m & (~127 );
41
+ BLASLONG tag_m_64x = m & (~63 );
42
+ BLASLONG tag_m_32x = m & (~31 );
43
+ BLASLONG tag_m_16x = m & (~15 );
44
+
45
+ for (BLASLONG idx_m = 0 ; idx_m < tag_m_128x ; idx_m += 128 ) {
46
+ accum512_0 = _mm512_setzero_ps ();
47
+ accum512_1 = _mm512_setzero_ps ();
48
+ accum512_2 = _mm512_setzero_ps ();
49
+ accum512_3 = _mm512_setzero_ps ();
50
+ accum512_4 = _mm512_setzero_ps ();
51
+ accum512_5 = _mm512_setzero_ps ();
52
+ accum512_6 = _mm512_setzero_ps ();
53
+ accum512_7 = _mm512_setzero_ps ();
54
+
55
+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
56
+ xArray_0 = _mm512_set1_ps (x [idx_n ]);
57
+
58
+ matrixArray_0 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 0 ]);
59
+ matrixArray_1 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 16 ]);
60
+ matrixArray_2 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 32 ]);
61
+ matrixArray_3 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 48 ]);
62
+ matrixArray_4 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 64 ]);
63
+ matrixArray_5 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 80 ]);
64
+ matrixArray_6 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 96 ]);
65
+ matrixArray_7 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 112 ]);
66
+
67
+ accum512_0 = _mm512_fmadd_ps (matrixArray_0 , xArray_0 , accum512_0 );
68
+ accum512_1 = _mm512_fmadd_ps (matrixArray_1 , xArray_0 , accum512_1 );
69
+ accum512_2 = _mm512_fmadd_ps (matrixArray_2 , xArray_0 , accum512_2 );
70
+ accum512_3 = _mm512_fmadd_ps (matrixArray_3 , xArray_0 , accum512_3 );
71
+ accum512_4 = _mm512_fmadd_ps (matrixArray_4 , xArray_0 , accum512_4 );
72
+ accum512_5 = _mm512_fmadd_ps (matrixArray_5 , xArray_0 , accum512_5 );
73
+ accum512_6 = _mm512_fmadd_ps (matrixArray_6 , xArray_0 , accum512_6 );
74
+ accum512_7 = _mm512_fmadd_ps (matrixArray_7 , xArray_0 , accum512_7 );
75
+ }
76
+
77
+ _mm512_storeu_ps (& y [idx_m + 0 ], _mm512_fmadd_ps (accum512_0 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 0 ])));
78
+ _mm512_storeu_ps (& y [idx_m + 16 ], _mm512_fmadd_ps (accum512_1 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 16 ])));
79
+ _mm512_storeu_ps (& y [idx_m + 32 ], _mm512_fmadd_ps (accum512_2 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 32 ])));
80
+ _mm512_storeu_ps (& y [idx_m + 48 ], _mm512_fmadd_ps (accum512_3 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 48 ])));
81
+ _mm512_storeu_ps (& y [idx_m + 64 ], _mm512_fmadd_ps (accum512_4 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 64 ])));
82
+ _mm512_storeu_ps (& y [idx_m + 80 ], _mm512_fmadd_ps (accum512_5 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 80 ])));
83
+ _mm512_storeu_ps (& y [idx_m + 96 ], _mm512_fmadd_ps (accum512_6 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 96 ])));
84
+ _mm512_storeu_ps (& y [idx_m + 112 ], _mm512_fmadd_ps (accum512_7 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 112 ])));
85
+ }
86
+ if (tag_m_128x != m ) {
87
+ for (BLASLONG idx_m = tag_m_128x ; idx_m < tag_m_64x ; idx_m += 64 ) {
88
+ accum512_0 = _mm512_setzero_ps ();
89
+ accum512_1 = _mm512_setzero_ps ();
90
+ accum512_2 = _mm512_setzero_ps ();
91
+ accum512_3 = _mm512_setzero_ps ();
92
+
93
+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
94
+ xArray_0 = _mm512_set1_ps (x [idx_n ]);
95
+
96
+ matrixArray_0 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 0 ]);
97
+ matrixArray_1 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 16 ]);
98
+ matrixArray_2 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 32 ]);
99
+ matrixArray_3 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 48 ]);
100
+
101
+ accum512_0 = _mm512_fmadd_ps (matrixArray_0 , xArray_0 , accum512_0 );
102
+ accum512_1 = _mm512_fmadd_ps (matrixArray_1 , xArray_0 , accum512_1 );
103
+ accum512_2 = _mm512_fmadd_ps (matrixArray_2 , xArray_0 , accum512_2 );
104
+ accum512_3 = _mm512_fmadd_ps (matrixArray_3 , xArray_0 , accum512_3 );
105
+ }
106
+
107
+ _mm512_storeu_ps (& y [idx_m + 0 ], _mm512_fmadd_ps (accum512_0 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 0 ])));
108
+ _mm512_storeu_ps (& y [idx_m + 16 ], _mm512_fmadd_ps (accum512_1 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 16 ])));
109
+ _mm512_storeu_ps (& y [idx_m + 32 ], _mm512_fmadd_ps (accum512_2 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 32 ])));
110
+ _mm512_storeu_ps (& y [idx_m + 48 ], _mm512_fmadd_ps (accum512_3 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 48 ])));
111
+ }
112
+
113
+ if (tag_m_64x != m ) {
114
+ for (BLASLONG idx_m = tag_m_64x ; idx_m < tag_m_32x ; idx_m += 32 ) {
115
+ accum512_0 = _mm512_setzero_ps ();
116
+ accum512_1 = _mm512_setzero_ps ();
117
+
118
+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
119
+ xArray_0 = _mm512_set1_ps (x [idx_n ]);
120
+
121
+ matrixArray_0 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 0 ]);
122
+ matrixArray_1 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 16 ]);
123
+
124
+ accum512_0 = _mm512_fmadd_ps (matrixArray_0 , xArray_0 , accum512_0 );
125
+ accum512_1 = _mm512_fmadd_ps (matrixArray_1 , xArray_0 , accum512_1 );
126
+ }
127
+
128
+ _mm512_storeu_ps (& y [idx_m + 0 ], _mm512_fmadd_ps (accum512_0 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 0 ])));
129
+ _mm512_storeu_ps (& y [idx_m + 16 ], _mm512_fmadd_ps (accum512_1 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 16 ])));
130
+ }
131
+
132
+ if (tag_m_32x != m ) {
133
+
134
+ for (BLASLONG idx_m = tag_m_32x ; idx_m < tag_m_16x ; idx_m += 16 ) {
135
+ accum512_0 = _mm512_setzero_ps ();
136
+
137
+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
138
+ xArray_0 = _mm512_set1_ps (x [idx_n ]);
139
+
140
+ matrixArray_0 = _mm512_loadu_ps (& a [idx_n * lda + idx_m + 0 ]);
141
+
142
+ accum512_0 = _mm512_fmadd_ps (matrixArray_0 , xArray_0 , accum512_0 );
143
+ }
144
+
145
+ _mm512_storeu_ps (& y [idx_m + 0 ], _mm512_fmadd_ps (accum512_0 , ALPHAVECTOR , _mm512_loadu_ps (& y [idx_m + 0 ])));
146
+ }
147
+
148
+ if (tag_m_16x != m ) {
149
+ accum512_0 = _mm512_setzero_ps ();
150
+
151
+ unsigned short tail_mask_value = (((unsigned int )0xffff ) >> (16 - (m & 15 )));
152
+ __mmask16 tail_mask = * ((__mmask16 * ) & tail_mask_value );
153
+
154
+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
155
+ xArray_0 = _mm512_set1_ps (x [idx_n ]);
156
+ matrixArray_0 = _mm512_maskz_loadu_ps (tail_mask , & a [idx_n * lda + tag_m_16x ]);
157
+
158
+ accum512_0 = _mm512_fmadd_ps (matrixArray_0 , xArray_0 , accum512_0 );
159
+ }
160
+
161
+ _mm512_mask_storeu_ps (& y [tag_m_16x ], tail_mask , _mm512_fmadd_ps (accum512_0 , ALPHAVECTOR , _mm512_maskz_loadu_ps (tail_mask , & y [tag_m_16x ])));
162
+
163
+ }
164
+ }
165
+ }
166
+ }
167
+ return 0 ;
168
+ }
169
+
170
+ static int sgemv_kernel_n_64 (BLASLONG m , BLASLONG n , float alpha , float * a , BLASLONG lda , float * x , float * y )
171
+ {
172
+ __m256 ma0 , ma1 , ma2 , ma3 , ma4 , ma5 , ma6 , ma7 ;
173
+ __m256 as0 , as1 , as2 , as3 , as4 , as5 , as6 , as7 ;
174
+ __m256 alphav = _mm256_set1_ps (alpha );
175
+ __m256 xv ;
176
+ BLASLONG tag_m_32x = m & (~31 );
177
+ BLASLONG tag_m_16x = m & (~15 );
178
+ BLASLONG tag_m_8x = m & (~7 );
179
+ __mmask8 one_mask = 0xff ;
180
+
181
+ for (BLASLONG idx_m = 0 ; idx_m < tag_m_32x ; idx_m += 32 ) {
182
+ as0 = _mm256_setzero_ps ();
183
+ as1 = _mm256_setzero_ps ();
184
+ as2 = _mm256_setzero_ps ();
185
+ as3 = _mm256_setzero_ps ();
186
+
187
+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
188
+ xv = _mm256_set1_ps (x [idx_n ]);
189
+ ma0 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m + 0 ]);
190
+ ma1 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m + 8 ]);
191
+ ma2 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m + 16 ]);
192
+ ma3 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m + 24 ]);
193
+
194
+ as0 = _mm256_maskz_fmadd_ps (one_mask , ma0 , xv , as0 );
195
+ as1 = _mm256_maskz_fmadd_ps (one_mask , ma1 , xv , as1 );
196
+ as2 = _mm256_maskz_fmadd_ps (one_mask , ma2 , xv , as2 );
197
+ as3 = _mm256_maskz_fmadd_ps (one_mask , ma3 , xv , as3 );
198
+ }
199
+ _mm256_mask_storeu_ps (& y [idx_m ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as0 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m ])));
200
+ _mm256_mask_storeu_ps (& y [idx_m + 8 ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as1 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m + 8 ])));
201
+ _mm256_mask_storeu_ps (& y [idx_m + 16 ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as2 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m + 16 ])));
202
+ _mm256_mask_storeu_ps (& y [idx_m + 24 ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as3 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m + 24 ])));
203
+
204
+ }
205
+
206
+ if (tag_m_32x != m ) {
207
+ for (BLASLONG idx_m = tag_m_32x ; idx_m < tag_m_16x ; idx_m += 16 ) {
208
+ as4 = _mm256_setzero_ps ();
209
+ as5 = _mm256_setzero_ps ();
210
+
211
+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
212
+ xv = _mm256_set1_ps (x [idx_n ]);
213
+ ma4 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m + 0 ]);
214
+ ma5 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m + 8 ]);
215
+
216
+ as4 = _mm256_maskz_fmadd_ps (one_mask , ma4 , xv , as4 );
217
+ as5 = _mm256_maskz_fmadd_ps (one_mask , ma5 , xv , as5 );
218
+ }
219
+ _mm256_mask_storeu_ps (& y [idx_m ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as4 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m ])));
220
+ _mm256_mask_storeu_ps (& y [idx_m + 8 ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as5 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m + 8 ])));
221
+ }
222
+
223
+ if (tag_m_16x != m ) {
224
+ for (BLASLONG idx_m = tag_m_16x ; idx_m < tag_m_8x ; idx_m += 8 ) {
225
+ as6 = _mm256_setzero_ps ();
226
+
227
+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
228
+ xv = _mm256_set1_ps (x [idx_n ]);
229
+ ma6 = _mm256_maskz_loadu_ps (one_mask , & a [idx_n * lda + idx_m ]);
230
+ as6 = _mm256_maskz_fmadd_ps (one_mask , ma6 , xv , as6 );
231
+ }
232
+ _mm256_mask_storeu_ps (& y [idx_m ], one_mask , _mm256_maskz_fmadd_ps (one_mask , as6 , alphav , _mm256_maskz_loadu_ps (one_mask , & y [idx_m ])));
233
+ }
234
+
235
+ if (tag_m_8x != m ) {
236
+ as7 = _mm256_setzero_ps ();
237
+
238
+ unsigned char tail_mask_uint = (((unsigned char )0xff ) >> (8 - (m & 7 )));
239
+ __mmask8 tail_mask = * ((__mmask8 * ) & tail_mask_uint );
240
+
241
+ for (BLASLONG idx_n = 0 ; idx_n < n ; idx_n ++ ) {
242
+ xv = _mm256_set1_ps (x [idx_n ]);
243
+ ma7 = _mm256_maskz_loadu_ps (tail_mask , & a [idx_n * lda + tag_m_8x ]);
244
+
245
+ as7 = _mm256_maskz_fmadd_ps (tail_mask , ma7 , xv , as7 );
246
+ }
247
+
248
+ _mm256_mask_storeu_ps (& y [tag_m_8x ], tail_mask , _mm256_maskz_fmadd_ps (tail_mask , as7 , alphav , _mm256_maskz_loadu_ps (tail_mask , & y [tag_m_8x ])));
249
+
250
+ }
251
+ }
252
+ }
253
+
254
+ return 0 ;
255
+ }
256
+
257
+
258
+ #endif
0 commit comments