Skip to content

Commit ff16329

Browse files
authored
Merge pull request #2972 from xiegengxin/rot-intrinsic
Improve the performance of rot by using AVX512 and AVX2 intrinsic
2 parents 433637c + 725ffbf commit ff16329

File tree

9 files changed

+648
-5
lines changed

9 files changed

+648
-5
lines changed

driver/others/blas_l1_thread.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha
8080
break;
8181
}
8282

83-
mode |= BLAS_LEGACY;
83+
if(!(mode & BLAS_PTHREAD)) mode |= BLAS_LEGACY;
8484

8585
for (i = 0; i < nthreads; i++) blas_queue_init(&queue[i]);
8686

driver/others/blas_server_win32.c

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,12 +476,15 @@ int exec_blas(BLASLONG num, blas_queue_t *queue){
476476

477477
routine = queue -> routine;
478478

479-
if (!(queue -> mode & BLAS_LEGACY)) {
479+
if (queue -> mode & BLAS_LEGACY) {
480+
legacy_exec(routine, queue -> mode, queue -> args, queue -> sb);
481+
} else
482+
if (queue -> mode & BLAS_PTHREAD) {
483+
void (*pthreadcompat)(void *) = queue -> routine;
484+
(pthreadcompat)(queue -> args);
485+
} else
480486
(routine)(queue -> args, queue -> range_m, queue -> range_n,
481487
queue -> sa, queue -> sb, 0);
482-
} else {
483-
legacy_exec(routine, queue -> mode, queue -> args, queue -> sb);
484-
}
485488

486489
if ((num > 1) && queue -> next) exec_blas_async_wait(num - 1, queue -> next);
487490

kernel/x86_64/KERNEL.HASWELL

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,6 @@ ZGEMM3MKERNEL = zgemm3m_kernel_4x4_haswell.c
102102

103103
SASUMKERNEL = sasum.c
104104
DASUMKERNEL = dasum.c
105+
106+
SROTKERNEL = srot.c
107+
DROTKERNEL = drot.c

kernel/x86_64/drot.c

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#include "common.h"
2+
3+
#if defined(SKYLAKEX)
4+
#include "drot_microk_skylakex-2.c"
5+
#elif defined(HASWELL)
6+
#include "drot_microk_haswell-2.c"
7+
#endif
8+
9+
#ifndef HAVE_DROT_KERNEL
10+
11+
static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
12+
{
13+
BLASLONG i = 0;
14+
FLOAT f0, f1, f2, f3;
15+
FLOAT x0, x1, x2, x3;
16+
FLOAT g0, g1, g2, g3;
17+
FLOAT y0, y1, y2, y3;
18+
19+
FLOAT* xp = x;
20+
FLOAT* yp = y;
21+
22+
BLASLONG n1 = n & (~7);
23+
24+
while (i < n1) {
25+
x0 = xp[0];
26+
y0 = yp[0];
27+
x1 = xp[1];
28+
y1 = yp[1];
29+
x2 = xp[2];
30+
y2 = yp[2];
31+
x3 = xp[3];
32+
y3 = yp[3];
33+
34+
f0 = c*x0 + s*y0;
35+
g0 = c*y0 - s*x0;
36+
f1 = c*x1 + s*y1;
37+
g1 = c*y1 - s*x1;
38+
f2 = c*x2 + s*y2;
39+
g2 = c*y2 - s*x2;
40+
f3 = c*x3 + s*y3;
41+
g3 = c*y3 - s*x3;
42+
43+
xp[0] = f0;
44+
yp[0] = g0;
45+
xp[1] = f1;
46+
yp[1] = g1;
47+
xp[2] = f2;
48+
yp[2] = g2;
49+
xp[3] = f3;
50+
yp[3] = g3;
51+
52+
xp += 4;
53+
yp += 4;
54+
i += 4;
55+
}
56+
57+
while (i < n) {
58+
FLOAT temp = c*x[i] + s*y[i];
59+
y[i] = c*y[i] - s*x[i];
60+
x[i] = temp;
61+
62+
i++;
63+
}
64+
}
65+
66+
#endif
67+
static void rot_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT c, FLOAT s)
68+
{
69+
BLASLONG i = 0;
70+
BLASLONG ix = 0, iy = 0;
71+
72+
FLOAT temp;
73+
74+
if (n <= 0)
75+
return;
76+
if ((inc_x == 1) && (inc_y == 1)) {
77+
drot_kernel(n, x, y, c, s);
78+
}
79+
else {
80+
while (i < n) {
81+
temp = c * x[ix] + s * y[iy];
82+
y[iy] = c * y[iy] - s * x[ix];
83+
x[ix] = temp;
84+
85+
ix += inc_x;
86+
iy += inc_y;
87+
i++;
88+
}
89+
}
90+
return;
91+
}
92+
93+
94+
#if defined(SMP)
95+
static int rot_thread_function(blas_arg_t *args)
96+
{
97+
98+
rot_compute(args->m,
99+
args->a, args->lda,
100+
args->b, args->ldb,
101+
((FLOAT *)args->alpha)[0],
102+
((FLOAT *)args->alpha)[1]);
103+
return 0;
104+
}
105+
106+
extern int blas_level1_thread(int mode, BLASLONG m, BLASLONG n, BLASLONG k, void *alpha, void *a, BLASLONG lda, void *b, BLASLONG ldb, void *c, BLASLONG ldc, int (*function)(), int nthreads);
107+
#endif
108+
int CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT c, FLOAT s)
109+
{
110+
#if defined(SMP)
111+
int nthreads;
112+
FLOAT alpha[2]={c, s};
113+
FLOAT dummy_c;
114+
#endif
115+
116+
#if defined(SMP)
117+
if (inc_x == 0 || inc_y == 0 || n <= 100000) {
118+
nthreads = 1;
119+
}
120+
else {
121+
nthreads = num_cpu_avail(1);
122+
}
123+
124+
if (nthreads == 1) {
125+
rot_compute(n, x, inc_x, y, inc_y, c, s);
126+
}
127+
else {
128+
#if defined(DOUBLE)
129+
int mode = BLAS_DOUBLE | BLAS_REAL | BLAS_PTHREAD;
130+
#else
131+
int mode = BLAS_SINGLE | BLAS_REAL | BLAS_PTHREAD;
132+
#endif
133+
blas_level1_thread(mode, n, 0, 0, alpha, x, inc_x, y, inc_y, &dummy_c, 0, (void *)rot_thread_function, nthreads);
134+
}
135+
#else
136+
rot_compute(n, x, inc_x, y, inc_y, c, s);
137+
#endif
138+
return 0;
139+
}

