Skip to content

Commit 02087a6

Browse files
authored
Merge pull request #3205 from intelmy/sgemv_n_opt
optimize on sgemv_n for small n
2 parents 03b4d79 + c0ca63e commit 02087a6

File tree

2 files changed

+308
-8
lines changed

2 files changed

+308
-8
lines changed

kernel/x86_64/sgemv_n_4.c

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3535
#include "sgemv_n_microk_nehalem-4.c"
3636
#elif defined(SANDYBRIDGE)
3737
#include "sgemv_n_microk_sandy-4.c"
38-
#elif defined(HASWELL) || defined(ZEN) || defined (SKYLAKEX) || defined (COOPERLAKE)
38+
#elif defined(HASWELL) || defined(ZEN)
3939
#include "sgemv_n_microk_haswell-4.c"
40+
#elif defined (SKYLAKEX) || defined (COOPERLAKE)
41+
#include "sgemv_n_microk_haswell-4.c"
42+
#include "sgemv_n_microk_skylakex-8.c"
4043
#endif
4144

4245
#if defined(STEAMROLLER) || defined(EXCAVATOR)
@@ -291,6 +294,41 @@ static void add_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest)
291294

292295
int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *buffer)
293296
{
297+
if ( m < 1 || n < 1) return(0);
298+
299+
#ifdef HAVE_SGEMV_N_SKYLAKE_KERNEL
300+
if (m <= 16384 && n <= 48 && !(n == 4))
301+
{
302+
FLOAT * xbuffer_align = x;
303+
FLOAT * ybuffer_align = y;
304+
305+
FLOAT * xbuffer = NULL;
306+
FLOAT * ybuffer = NULL;
307+
308+
if (inc_x != 1) {
309+
xbuffer_align = buffer;
310+
for(BLASLONG i=0; i<n; i++) {
311+
xbuffer_align[i] = x[i*inc_x];
312+
}
313+
}
314+
315+
if (inc_y != 1) {
316+
ybuffer_align = buffer + n;
317+
for(BLASLONG i=0; i<m; i++) {
318+
ybuffer_align[i] = y[i*inc_y];
319+
}
320+
}
321+
sgemv_kernel_n_128(m, n , alpha, a, lda, xbuffer_align, ybuffer_align);
322+
323+
if(inc_y != 1) {
324+
for(BLASLONG i=0; i<m; i++) {
325+
y[i*inc_y] = ybuffer_align[i];
326+
}
327+
}
328+
return(0);
329+
}
330+
331+
#endif
294332
BLASLONG i;
295333
FLOAT *a_ptr;
296334
FLOAT *x_ptr;
@@ -305,9 +343,6 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
305343
BLASLONG lda8 = lda << 3;
306344
FLOAT xbuffer[8],*ybuffer;
307345

308-
if ( m < 1 ) return(0);
309-
if ( n < 1 ) return(0);
310-
311346
ybuffer = buffer;
312347

313348
if ( inc_x == 1 )
@@ -322,10 +357,9 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
322357

323358
}
324359

325-
m3 = m & 3 ;
326-
m1 = m & -4 ;
327-
m2 = (m & (NBMAX-1)) - m3 ;
328-
360+
m3 = m & 3 ;
361+
m1 = m & -4 ;
362+
m2 = (m & (NBMAX-1)) - m3 ;
329363

330364
y_ptr = y;
331365

@@ -383,15 +417,23 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
383417

384418
if ( n2 & 2 )
385419
{
420+
#ifdef HAVE_SGEMV_N_SKYLAKE_KERNEL
421+
sgemv_kernel_n_64(NB, 2, alpha, a_ptr, lda, x_ptr, ybuffer);
422+
#else
386423
sgemv_kernel_4x2(NB,ap,x_ptr,ybuffer,&alpha);
424+
#endif
387425
a_ptr += lda*2;
388426
x_ptr += 2;
389427
}
390428

391429

