1
1
#include "common.h"
2
2
#include <stdint.h>
3
3
#include <immintrin.h>
4
+ //register usage: zmm3 for alpha, zmm4-zmm7 for temporary use, zmm8-zmm31 for accumulators.
4
5
/* row-major c_block */
5
- /* 64-bit pointer registers: a_block_pointer,b_block_pointer,c_pointer;*/
6
6
#define INNER_KERNEL_k1m1n8 \
7
7
"prefetcht0 384(%1);"\
8
8
"prefetcht0 768(%0); vmovupd (%1),%%zmm5; addq $64,%1;"\
158
158
#define INNER_STORE_m1n8 (c1 ,disp ) \
159
159
"kxnorw %%k1,%%k1,%%k1;"\
160
160
"vgatherqpd "#disp"(%3,%%zmm6,1), %%zmm7 %{%%k1%};"\
161
- "vaddpd %%zmm7,"#c1" ,"#c1";"\
161
+ "vfmadd132pd %%zmm3,%%zmm7 ,"#c1";"\
162
162
"kxnorw %%k1,%%k1,%%k1;"\
163
163
"vscatterqpd "#c1", "#disp"(%3,%%zmm6,1) %{%%k1%};"
164
164
227
227
"vblendmpd "#c8","#c4",%%zmm7%{%5%};vshuff64x2 $0x4e,%%zmm7,%%zmm7,%%zmm7;"\
228
228
"vblendmpd "#c4",%%zmm7,"#c4"%{%5%};vblendmpd %%zmm7,"#c8","#c8"%{%5%};"
229
229
230
+ //%7 for k01(input) only when m=4
230
231
#define INNER_STORE_4x8 (c1 ,c2 ,c3 ,c4 ) \
231
- "vmovupd (%3),%%zmm4%{%5%};vmovupd -32(%3,%4,4),%%zmm4%{%7%};vaddpd %%zmm4,"#c1" ,"#c1";"\
232
+ "vmovupd (%3),%%zmm4%{%5%};vmovupd -32(%3,%4,4),%%zmm4%{%7%};vfmadd132pd %%zmm3,%%zmm4 ,"#c1";"\
232
233
"vmovupd "#c1",(%3)%{%5%}; vmovupd "#c1",-32(%3,%4,4)%{%7%}; leaq (%3,%4,1),%3;"\
233
- "vmovupd (%3),%%zmm5%{%5%};vmovupd -32(%3,%4,4),%%zmm5%{%7%};vaddpd %%zmm5,"#c2" ,"#c2";"\
234
+ "vmovupd (%3),%%zmm5%{%5%};vmovupd -32(%3,%4,4),%%zmm5%{%7%};vfmadd132pd %%zmm3,%%zmm5 ,"#c2";"\
234
235
"vmovupd "#c2",(%3)%{%5%}; vmovupd "#c2",-32(%3,%4,4)%{%7%}; leaq (%3,%4,1),%3;"\
235
- "vmovupd (%3),%%zmm6%{%5%};vmovupd -32(%3,%4,4),%%zmm6%{%7%};vaddpd %%zmm6,"#c3" ,"#c3";"\
236
+ "vmovupd (%3),%%zmm6%{%5%};vmovupd -32(%3,%4,4),%%zmm6%{%7%};vfmadd132pd %%zmm3,%%zmm6 ,"#c3";"\
236
237
"vmovupd "#c3",(%3)%{%5%}; vmovupd "#c3",-32(%3,%4,4)%{%7%}; leaq (%3,%4,1),%3;"\
237
- "vmovupd (%3),%%zmm7%{%5%};vmovupd -32(%3,%4,4),%%zmm7%{%7%};vaddpd %%zmm7,"#c4" ,"#c4";"\
238
+ "vmovupd (%3),%%zmm7%{%5%};vmovupd -32(%3,%4,4),%%zmm7%{%7%};vfmadd132pd %%zmm3,%%zmm7 ,"#c4";"\
238
239
"vmovupd "#c4",(%3)%{%5%}; vmovupd "#c4",-32(%3,%4,4)%{%7%}; leaq (%3,%4,1),%3;"\
239
240
"leaq (%3,%4,4),%3;"
240
241
241
242
#define INNER_STORE_8x8 (c1 ,c2 ,c3 ,c4 ,c5 ,c6 ,c7 ,c8 ) \
242
243
"prefetcht1 120(%3); prefetcht1 120(%3,%4,1);"\
243
- "vaddpd (%3),"#c1" ,"#c1"; vmovupd "#c1",(%3); vaddpd (%3,%4,1),"#c2" ,"#c2"; vmovupd "#c2",(%3,%4,1); leaq (%3,%4,2),%3;"\
244
+ "vfmadd213pd (%3),%%zmm3 ,"#c1"; vmovupd "#c1",(%3); vfmadd213pd (%3,%4,1),%%zmm3 ,"#c2"; vmovupd "#c2",(%3,%4,1); leaq (%3,%4,2),%3;"\
244
245
"prefetcht1 120(%3); prefetcht1 120(%3,%4,1);"\
245
- "vaddpd (%3),"#c3" ,"#c3"; vmovupd "#c3",(%3); vaddpd (%3,%4,1),"#c4" ,"#c4"; vmovupd "#c4",(%3,%4,1); leaq (%3,%4,2),%3;"\
246
+ "vfmadd213pd (%3),%%zmm3 ,"#c3"; vmovupd "#c3",(%3); vfmadd213pd (%3,%4,1),%%zmm3 ,"#c4"; vmovupd "#c4",(%3,%4,1); leaq (%3,%4,2),%3;"\
246
247
"prefetcht1 120(%3); prefetcht1 120(%3,%4,1);"\
247
- "vaddpd (%3),"#c5" ,"#c5"; vmovupd "#c5",(%3); vaddpd (%3,%4,1),"#c6" ,"#c6"; vmovupd "#c6",(%3,%4,1); leaq (%3,%4,2),%3;"\
248
+ "vfmadd213pd (%3),%%zmm3 ,"#c5"; vmovupd "#c5",(%3); vfmadd213pd (%3,%4,1),%%zmm3 ,"#c6"; vmovupd "#c6",(%3,%4,1); leaq (%3,%4,2),%3;"\
248
249
"prefetcht1 120(%3); prefetcht1 120(%3,%4,1);"\
249
- "vaddpd (%3),"#c7" ,"#c7"; vmovupd "#c7",(%3); vaddpd (%3,%4,1),"#c8" ,"#c8"; vmovupd "#c8",(%3,%4,1); leaq (%3,%4,2),%3;"
250
+ "vfmadd213pd (%3),%%zmm3 ,"#c7"; vmovupd "#c7",(%3); vfmadd213pd (%3,%4,1),%%zmm3 ,"#c8"; vmovupd "#c8",(%3,%4,1); leaq (%3,%4,2),%3;"
250
251
251
252
#define INNER_SAVE_m4n8 \
252
253
INNER_TRANS_4x8(%%zmm8,%%zmm9,%%zmm10,%%zmm11)\
292
293
293
294
#define COMPUTE_n8 {\
294
295
__asm__ __volatile__(\
296
+ "vbroadcastsd (%9),%%zmm3;"\
295
297
"movq %8,%%r14;movq %2,%%r13;movq %2,%%r12;shlq $6,%%r12;"\
296
298
"cmpq $8,%8; jb 42222f;"\
297
299
"42221:\n\t"\
327
329
"42225:\n\t"\
328
330
"movq %%r14,%8;shlq $3,%8;subq %8,%3;shrq $3,%8;"\
329
331
"shlq $3,%4;addq %4,%3;shrq $3,%4;"\
330
- :"+r"(a_block_pointer),"+r"(packed_b_pointer),"+r"(K),"+r"(c_pointer),"+r"(ldc_in_bytes),"+Yk"(k02),"+Yk"(k03),"+Yk"(k01),"+r"(M)\
331
- ::"zmm4","zmm5","zmm6","zmm7","zmm8","zmm9","zmm10","zmm11","zmm12","zmm13","zmm14","zmm15","cc","memory","k1","r13","r14");\
332
+ :"+r"(a_block_pointer),"+r"(packed_b_pointer),"+r"(K),"+r"(c_pointer),"+r"(ldc_in_bytes),"+Yk"(k02),"+Yk"(k03),"+Yk"(k01),"+r"(M),"+r"(alpha) \
333
+ ::"zmm3"," zmm4","zmm5","zmm6","zmm7","zmm8","zmm9","zmm10","zmm11","zmm12","zmm13","zmm14","zmm15","cc","memory","k1","r13","r14");\
332
334
a_block_pointer -= M * K;\
333
335
}
334
336
#define COMPUTE_n16 {\
335
337
__asm__ __volatile__(\
338
+ "vbroadcastsd (%9),%%zmm3;"\
336
339
"movq %8,%%r14;movq %2,%%r13;movq %2,%%r12;shlq $6,%%r12;"\
337
340
"cmpq $8,%8; jb 32222f;"\
338
341
"32221:\n\t"\
369
372
"movq %%r14,%8;shlq $3,%8;subq %8,%3;shrq $3,%8;"\
370
373
"shlq $4,%4;addq %4,%3;shrq $4,%4;"\
371
374
"leaq (%1,%%r12,2),%1;"\
372
- :"+r"(a_block_pointer),"+r"(packed_b_pointer),"+r"(K),"+r"(c_pointer),"+r"(ldc_in_bytes),"+Yk"(k02),"+Yk"(k03),"+Yk"(k01),"+r"(M)\
373
- ::"zmm4","zmm5","zmm6","zmm7","zmm8","zmm9","zmm10","zmm11","zmm12","zmm13","zmm14","zmm15","zmm16","zmm17",\
375
+ :"+r"(a_block_pointer),"+r"(packed_b_pointer),"+r"(K),"+r"(c_pointer),"+r"(ldc_in_bytes),"+Yk"(k02),"+Yk"(k03),"+Yk"(k01),"+r"(M),"+r"(alpha) \
376
+ ::"zmm3"," zmm4","zmm5","zmm6","zmm7","zmm8","zmm9","zmm10","zmm11","zmm12","zmm13","zmm14","zmm15","zmm16","zmm17",\
374
377
"zmm18","zmm19","zmm20","zmm21","zmm22","zmm23","cc","memory","k1","r12","r13","r14");\
375
378
a_block_pointer -= M * K;\
376
379
}
377
380
#define COMPUTE_n24 {\
378
381
__asm__ __volatile__(\
382
+ "vbroadcastsd (%9),%%zmm3;"\
379
383
"movq %8,%%r14;movq %2,%%r13;movq %2,%%r12;shlq $6,%%r12;"\
380
384
"cmpq $8,%8; jb 22222f;"\
381
385
"22221:\n\t"\
412
416
"movq %%r14,%8;shlq $3,%8;subq %8,%3;shrq $3,%8;"\
413
417
"shlq $3,%4;addq %4,%3;shlq $1,%4;addq %4,%3;shrq $4,%4;"\
414
418
"leaq (%1,%%r12,2),%1; addq %%r12,%1;"\
415
- :"+r"(a_block_pointer),"+r"(packed_b_pointer),"+r"(K),"+r"(c_pointer),"+r"(ldc_in_bytes),"+Yk"(k02),"+Yk"(k03),"+Yk"(k01),"+r"(M)\
416
- ::"zmm4","zmm5","zmm6","zmm7","zmm8","zmm9","zmm10","zmm11","zmm12","zmm13","zmm14","zmm15","zmm16","zmm17","zmm18","zmm19",\
419
+ :"+r"(a_block_pointer),"+r"(packed_b_pointer),"+r"(K),"+r"(c_pointer),"+r"(ldc_in_bytes),"+Yk"(k02),"+Yk"(k03),"+Yk"(k01),"+r"(M),"+r"(alpha) \
420
+ ::"zmm3"," zmm4","zmm5","zmm6","zmm7","zmm8","zmm9","zmm10","zmm11","zmm12","zmm13","zmm14","zmm15","zmm16","zmm17","zmm18","zmm19",\
417
421
"zmm20","zmm21","zmm22","zmm23","zmm24","zmm25","zmm26","zmm27","zmm28","zmm29","zmm30","zmm31","cc","memory","k1","r12","r13","r14");\
418
422
a_block_pointer -= M * K;\
419
423
}
420
424
421
- static void __attribute__ (( noinline )) KERNEL_MAIN (double * packed_a , double * packed_b , BLASLONG m , BLASLONG ndiv8 , BLASLONG k , BLASLONG LDC , double * c ){//icopy=8,ocopy=8
425
+ static void KERNEL_MAIN (double * packed_a , double * packed_b , BLASLONG m , BLASLONG ndiv8 , BLASLONG k , BLASLONG LDC , double * c , double * alpha ){//icopy=8,ocopy=8
422
426
//perform C += A<pack> B<pack>
423
427
if (k == 0 || m == 0 || ndiv8 == 0 ) return ;
424
428
int64_t ldc_in_bytes = (int64_t )LDC * sizeof (double );
425
429
int64_t K = (int64_t )k ; int64_t M = (int64_t )m ;
426
430
double * a_block_pointer ;
427
431
double * c_pointer = c ;
428
432
__mmask16 k01 = 0x00f0 ,k02 = 0x000f ,k03 = 0x0033 ;
429
- BLASLONG ndiv8_count ;
433
+ BLASLONG m_count , ndiv8_count , k_count ;
430
434
double * packed_b_pointer = packed_b ;
431
435
a_block_pointer = packed_a ;
432
436
for (ndiv8_count = ndiv8 ;ndiv8_count > 2 ;ndiv8_count -= 3 ){
@@ -474,24 +478,27 @@ static void __attribute__ ((noinline)) KERNEL_MAIN(double *packed_a, double *pac
474
478
#define INIT_m8n2 zc2=INIT_m8n1
475
479
#define INIT_m8n4 zc4=zc3=INIT_m8n2
476
480
#define SAVE_m8n1 {\
477
- za1 = _mm512_loadu_pd(c_pointer);\
478
- zc1 = _mm512_add_pd(zc1,za1);\
481
+ __asm__ __volatile__("vbroadcastsd (%0),%1;":"+r"(alpha),"+v"(za1)::"memory");\
482
+ zb1 = _mm512_loadu_pd(c_pointer);\
483
+ zc1 = _mm512_fmadd_pd(zc1,za1,zb1);\
479
484
_mm512_storeu_pd(c_pointer,zc1);\
480
485
c_pointer += 8;\
481
486
}
482
487
#define SAVE_m8n2 {\
488
+ __asm__ __volatile__("vbroadcastsd (%0),%1;":"+r"(alpha),"+v"(za1)::"memory");\
483
489
zb1 = _mm512_loadu_pd(c_pointer); zb2 = _mm512_loadu_pd(c_pointer+LDC);\
484
- zc1 = _mm512_add_pd (zc1,zb1); zc2 = _mm512_add_pd (zc2,zb2);\
490
+ zc1 = _mm512_fmadd_pd (zc1,za1, zb1); zc2 = _mm512_fmadd_pd (zc2,za1 ,zb2);\
485
491
_mm512_storeu_pd(c_pointer,zc1); _mm512_storeu_pd(c_pointer+LDC,zc2);\
486
492
c_pointer += 8;\
487
493
}
488
494
#define SAVE_m8n4 {\
495
+ __asm__ __volatile__("vbroadcastsd (%0),%1;":"+r"(alpha),"+v"(za1)::"memory");\
489
496
zb1 = _mm512_loadu_pd(c_pointer); zb2 = _mm512_loadu_pd(c_pointer+LDC);\
490
- zc1 = _mm512_add_pd (zc1,zb1); zc2 = _mm512_add_pd (zc2,zb2);\
497
+ zc1 = _mm512_fmadd_pd (zc1,za1, zb1); zc2 = _mm512_fmadd_pd (zc2,za1 ,zb2);\
491
498
_mm512_storeu_pd(c_pointer,zc1); _mm512_storeu_pd(c_pointer+LDC,zc2);\
492
499
c_pointer += LDC*2;\
493
500
zb1 = _mm512_loadu_pd(c_pointer); zb2 = _mm512_loadu_pd(c_pointer+LDC);\
494
- zc3 = _mm512_add_pd (zc3,zb1); zc4 = _mm512_add_pd (zc4,zb2);\
501
+ zc3 = _mm512_fmadd_pd (zc3,za1, zb1); zc4 = _mm512_fmadd_pd (zc4,za1 ,zb2);\
495
502
_mm512_storeu_pd(c_pointer,zc3); _mm512_storeu_pd(c_pointer+LDC,zc4);\
496
503
c_pointer += 8-LDC*2;\
497
504
}
@@ -518,24 +525,27 @@ static void __attribute__ ((noinline)) KERNEL_MAIN(double *packed_a, double *pac
518
525
#define INIT_m4n2 yc2=INIT_m4n1
519
526
#define INIT_m4n4 yc4=yc3=INIT_m4n2
520
527
#define SAVE_m4n1 {\
528
+ yb1 = _mm256_broadcast_sd(alpha);\
521
529
ya1 = _mm256_loadu_pd(c_pointer);\
522
- yc1 = _mm256_add_pd (yc1,ya1);\
530
+ yc1 = _mm256_fmadd_pd (yc1,yb1 ,ya1);\
523
531
_mm256_storeu_pd(c_pointer,yc1);\
524
532
c_pointer += 4;\
525
533
}
526
534
#define SAVE_m4n2 {\
535
+ ya1 = _mm256_broadcast_sd(alpha);\
527
536
yb1 = _mm256_loadu_pd(c_pointer); yb2 = _mm256_loadu_pd(c_pointer+LDC);\
528
- yc1 = _mm256_add_pd (yc1,yb1); yc2 = _mm256_add_pd (yc2,yb2);\
537
+ yc1 = _mm256_fmadd_pd (yc1,ya1, yb1); yc2 = _mm256_fmadd_pd (yc2,ya1 ,yb2);\
529
538
_mm256_storeu_pd(c_pointer,yc1); _mm256_storeu_pd(c_pointer+LDC,yc2);\
530
539
c_pointer += 4;\
531
540
}
532
541
#define SAVE_m4n4 {\
542
+ ya1 = _mm256_broadcast_sd(alpha);\
533
543
yb1 = _mm256_loadu_pd(c_pointer); yb2 = _mm256_loadu_pd(c_pointer+LDC);\
534
- yc1 = _mm256_add_pd (yc1,yb1); yc2 = _mm256_add_pd (yc2,yb2);\
544
+ yc1 = _mm256_fmadd_pd (yc1,ya1, yb1); yc2 = _mm256_fmadd_pd (yc2,ya1 ,yb2);\
535
545
_mm256_storeu_pd(c_pointer,yc1); _mm256_storeu_pd(c_pointer+LDC,yc2);\
536
546
c_pointer += LDC*2;\
537
547
yb1 = _mm256_loadu_pd(c_pointer); yb2 = _mm256_loadu_pd(c_pointer+LDC);\
538
- yc3 = _mm256_add_pd (yc3,yb1); yc4 = _mm256_add_pd (yc4,yb2);\
548
+ yc3 = _mm256_fmadd_pd (yc3,ya1, yb1); yc4 = _mm256_fmadd_pd (yc4,ya1 ,yb2);\
539
549
_mm256_storeu_pd(c_pointer,yc3); _mm256_storeu_pd(c_pointer+LDC,yc4);\
540
550
c_pointer += 4-LDC*2;\
541
551
}
@@ -553,14 +563,16 @@ static void __attribute__ ((noinline)) KERNEL_MAIN(double *packed_a, double *pac
553
563
#define INIT_m2n1 xc1=_mm_setzero_pd();
554
564
#define INIT_m2n2 xc2=INIT_m2n1
555
565
#define SAVE_m2n1 {\
566
+ xb1 = _mm_loaddup_pd(alpha);\
556
567
xa1 = _mm_loadu_pd(c_pointer);\
557
- xc1 = _mm_add_pd (xc1,xa1);\
568
+ xc1 = _mm_fmadd_pd (xc1,xb1 ,xa1);\
558
569
_mm_storeu_pd(c_pointer,xc1);\
559
570
c_pointer += 2;\
560
571
}
561
572
#define SAVE_m2n2 {\
573
+ xa1 = _mm_loaddup_pd(alpha);\
562
574
xb1 = _mm_loadu_pd(c_pointer); xb2 = _mm_loadu_pd(c_pointer+LDC);\
563
- xc1 = _mm_add_pd (xc1,xb1); xc2 = _mm_add_pd (xc2,xb2);\
575
+ xc1 = _mm_fmadd_pd (xc1,xa1, xb1); xc2 = _mm_fmadd_pd (xc2,xa1 ,xb2);\
564
576
_mm_storeu_pd(c_pointer,xc1); _mm_storeu_pd(c_pointer+LDC,xc2);\
565
577
c_pointer += 2;\
566
578
}
@@ -571,7 +583,7 @@ static void __attribute__ ((noinline)) KERNEL_MAIN(double *packed_a, double *pac
571
583
}
572
584
#define INIT_m1n1 sc1=0.0;
573
585
#define SAVE_m1n1 {\
574
- *c_pointer += sc1;\
586
+ *c_pointer += sc1 * (*alpha) ;\
575
587
c_pointer++;\
576
588
}
577
589
@@ -596,6 +608,9 @@ static void __attribute__ ((noinline)) KERNEL_MAIN(double *packed_a, double *pac
596
608
#define INIT_m1n4 INIT_m4n1
597
609
#define INIT_m2n4 INIT_m4n2
598
610
#define SAVE_m2n4 {\
611
+ ya1 = _mm256_broadcast_sd(alpha);\
612
+ yc1 = _mm256_mul_pd(yc1,ya1);\
613
+ yc2 = _mm256_mul_pd(yc2,ya1);\
599
614
yb1 = _mm256_unpacklo_pd(yc1,yc2);\
600
615
yb2 = _mm256_unpackhi_pd(yc1,yc2);\
601
616
xb1 = _mm_add_pd(_mm_loadu_pd(c_pointer),_mm256_extractf128_pd(yb1,0));\
@@ -609,12 +624,16 @@ static void __attribute__ ((noinline)) KERNEL_MAIN(double *packed_a, double *pac
609
624
c_pointer += 2;\
610
625
}
611
626
#define SAVE_m1n2 {\
627
+ xb1 = _mm_loaddup_pd(alpha);\
628
+ xc1 = _mm_mul_pd(xc1,xb1);\
612
629
*c_pointer += _mm_cvtsd_f64(xc1);\
613
630
xa1 = _mm_unpackhi_pd(xc1,xc1);\
614
631
c_pointer[LDC]+= _mm_cvtsd_f64(xa1);\
615
632
c_pointer ++;\
616
633
}
617
634
#define SAVE_m1n4 {\
635
+ ya1 = _mm256_broadcast_sd(alpha);\
636
+ yc1 = _mm256_mul_pd(yc1,ya1);\
618
637
xb1 = _mm256_extractf128_pd(yc1,0);\
619
638
*c_pointer += _mm_cvtsd_f64(xb1);\
620
639
xb2 = _mm_unpackhi_pd(xb1,xb1);\
@@ -626,7 +645,7 @@ static void __attribute__ ((noinline)) KERNEL_MAIN(double *packed_a, double *pac
626
645
c_pointer ++;\
627
646
}
628
647
629
- static void KERNEL_EDGE (double * packed_a , double * packed_b , BLASLONG m , BLASLONG edge_n , BLASLONG k , BLASLONG LDC , double * c ){//icopy=8,ocopy=8
648
+ static void __attribute__ (( noinline )) KERNEL_EDGE (double * packed_a , double * packed_b , BLASLONG m , BLASLONG edge_n , BLASLONG k , BLASLONG LDC , double * c , double * alpha ){//icopy=8,ocopy=8
630
649
//perform C += A<pack> B<pack> , edge_n<8 must be satisfied !
631
650
if (k == 0 || m == 0 || edge_n == 0 ) return ;
632
651
double * a_block_pointer ,* b_block_pointer ,* b_base_pointer ;
@@ -724,30 +743,30 @@ static void KERNEL_EDGE(double *packed_a, double *packed_b, BLASLONG m, BLASLONG
724
743
}
725
744
}
726
745
}
727
- static void copy_4_to_8 (double * src ,double * dst ,BLASLONG m ,BLASLONG k , double alpha ){
728
- BLASLONG m_count ,k_count ;double * src1 ,* dst1 ,* src2 ;__m256d tmp , alp ;
729
- src1 = src ; dst1 = dst ; src2 = src1 + 4 * k ; alp = _mm256_set1_pd ( alpha );
746
+ static void copy_4_to_8 (double * src ,double * dst ,BLASLONG m ,BLASLONG k ){
747
+ BLASLONG m_count ,k_count ;double * src1 ,* dst1 ,* src2 ;__m256d tmp ;
748
+ src1 = src ; dst1 = dst ; src2 = src1 + 4 * k ;
730
749
for (m_count = m ;m_count > 7 ;m_count -= 8 ){
731
750
for (k_count = k ;k_count > 0 ;k_count -- ){
732
- tmp = _mm256_loadu_pd (src1 );tmp = _mm256_mul_pd ( tmp , alp ); _mm256_storeu_pd (dst1 + 0 ,tmp );src1 += 4 ;
733
- tmp = _mm256_loadu_pd (src2 );tmp = _mm256_mul_pd ( tmp , alp ); _mm256_storeu_pd (dst1 + 4 ,tmp );src2 += 4 ;
751
+ tmp = _mm256_loadu_pd (src1 );_mm256_storeu_pd (dst1 + 0 ,tmp );src1 += 4 ;
752
+ tmp = _mm256_loadu_pd (src2 );_mm256_storeu_pd (dst1 + 4 ,tmp );src2 += 4 ;
734
753
dst1 += 8 ;
735
754
}
736
755
src1 += 4 * k ;src2 += 4 * k ;
737
756
}
738
757
for (;m_count > 0 ;m_count -- ){
739
758
for (k_count = k ;k_count > 0 ;k_count -- ){
740
- * dst1 = (* src1 ) * alpha ; src1 ++ ; dst1 ++ ;
759
+ * dst1 = (* src1 ); src1 ++ ; dst1 ++ ;
741
760
}
742
761
}
743
762
}
744
763
int __attribute__ ((noinline )) CNAME (BLASLONG m , BLASLONG n , BLASLONG k , double alpha , double * __restrict__ A , double * __restrict__ B , double * __restrict__ C , BLASLONG ldc ){
745
- if (m == 0 || n == 0 || k == 0 ) return 0 ;
746
- BLASLONG ndiv8 = n /8 ;
764
+ if (m == 0 || n == 0 || k == 0 || alpha == 0.0 ) return 0 ;
765
+ BLASLONG ndiv8 = n /8 ;double ALPHA = alpha ;
747
766
double * packed_a = (double * )malloc (m * k * sizeof (double ));
748
- copy_4_to_8 (A ,packed_a ,m ,k , alpha );
749
- if (ndiv8 > 0 ) KERNEL_MAIN (packed_a ,B ,m ,ndiv8 ,k ,ldc ,C );
750
- if (n > ndiv8 * 8 ) KERNEL_EDGE (packed_a ,B + (int64_t )k * (int64_t )ndiv8 * 8 ,m ,n - ndiv8 * 8 ,k ,ldc ,C + (int64_t )ldc * (int64_t )ndiv8 * 8 );
767
+ copy_4_to_8 (A ,packed_a ,m ,k );
768
+ if (ndiv8 > 0 ) KERNEL_MAIN (packed_a ,B ,m ,ndiv8 ,k ,ldc ,C , & ALPHA );
769
+ if (n > ndiv8 * 8 ) KERNEL_EDGE (packed_a ,B + (int64_t )k * (int64_t )ndiv8 * 8 ,m ,n - ndiv8 * 8 ,k ,ldc ,C + (int64_t )ldc * (int64_t )ndiv8 * 8 , & ALPHA );
751
770
free (packed_a );packed_a = NULL ;
752
771
return 0 ;
753
772
}
0 commit comments