kernel/x86_64/drot_microk_haswell-2.c

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/* need a new enough GCC for avx512 support */
2+
#if (( defined(__GNUC__) && __GNUC__ > 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 9))
3+
4+
#define HAVE_DROT_KERNEL 1
5+
6+
#include <immintrin.h>
7+
#include <stdint.h>
8+
9+
static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
10+
{
11+
BLASLONG i = 0;
12+
13+
BLASLONG tail_index_4 = n&(~3);
14+
BLASLONG tail_index_16 = n&(~15);
15+
16+
__m256d c_256, s_256;
17+
if (n >= 4) {
18+
c_256 = _mm256_set1_pd(c);
19+
s_256 = _mm256_set1_pd(s);
20+
}
21+
22+
__m256d x0, x1, x2, x3;
23+
__m256d y0, y1, y2, y3;
24+
__m256d t0, t1, t2, t3;
25+
26+
for (i = 0; i < tail_index_16; i += 16) {
27+
x0 = _mm256_loadu_pd(&x[i + 0]);
28+
x1 = _mm256_loadu_pd(&x[i + 4]);
29+
x2 = _mm256_loadu_pd(&x[i + 8]);
30+
x3 = _mm256_loadu_pd(&x[i +12]);
31+
y0 = _mm256_loadu_pd(&y[i + 0]);
32+
y1 = _mm256_loadu_pd(&y[i + 4]);
33+
y2 = _mm256_loadu_pd(&y[i + 8]);
34+
y3 = _mm256_loadu_pd(&y[i +12]);
35+
36+
t0 = _mm256_mul_pd(s_256, y0);
37+
t1 = _mm256_mul_pd(s_256, y1);
38+
t2 = _mm256_mul_pd(s_256, y2);
39+
t3 = _mm256_mul_pd(s_256, y3);
40+
41+
t0 = _mm256_fmadd_pd(c_256, x0, t0);
42+
t1 = _mm256_fmadd_pd(c_256, x1, t1);
43+
t2 = _mm256_fmadd_pd(c_256, x2, t2);
44+
t3 = _mm256_fmadd_pd(c_256, x3, t3);
45+
46+
_mm256_storeu_pd(&x[i + 0], t0);
47+
_mm256_storeu_pd(&x[i + 4], t1);
48+
_mm256_storeu_pd(&x[i + 8], t2);
49+
_mm256_storeu_pd(&x[i +12], t3);
50+
51+
t0 = _mm256_mul_pd(s_256, x0);
52+
t1 = _mm256_mul_pd(s_256, x1);
53+
t2 = _mm256_mul_pd(s_256, x2);
54+
t3 = _mm256_mul_pd(s_256, x3);
55+
56+
t0 = _mm256_fmsub_pd(c_256, y0, t0);
57+
t1 = _mm256_fmsub_pd(c_256, y1, t1);
58+
t2 = _mm256_fmsub_pd(c_256, y2, t2);
59+
t3 = _mm256_fmsub_pd(c_256, y3, t3);
60+
61+
_mm256_storeu_pd(&y[i + 0], t0);
62+
_mm256_storeu_pd(&y[i + 4], t1);
63+
_mm256_storeu_pd(&y[i + 8], t2);
64+
_mm256_storeu_pd(&y[i +12], t3);
65+
66+
}
67+
68+
for (i = tail_index_16; i < tail_index_4; i += 4) {
69+
x0 = _mm256_loadu_pd(&x[i]);
70+
y0 = _mm256_loadu_pd(&y[i]);
71+
72+
t0 = _mm256_mul_pd(s_256, y0);
73+
t0 = _mm256_fmadd_pd(c_256, x0, t0);
74+
_mm256_storeu_pd(&x[i], t0);
75+
76+
t0 = _mm256_mul_pd(s_256, x0);
77+
t0 = _mm256_fmsub_pd(c_256, y0, t0);
78+
_mm256_storeu_pd(&y[i], t0);
79+
}
80+
81+
for (i = tail_index_4; i < n; ++i) {
82+
FLOAT temp = c * x[i] + s * y[i];
83+
y[i] = c * y[i] - s * x[i];
84+
x[i] = temp;
85+
}
86+
}
87+
#endif
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/* need a new enough GCC for avx512 support */
2+
#if (( defined(__GNUC__) && __GNUC__ > 6 && defined(__AVX512CD__)) || (defined(__clang__) && __clang_major__ >= 9))
3+
4+
#define HAVE_DROT_KERNEL 1
5+
6+
#include <immintrin.h>
7+
#include <stdint.h>
8+
9+
static void drot_kernel(BLASLONG n, FLOAT *x, FLOAT *y, FLOAT c, FLOAT s)
10+
{
11+
BLASLONG i = 0;
12+
BLASLONG n1 = n;
13+
14+
BLASLONG tail_index_8 = 0;
15+
BLASLONG tail_index_32 = 0;
16+
17+
__m512d c_512 = _mm512_set1_pd(c);
18+
__m512d s_512 = _mm512_set1_pd(s);
19+
20+
tail_index_8 = n1 & (~7);
21+
tail_index_32 = n1 & (~31);
22+
23+
24+
__m512d x0, x1, x2, x3;
25+
__m512d y0, y1, y2, y3;
26+
__m512d t0, t1, t2, t3;
27+
28+
for (i = 0; i < tail_index_32; i += 32) {
29+
x0 = _mm512_loadu_pd(&x[i + 0]);
30+
x1 = _mm512_loadu_pd(&x[i + 8]);
31+
x2 = _mm512_loadu_pd(&x[i +16]);
32+
x3 = _mm512_loadu_pd(&x[i +24]);
33+
y0 = _mm512_loadu_pd(&y[i + 0]);
34+
y1 = _mm512_loadu_pd(&y[i + 8]);
35+
y2 = _mm512_loadu_pd(&y[i +16]);
36+
y3 = _mm512_loadu_pd(&y[i +24]);
37+
38+
t0 = _mm512_mul_pd(s_512, y0);
39+
t1 = _mm512_mul_pd(s_512, y1);
40+
t2 = _mm512_mul_pd(s_512, y2);
41+
t3 = _mm512_mul_pd(s_512, y3);
42+
43+
t0 = _mm512_fmadd_pd(c_512, x0, t0);
44+
t1 = _mm512_fmadd_pd(c_512, x1, t1);
45+
t2 = _mm512_fmadd_pd(c_512, x2, t2);
46+
t3 = _mm512_fmadd_pd(c_512, x3, t3);
47+
48+
_mm512_storeu_pd(&x[i + 0], t0);
49+
_mm512_storeu_pd(&x[i + 8], t1);
50+
_mm512_storeu_pd(&x[i +16], t2);
51+
_mm512_storeu_pd(&x[i +24], t3);
52+
53+
t0 = _mm512_mul_pd(s_512, x0);
54+
t1 = _mm512_mul_pd(s_512, x1);
55+
t2 = _mm512_mul_pd(s_512, x2);
56+
t3 = _mm512_mul_pd(s_512, x3);
57+
58+
t0 = _mm512_fmsub_pd(c_512, y0, t0);
59+
t1 = _mm512_fmsub_pd(c_512, y1, t1);
60+
t2 = _mm512_fmsub_pd(c_512, y2, t2);
61+
t3 = _mm512_fmsub_pd(c_512, y3, t3);
62+
63+
_mm512_storeu_pd(&y[i + 0], t0);
64+
_mm512_storeu_pd(&y[i + 8], t1);
65+
_mm512_storeu_pd(&y[i +16], t2);
66+
_mm512_storeu_pd(&y[i +24], t3);
67+
}
68+
69+
for (i = tail_index_32; i < tail_index_8; i += 8) {
70+
x0 = _mm512_loadu_pd(&x[i]);
71+
y0 = _mm512_loadu_pd(&y[i]);
72+
73+
t0 = _mm512_mul_pd(s_512, y0);
74+
t0 = _mm512_fmadd_pd(c_512, x0, t0);
75+
_mm512_storeu_pd(&x[i], t0);
76+
77+
t0 = _mm512_mul_pd(s_512, x0);
78+
t0 = _mm512_fmsub_pd(c_512, y0, t0);
79+
_mm512_storeu_pd(&y[i], t0);
80+
}
81+
82+
if ((n1&7) > 0) {
83+
unsigned char tail_mask8 = (((unsigned char) 0xff) >> (8 -(n1&7)));
84+
__m512d tail_x = _mm512_maskz_loadu_pd(*((__mmask8*) &tail_mask8), &x[tail_index_8]);
85+
__m512d tail_y = _mm512_maskz_loadu_pd(*((__mmask8*) &tail_mask8), &y[tail_index_8]);
86+
__m512d temp = _mm512_mul_pd(s_512, tail_y);
87+
temp = _mm512_fmadd_pd(c_512, tail_x, temp);
88+
_mm512_mask_storeu_pd(&x[tail_index_8],*((__mmask8*)&tail_mask8), temp);
89+
temp = _mm512_mul_pd(s_512, tail_x);
90+
temp = _mm512_fmsub_pd(c_512, tail_y, temp);
91+
_mm512_mask_storeu_pd(&y[tail_index_8], *((__mmask8*)&tail_mask8), temp);
92+
}
93+
}
94+
#endif

0 commit comments

Comments
 (0)