392430
if ( n2 & 1 )
393431
{
432+
#ifdef HAVE_SGEMV_N_SKYLAKE_KERNEL
433+
sgemv_kernel_n_64(NB, 1, alpha, a_ptr, lda, x_ptr, ybuffer);
434+
#else
394435
sgemv_kernel_4x1(NB,a_ptr,x_ptr,ybuffer,&alpha);
436+
#endif
395437
/* a_ptr += lda;
396438
x_ptr += 1a; */
397439

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
/***************************************************************************
2+
Copyright (c) 2014, The OpenBLAS Project
3+
All rights reserved.
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are
6+
met:
7+
1. Redistributions of source code must retain the above copyright
8+
notice, this list of conditions and the following disclaimer.
9+
2. Redistributions in binary form must reproduce the above copyright
10+
notice, this list of conditions and the following disclaimer in
11+
the documentation and/or other materials provided with the
12+
distribution.
13+
3. Neither the name of the OpenBLAS project nor the names of
14+
its contributors may be used to endorse or promote products
15+
derived from this software without specific prior written permission.
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25+
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
*****************************************************************************/
27+
28+
/* need a new enough GCC for avx512 support */
29+
#if (( defined(__GNUC__) && __GNUC__ >= 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 6))
30+
31+
#define HAVE_SGEMV_N_SKYLAKE_KERNEL 1
32+
#include "common.h"
33+
#include <immintrin.h>
34+
static int sgemv_kernel_n_128(BLASLONG m, BLASLONG n, float alpha, float *a, BLASLONG lda, float *x, float *y)
35+
{
36+
__m512 matrixArray_0, matrixArray_1, matrixArray_2, matrixArray_3, matrixArray_4, matrixArray_5, matrixArray_6, matrixArray_7;
37+
__m512 accum512_0, accum512_1, accum512_2, accum512_3, accum512_4, accum512_5, accum512_6, accum512_7;
38+
__m512 xArray_0;
39+
__m512 ALPHAVECTOR = _mm512_set1_ps(alpha);
40+
BLASLONG tag_m_128x = m & (~127);
41+
BLASLONG tag_m_64x = m & (~63);
42+
BLASLONG tag_m_32x = m & (~31);
43+
BLASLONG tag_m_16x = m & (~15);
44+
45+
for (BLASLONG idx_m = 0; idx_m < tag_m_128x; idx_m+=128) {
46+
accum512_0 = _mm512_setzero_ps();
47+
accum512_1 = _mm512_setzero_ps();
48+
accum512_2 = _mm512_setzero_ps();
49+
accum512_3 = _mm512_setzero_ps();
50+
accum512_4 = _mm512_setzero_ps();
51+
accum512_5 = _mm512_setzero_ps();
52+
accum512_6 = _mm512_setzero_ps();
53+
accum512_7 = _mm512_setzero_ps();
54+
55+
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) {
56+
xArray_0 = _mm512_set1_ps(x[idx_n]);
57+
58+
matrixArray_0 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 0]);
59+
matrixArray_1 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 16]);
60+
matrixArray_2 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 32]);
61+
matrixArray_3 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 48]);
62+
matrixArray_4 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 64]);
63+
matrixArray_5 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 80]);
64+
matrixArray_6 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 96]);
65+
matrixArray_7 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 112]);
66+
67+
accum512_0 = _mm512_fmadd_ps(matrixArray_0, xArray_0, accum512_0);
68+
accum512_1 = _mm512_fmadd_ps(matrixArray_1, xArray_0, accum512_1);
69+
accum512_2 = _mm512_fmadd_ps(matrixArray_2, xArray_0, accum512_2);
70+
accum512_3 = _mm512_fmadd_ps(matrixArray_3, xArray_0, accum512_3);
71+
accum512_4 = _mm512_fmadd_ps(matrixArray_4, xArray_0, accum512_4);
72+
accum512_5 = _mm512_fmadd_ps(matrixArray_5, xArray_0, accum512_5);
73+
accum512_6 = _mm512_fmadd_ps(matrixArray_6, xArray_0, accum512_6);
74+
accum512_7 = _mm512_fmadd_ps(matrixArray_7, xArray_0, accum512_7);
75+
}
76+
77+
_mm512_storeu_ps(&y[idx_m + 0], _mm512_fmadd_ps(accum512_0, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 0])));
78+
_mm512_storeu_ps(&y[idx_m + 16], _mm512_fmadd_ps(accum512_1, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 16])));
79+
_mm512_storeu_ps(&y[idx_m + 32], _mm512_fmadd_ps(accum512_2, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 32])));
80+
_mm512_storeu_ps(&y[idx_m + 48], _mm512_fmadd_ps(accum512_3, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 48])));
81+
_mm512_storeu_ps(&y[idx_m + 64], _mm512_fmadd_ps(accum512_4, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 64])));
82+
_mm512_storeu_ps(&y[idx_m + 80], _mm512_fmadd_ps(accum512_5, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 80])));
83+
_mm512_storeu_ps(&y[idx_m + 96], _mm512_fmadd_ps(accum512_6, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 96])));
84+
_mm512_storeu_ps(&y[idx_m + 112], _mm512_fmadd_ps(accum512_7, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 112])));
85+
}
86+
if (tag_m_128x != m) {
87+
for (BLASLONG idx_m = tag_m_128x; idx_m < tag_m_64x; idx_m+=64) {
88+
accum512_0 = _mm512_setzero_ps();
89+
accum512_1 = _mm512_setzero_ps();
90+
accum512_2 = _mm512_setzero_ps();
91+
accum512_3 = _mm512_setzero_ps();
92+
93+
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) {
94+
xArray_0 = _mm512_set1_ps(x[idx_n]);
95+
96+
matrixArray_0 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 0]);
97+
matrixArray_1 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 16]);
98+
matrixArray_2 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 32]);
99+
matrixArray_3 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 48]);
100+
101+
accum512_0 = _mm512_fmadd_ps(matrixArray_0, xArray_0, accum512_0);
102+
accum512_1 = _mm512_fmadd_ps(matrixArray_1, xArray_0, accum512_1);
103+
accum512_2 = _mm512_fmadd_ps(matrixArray_2, xArray_0, accum512_2);
104+
accum512_3 = _mm512_fmadd_ps(matrixArray_3, xArray_0, accum512_3);
105+
}
106+
107+
_mm512_storeu_ps(&y[idx_m + 0], _mm512_fmadd_ps(accum512_0, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 0])));
108+
_mm512_storeu_ps(&y[idx_m + 16], _mm512_fmadd_ps(accum512_1, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 16])));
109+
_mm512_storeu_ps(&y[idx_m + 32], _mm512_fmadd_ps(accum512_2, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 32])));
110+
_mm512_storeu_ps(&y[idx_m + 48], _mm512_fmadd_ps(accum512_3, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 48])));
111+
}
112+
113+
if(tag_m_64x != m) {
114+
for (BLASLONG idx_m = tag_m_64x; idx_m < tag_m_32x; idx_m+=32) {
115+
accum512_0 = _mm512_setzero_ps();
116+
accum512_1 = _mm512_setzero_ps();
117+
118+
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) {
119+
xArray_0 = _mm512_set1_ps(x[idx_n]);
120+
121+
matrixArray_0 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 0]);
122+
matrixArray_1 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 16]);
123+
124+
accum512_0 = _mm512_fmadd_ps(matrixArray_0, xArray_0, accum512_0);
125+
accum512_1 = _mm512_fmadd_ps(matrixArray_1, xArray_0, accum512_1);
126+
}
127+
128+
_mm512_storeu_ps(&y[idx_m + 0], _mm512_fmadd_ps(accum512_0, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 0])));
129+
_mm512_storeu_ps(&y[idx_m + 16], _mm512_fmadd_ps(accum512_1, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 16])));
130+
}
131+
132+
if(tag_m_32x != m) {
133+
134+
for (BLASLONG idx_m = tag_m_32x; idx_m < tag_m_16x; idx_m+=16) {
135+
accum512_0 = _mm512_setzero_ps();
136+
137+
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) {
138+
xArray_0 = _mm512_set1_ps(x[idx_n]);
139+
140+
matrixArray_0 = _mm512_loadu_ps(&a[idx_n * lda + idx_m + 0]);
141+
142+
accum512_0 = _mm512_fmadd_ps(matrixArray_0, xArray_0, accum512_0);
143+
}
144+
145+
_mm512_storeu_ps(&y[idx_m + 0], _mm512_fmadd_ps(accum512_0, ALPHAVECTOR, _mm512_loadu_ps(&y[idx_m + 0])));
146+
}
147+
148+
if (tag_m_16x != m) {
149+
accum512_0 = _mm512_setzero_ps();
150+
151+
unsigned short tail_mask_value = (((unsigned int)0xffff) >> (16-(m&15)));
152+
__mmask16 tail_mask = *((__mmask16*) &tail_mask_value);
153+
154+
for(BLASLONG idx_n = 0; idx_n < n; idx_n++) {
155+
xArray_0 = _mm512_set1_ps(x[idx_n]);
156+
matrixArray_0 = _mm512_maskz_loadu_ps(tail_mask, &a[idx_n * lda + tag_m_16x]);
157+
158+
accum512_0 = _mm512_fmadd_ps(matrixArray_0, xArray_0, accum512_0);
159+
}
160+
161+
_mm512_mask_storeu_ps(&y[tag_m_16x], tail_mask, _mm512_fmadd_ps(accum512_0, ALPHAVECTOR, _mm512_maskz_loadu_ps(tail_mask, &y[tag_m_16x])));
162+
163+
}
164+
}
165+
}
166+
}
167+
return 0;
168+
}
169+
170+
static int sgemv_kernel_n_64(BLASLONG m, BLASLONG n, float alpha, float *a, BLASLONG lda, float *x, float *y)
171+
{
172+
__m256 ma0, ma1, ma2, ma3, ma4, ma5, ma6, ma7;
173+
__m256 as0, as1, as2, as3, as4, as5, as6, as7;
174+
__m256 alphav = _mm256_set1_ps(alpha);
175+
__m256 xv;
176+
BLASLONG tag_m_32x = m & (~31);
177+
BLASLONG tag_m_16x = m & (~15);
178+
BLASLONG tag_m_8x = m & (~7);
179+
__mmask8 one_mask = 0xff;
180+
181+
for (BLASLONG idx_m = 0; idx_m < tag_m_32x; idx_m+=32) {
182+
as0 = _mm256_setzero_ps();
183+
as1 = _mm256_setzero_ps();
184+
as2 = _mm256_setzero_ps();
185+
as3 = _mm256_setzero_ps();
186+
187+
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) {
188+
xv = _mm256_set1_ps(x[idx_n]);
189+
ma0 = _mm256_maskz_loadu_ps(one_mask, &a[idx_n * lda + idx_m +0]);
190+
ma1 = _mm256_maskz_loadu_ps(one_mask, &a[idx_n * lda + idx_m +8]);
191+
ma2 = _mm256_maskz_loadu_ps(one_mask, &a[idx_n * lda + idx_m +16]);
192+
ma3 = _mm256_maskz_loadu_ps(one_mask, &a[idx_n * lda + idx_m +24]);
193+
194+
as0 = _mm256_maskz_fmadd_ps(one_mask, ma0, xv, as0);
195+
as1 = _mm256_maskz_fmadd_ps(one_mask, ma1, xv, as1);
196+
as2 = _mm256_maskz_fmadd_ps(one_mask, ma2, xv, as2);
197+
as3 = _mm256_maskz_fmadd_ps(one_mask, ma3, xv, as3);
198+
}
199+
_mm256_mask_storeu_ps(&y[idx_m], one_mask, _mm256_maskz_fmadd_ps(one_mask, as0, alphav, _mm256_maskz_loadu_ps(one_mask, &y[idx_m])));
200+
_mm256_mask_storeu_ps(&y[idx_m + 8], one_mask, _mm256_maskz_fmadd_ps(one_mask, as1, alphav, _mm256_maskz_loadu_ps(one_mask, &y[idx_m + 8])));
201+
_mm256_mask_storeu_ps(&y[idx_m + 16], one_mask, _mm256_maskz_fmadd_ps(one_mask, as2, alphav, _mm256_maskz_loadu_ps(one_mask, &y[idx_m + 16])));
202+
_mm256_mask_storeu_ps(&y[idx_m + 24], one_mask, _mm256_maskz_fmadd_ps(one_mask, as3, alphav, _mm256_maskz_loadu_ps(one_mask, &y[idx_m + 24])));
203+
204+
}
205+
206+
if (tag_m_32x != m ) {
207+
for (BLASLONG idx_m = tag_m_32x; idx_m < tag_m_16x; idx_m+=16) {
208+
as4 = _mm256_setzero_ps();
209+
as5 = _mm256_setzero_ps();
210+
211+
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) {
212+
xv = _mm256_set1_ps(x[idx_n]);
213+
ma4 = _mm256_maskz_loadu_ps(one_mask, &a[idx_n * lda + idx_m +0]);
214+
ma5 = _mm256_maskz_loadu_ps(one_mask, &a[idx_n * lda + idx_m +8]);
215+
216+
as4 = _mm256_maskz_fmadd_ps(one_mask, ma4, xv, as4);
217+
as5 = _mm256_maskz_fmadd_ps(one_mask, ma5, xv, as5);
218+
}
219+
_mm256_mask_storeu_ps(&y[idx_m], one_mask, _mm256_maskz_fmadd_ps(one_mask, as4, alphav, _mm256_maskz_loadu_ps(one_mask, &y[idx_m])));
220+
_mm256_mask_storeu_ps(&y[idx_m + 8], one_mask, _mm256_maskz_fmadd_ps(one_mask, as5, alphav, _mm256_maskz_loadu_ps(one_mask, &y[idx_m + 8])));
221+
}
222+
223+
if (tag_m_16x != m ) {
224+
for (BLASLONG idx_m = tag_m_16x; idx_m < tag_m_8x; idx_m+=8) {
225+
as6 = _mm256_setzero_ps();
226+
227+
for (BLASLONG idx_n = 0; idx_n < n; idx_n++) {
228+
xv = _mm256_set1_ps(x[idx_n]);
229+
ma6 = _mm256_maskz_loadu_ps(one_mask, &a[idx_n * lda + idx_m]);
230+
as6 = _mm256_maskz_fmadd_ps(one_mask, ma6, xv, as6);
231+
}
232+
_mm256_mask_storeu_ps(&y[idx_m], one_mask, _mm256_maskz_fmadd_ps(one_mask, as6, alphav, _mm256_maskz_loadu_ps(one_mask, &y[idx_m])));
233+
}
234+
235+
if (tag_m_8x != m) {
236+
as7 = _mm256_setzero_ps();
237+
238+
unsigned char tail_mask_uint = (((unsigned char)0xff) >> (8-(m&7)));
239+
__mmask8 tail_mask = *((__mmask8*) &tail_mask_uint);
240+
241+
for(BLASLONG idx_n = 0; idx_n < n; idx_n++) {
242+
xv = _mm256_set1_ps(x[idx_n]);
243+
ma7 = _mm256_maskz_loadu_ps(tail_mask, &a[idx_n * lda + tag_m_8x]);
244+
245+
as7 = _mm256_maskz_fmadd_ps(tail_mask, ma7, xv, as7);
246+
}
247+
248+
_mm256_mask_storeu_ps(&y[tag_m_8x], tail_mask, _mm256_maskz_fmadd_ps(tail_mask, as7, alphav, _mm256_maskz_loadu_ps(tail_mask, &y[tag_m_8x])));
249+
250+
}
251+
}
252+
}
253+
254+
return 0;
255+
}
256+
257+
258+
#endif

0 commit comments

Comments
 (0)