@@ -33,6 +33,8 @@ namespace vec_op {
33
33
#endif
34
34
35
35
#define FORCE_INLINE __attribute__ ((always_inline)) inline
36
+ // Number of elements in single ASIMD vector of given Datatype
37
+ #define NUM_ELEMENTS_REG (vec ) (sizeof (vec) / sizeof (vec[0 ]))
36
38
37
39
namespace {
38
40
template <typename T, T... indexes, typename F>
@@ -86,8 +88,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
86
88
}
87
89
88
90
void save (void * ptr, const int elem_num) const {
89
- int full_blocks = elem_num / 8 ;
90
- int remainder = elem_num % 8 ;
91
+ int full_blocks = elem_num / NUM_ELEMENTS_REG (reg. val [ 0 ]) ;
92
+ int remainder = elem_num % NUM_ELEMENTS_REG (reg. val [ 0 ]) ;
91
93
92
94
if (full_blocks > 0 ) {
93
95
vst1q_f16 (reinterpret_cast <__fp16*>(ptr), reg.val [0 ]);
@@ -197,6 +199,25 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
197
199
vcvtq_high_bf16_f32 (vcvtq_low_bf16_f32 (v.val [2 ]), v.val [3 ])}) {};
198
200
199
201
void save (void * ptr) const { *reinterpret_cast <bfloat16x8x2_t *>(ptr) = reg; };
202
+ void save (void * ptr, const int elem_num) const {
203
+ int full_blocks = elem_num / NUM_ELEMENTS_REG (reg.val [0 ]);
204
+ int remainder = elem_num % NUM_ELEMENTS_REG (reg.val [0 ]);
205
+ for (int i = 0 ; i < full_blocks; i++)
206
+ vst1q_bf16 (
207
+ reinterpret_cast <__bf16*>(ptr) + NUM_ELEMENTS_REG (reg.val [0 ]) * i,
208
+ reg.val [i]);
209
+ if (remainder > 0 ) {
210
+ bfloat16x8_t temp = reg.val [full_blocks];
211
+ bfloat16_t * base = reinterpret_cast <bfloat16_t *>(ptr) + full_blocks * 8 ;
212
+ if (remainder > 0 ) base[0 ] = vgetq_lane_bf16 (temp, 0 );
213
+ if (remainder > 1 ) base[1 ] = vgetq_lane_bf16 (temp, 1 );
214
+ if (remainder > 2 ) base[2 ] = vgetq_lane_bf16 (temp, 2 );
215
+ if (remainder > 3 ) base[3 ] = vgetq_lane_bf16 (temp, 3 );
216
+ if (remainder > 4 ) base[4 ] = vgetq_lane_bf16 (temp, 4 );
217
+ if (remainder > 5 ) base[5 ] = vgetq_lane_bf16 (temp, 5 );
218
+ if (remainder > 6 ) base[6 ] = vgetq_lane_bf16 (temp, 6 );
219
+ }
220
+ };
200
221
};
201
222
202
223
struct BF16Vec32 : public Vec <BF16Vec32> {
@@ -213,6 +234,25 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
213
234
: reg({vec8_data.reg , vec8_data.reg , vec8_data.reg , vec8_data.reg }) {};
214
235
215
236
void save (void * ptr) const { *reinterpret_cast <bfloat16x8x4_t *>(ptr) = reg; };
237
+ void save (void * ptr, const int elem_num) const {
238
+ int full_blocks = elem_num / NUM_ELEMENTS_REG (reg.val [0 ]);
239
+ int remainder = elem_num % NUM_ELEMENTS_REG (reg.val [0 ]);
240
+ for (int i = 0 ; i < full_blocks; i++)
241
+ vst1q_bf16 (
242
+ reinterpret_cast <__bf16*>(ptr) + NUM_ELEMENTS_REG (reg.val [0 ]) * i,
243
+ reg.val [i]);
244
+ if (remainder > 0 ) {
245
+ bfloat16x8_t temp = reg.val [full_blocks];
246
+ bfloat16_t * base = reinterpret_cast <bfloat16_t *>(ptr) + full_blocks * 8 ;
247
+ base[0 ] = vgetq_lane_bf16 (temp, 0 );
248
+ if (remainder > 1 ) base[1 ] = vgetq_lane_bf16 (temp, 1 );
249
+ if (remainder > 2 ) base[2 ] = vgetq_lane_bf16 (temp, 2 );
250
+ if (remainder > 3 ) base[3 ] = vgetq_lane_bf16 (temp, 3 );
251
+ if (remainder > 4 ) base[4 ] = vgetq_lane_bf16 (temp, 4 );
252
+ if (remainder > 5 ) base[5 ] = vgetq_lane_bf16 (temp, 5 );
253
+ if (remainder > 6 ) base[6 ] = vgetq_lane_bf16 (temp, 6 );
254
+ }
255
+ };
216
256
};
217
257
#endif
218
258
@@ -372,6 +412,48 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
372
412
}
373
413
};
374
414
415
+ struct INT32Vec16 : public Vec <INT32Vec16> {
416
+ constexpr static int VEC_ELEM_NUM = 16 ;
417
+ union AliasReg {
418
+ int32x4x4_t reg;
419
+ int32_t values[VEC_ELEM_NUM];
420
+ };
421
+ int32x4x4_t reg;
422
+
423
+ explicit INT32Vec16 (const void * ptr) {
424
+ reg.val [0 ] = vld1q_s32 (reinterpret_cast <const int32_t *>(ptr));
425
+ reg.val [1 ] = vld1q_s32 (reinterpret_cast <const int32_t *>(ptr) + 4 );
426
+ reg.val [2 ] = vld1q_s32 (reinterpret_cast <const int32_t *>(ptr) + 8 );
427
+ reg.val [3 ] = vld1q_s32 (reinterpret_cast <const int32_t *>(ptr) + 12 );
428
+ }
429
+
430
+ void save (int32_t * ptr) const {
431
+ vst1q_s32 (ptr, reg.val [0 ]);
432
+ vst1q_s32 (ptr + 4 , reg.val [1 ]);
433
+ vst1q_s32 (ptr + 8 , reg.val [2 ]);
434
+ vst1q_s32 (ptr + 12 , reg.val [3 ]);
435
+ };
436
+
437
+ void save (int32_t * ptr, const int elem_num) const {
438
+ int full_blocks = elem_num / NUM_ELEMENTS_REG (reg.val [0 ]);
439
+ int remainder = elem_num % NUM_ELEMENTS_REG (reg.val [0 ]);
440
+
441
+ for (int i = 0 ; i < full_blocks; i++)
442
+ vst1q_s32 (
443
+ reinterpret_cast <__int32_t *>(ptr) + NUM_ELEMENTS_REG (reg.val [0 ]) * i,
444
+ reg.val [i]);
445
+
446
+ if (remainder > 0 ) {
447
+ int32x4_t temp = reg.val [full_blocks];
448
+ int32_t * base = reinterpret_cast <int32_t *>(ptr) + full_blocks * 4 ;
449
+ if (remainder > 0 ) base[0 ] = vgetq_lane_s32 (temp, 0 );
450
+ if (remainder > 1 ) base[1 ] = vgetq_lane_s32 (temp, 1 );
451
+ if (remainder > 2 ) base[2 ] = vgetq_lane_s32 (temp, 2 );
452
+ if (remainder > 3 ) base[3 ] = vgetq_lane_s32 (temp, 3 );
453
+ }
454
+ }
455
+ };
456
+
375
457
struct FP32Vec16 : public Vec <FP32Vec16> {
376
458
constexpr static int VEC_ELEM_NUM = 16 ;
377
459
union AliasReg {
@@ -434,7 +516,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
434
516
reg.val [2 ] = vcvt_f32_f16 (vget_low_f16 (v.reg .val [1 ]));
435
517
reg.val [3 ] = vcvt_f32_f16 (vget_high_f16 (v.reg .val [1 ]));
436
518
};
437
-
519
+ explicit FP32Vec16 (const INT32Vec16& v) {
520
+ reg.val [0 ] = vcvtq_f32_s32 (v.reg .val [0 ]);
521
+ reg.val [1 ] = vcvtq_f32_s32 (v.reg .val [1 ]);
522
+ reg.val [2 ] = vcvtq_f32_s32 (v.reg .val [2 ]);
523
+ reg.val [3 ] = vcvtq_f32_s32 (v.reg .val [3 ]);
524
+ };
438
525
FP32Vec16 operator +(const FP32Vec16& b) const {
439
526
return FP32Vec16 (float32x4x4_t ({vaddq_f32 (reg.val [0 ], b.reg .val [0 ]),
440
527
vaddq_f32 (reg.val [1 ], b.reg .val [1 ]),
@@ -463,6 +550,85 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
463
550
vdivq_f32 (reg.val [3 ], b.reg .val [3 ])}));
464
551
};
465
552
553
+ FP32Vec16 clamp (const FP32Vec16& min, const FP32Vec16& max) const {
554
+ return FP32Vec16 (float32x4x4_t (
555
+ {vminq_f32 (max.reg .val [0 ], vmaxq_f32 (min.reg .val [0 ], reg.val [0 ])),
556
+ vminq_f32 (max.reg .val [1 ], vmaxq_f32 (min.reg .val [1 ], reg.val [1 ])),
557
+ vminq_f32 (max.reg .val [2 ], vmaxq_f32 (min.reg .val [2 ], reg.val [2 ])),
558
+ vminq_f32 (max.reg .val [3 ], vmaxq_f32 (min.reg .val [3 ], reg.val [3 ]))}));
559
+ };
560
+
561
+ FP32Vec16 max (const FP32Vec16& b) const {
562
+ return FP32Vec16 (float32x4x4_t ({vmaxq_f32 (b.reg .val [0 ], reg.val [0 ]),
563
+ vmaxq_f32 (b.reg .val [1 ], reg.val [1 ]),
564
+ vmaxq_f32 (b.reg .val [2 ], reg.val [2 ]),
565
+ vmaxq_f32 (b.reg .val [3 ], reg.val [3 ])}));
566
+ };
567
+
568
+ FP32Vec16 max (const FP32Vec16& b, const int elem_num) const {
569
+ int full_blocks = elem_num / NUM_ELEMENTS_REG (reg.val [0 ]);
570
+ int remainder = elem_num % NUM_ELEMENTS_REG (reg.val [0 ]);
571
+ float32x4x4_t temp;
572
+
573
+ for (int i = 0 ; i < full_blocks; i++)
574
+ temp.val [i] = vmaxq_f32 (b.reg .val [i], reg.val [i]);
575
+
576
+ if (remainder > 0 ) {
577
+ float max_v = std::max (vgetq_lane_f32 (reg.val [full_blocks], 0 ),
578
+ vgetq_lane_f32 (b.reg .val [full_blocks], 0 ));
579
+ temp.val [full_blocks] = vsetq_lane_f32 (max_v, temp.val [full_blocks], 0 );
580
+ }
581
+ if (remainder > 1 ) {
582
+ float max_v = std::max (vgetq_lane_f32 (reg.val [full_blocks], 1 ),
583
+ vgetq_lane_f32 (b.reg .val [full_blocks], 1 ));
584
+ temp.val [full_blocks] = vsetq_lane_f32 (max_v, temp.val [full_blocks], 1 );
585
+ }
586
+ if (remainder > 2 ) {
587
+ float max_v = std::max (vgetq_lane_f32 (reg.val [full_blocks], 2 ),
588
+ vgetq_lane_f32 (b.reg .val [full_blocks], 2 ));
589
+ temp.val [full_blocks] = vsetq_lane_f32 (max_v, temp.val [full_blocks], 2 );
590
+ }
591
+ return FP32Vec16 (temp);
592
+ };
593
+
594
+ FP32Vec16 min (const FP32Vec16& b) const {
595
+ return FP32Vec16 (float32x4x4_t ({
596
+ vminq_f32 (b.reg .val [0 ], reg.val [0 ]),
597
+ vminq_f32 (b.reg .val [1 ], reg.val [1 ]),
598
+ vminq_f32 (b.reg .val [2 ], reg.val [2 ]),
599
+ vminq_f32 (b.reg .val [3 ], reg.val [3 ]),
600
+ }));
601
+ };
602
+ FP32Vec16 min (const FP32Vec16& b, const int elem_num) const {
603
+ int full_blocks = elem_num / NUM_ELEMENTS_REG (reg.val [0 ]);
604
+ const int remainder = elem_num % NUM_ELEMENTS_REG (reg.val [0 ]);
605
+ float32x4x4_t temp;
606
+ for (int i = 0 ; i < full_blocks; i++)
607
+ temp.val [i] = vminq_f32 (b.reg .val [i], reg.val [i]);
608
+
609
+ if (remainder > 0 ) {
610
+ float min_v = std::min (vgetq_lane_f32 (reg.val [full_blocks], 0 ),
611
+ vgetq_lane_f32 (b.reg .val [full_blocks], 0 ));
612
+ temp.val [full_blocks] = vsetq_lane_f32 (min_v, temp.val [full_blocks], 0 );
613
+ }
614
+ if (remainder > 1 ) {
615
+ float min_v = std::min (vgetq_lane_f32 (reg.val [full_blocks], 1 ),
616
+ vgetq_lane_f32 (b.reg .val [full_blocks], 1 ));
617
+ temp.val [full_blocks] = vsetq_lane_f32 (min_v, temp.val [full_blocks], 1 );
618
+ }
619
+ if (remainder > 2 ) {
620
+ float min_v = std::min (vgetq_lane_f32 (reg.val [full_blocks], 2 ),
621
+ vgetq_lane_f32 (b.reg .val [full_blocks], 2 ));
622
+ temp.val [full_blocks] = vsetq_lane_f32 (min_v, temp.val [full_blocks], 2 );
623
+ }
624
+
625
+ return FP32Vec16 (temp);
626
+ };
627
+ FP32Vec16 abs () const {
628
+ return FP32Vec16 (
629
+ float32x4x4_t ({vabsq_f32 (reg.val [0 ]), vabsq_f32 (reg.val [1 ]),
630
+ vabsq_f32 (reg.val [2 ]), vabsq_f32 (reg.val [3 ])}));
631
+ }
466
632
float reduce_sum () const {
467
633
AliasReg ar;
468
634
ar.reg = reg;
@@ -473,6 +639,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
473
639
return answer;
474
640
};
475
641
642
+ float reduce_max () const {
643
+ AliasReg ar;
644
+ ar.reg = reg;
645
+ float max_v = std::numeric_limits<float >::lowest ();
646
+ unroll_loop<int , VEC_ELEM_NUM>(
647
+ [&max_v, &ar](int i) { max_v = std::max (max_v, ar.values [i]); });
648
+ return max_v;
649
+ }
650
+
651
+ float reduce_min () const {
652
+ AliasReg ar;
653
+ ar.reg = reg;
654
+ float min_v = std::numeric_limits<float >::max ();
655
+ unroll_loop<int , VEC_ELEM_NUM>(
656
+ [&min_v, &ar](int i) { min_v = std::min (min_v, ar.values [i]); });
657
+ return min_v;
658
+ }
659
+
476
660
template <int group_size>
477
661
float reduce_sub_sum (int idx) {
478
662
static_assert (VEC_ELEM_NUM % group_size == 0 );
@@ -493,6 +677,83 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
493
677
vst1q_f32 (ptr + 8 , reg.val [2 ]);
494
678
vst1q_f32 (ptr + 12 , reg.val [3 ]);
495
679
};
680
+
681
+ void save (float * ptr, const int elem_num) const {
682
+ int full_blocks = elem_num / NUM_ELEMENTS_REG (reg.val [0 ]);
683
+ int remainder = elem_num % NUM_ELEMENTS_REG (reg.val [0 ]);
684
+
685
+ for (int i = 0 ; i < full_blocks; i++)
686
+ vst1q_f32 (
687
+ reinterpret_cast <float32_t *>(ptr) + NUM_ELEMENTS_REG (reg.val [0 ]) * i,
688
+ reg.val [i]);
689
+
690
+ if (remainder > 0 ) {
691
+ float32x4_t temp = reg.val [full_blocks];
692
+ float * base = reinterpret_cast <float32_t *>(ptr) +
693
+ full_blocks * NUM_ELEMENTS_REG (reg.val [0 ]);
694
+ if (remainder > 0 ) base[0 ] = vgetq_lane_f32 (temp, 0 );
695
+ if (remainder > 1 ) base[1 ] = vgetq_lane_f32 (temp, 1 );
696
+ if (remainder > 2 ) base[2 ] = vgetq_lane_f32 (temp, 2 );
697
+ }
698
+ }
699
+ };
700
+
701
+ struct INT8Vec16 : public Vec <INT8Vec16> {
702
+ constexpr static int VEC_ELEM_NUM = 16 ;
703
+ union AliasReg {
704
+ int8x16_t reg;
705
+ int8_t values[VEC_ELEM_NUM];
706
+ };
707
+ int8x16_t reg;
708
+
709
+ explicit INT8Vec16 (const FP32Vec16& vec) {
710
+ // Convert each 128-bit float32 vector to int32
711
+ int32x4_t part0 =
712
+ vcvtq_s32_f32 (vec.reg .val [0 ]); // Convert first 128-bit block
713
+ int32x4_t part1 =
714
+ vcvtq_s32_f32 (vec.reg .val [1 ]); // Convert second 128-bit block
715
+ int32x4_t part2 =
716
+ vcvtq_s32_f32 (vec.reg .val [2 ]); // Convert third 128-bit block
717
+ int32x4_t part3 =
718
+ vcvtq_s32_f32 (vec.reg .val [3 ]); // Convert fourth 128-bit block
719
+
720
+ // Narrow each 32-bit vector to 8 bits and combine
721
+ int8x8_t lower =
722
+ vqmovn_s16 (vcombine_s16 (vqmovn_s32 (part0), vqmovn_s32 (part1)));
723
+ int8x8_t upper =
724
+ vqmovn_s16 (vcombine_s16 (vqmovn_s32 (part2), vqmovn_s32 (part3)));
725
+ reg = vcombine_s8 (lower, upper); // Combine to form a single 128-bit vector
726
+ }
727
+
728
+ void save (int8_t * ptr) const { vst1q_s8 (ptr, reg); };
729
+
730
+ void save (int8_t * ptr, const int elem_num) const {
731
+ int full_blocks = elem_num / NUM_ELEMENTS_REG (reg);
732
+ int remainder = elem_num % NUM_ELEMENTS_REG (reg);
733
+
734
+ for (int i = 0 ; i < full_blocks; i++)
735
+ vst1q_s8 (reinterpret_cast <int8_t *>(ptr) + NUM_ELEMENTS_REG (reg) * i, reg);
736
+ if (remainder > 0 ) {
737
+ int8x16_t temp = reg;
738
+ int8_t * base =
739
+ reinterpret_cast <int8_t *>(ptr) + full_blocks * NUM_ELEMENTS_REG (reg);
740
+ if (remainder > 0 ) base[0 ] = vgetq_lane_s8 (temp, 0 );
741
+ if (remainder > 1 ) base[1 ] = vgetq_lane_s8 (temp, 1 );
742
+ if (remainder > 2 ) base[2 ] = vgetq_lane_s8 (temp, 2 );
743
+ if (remainder > 3 ) base[3 ] = vgetq_lane_s8 (temp, 3 );
744
+ if (remainder > 4 ) base[4 ] = vgetq_lane_s8 (temp, 4 );
745
+ if (remainder > 5 ) base[5 ] = vgetq_lane_s8 (temp, 5 );
746
+ if (remainder > 6 ) base[6 ] = vgetq_lane_s8 (temp, 6 );
747
+ if (remainder > 7 ) base[7 ] = vgetq_lane_s8 (temp, 7 );
748
+ if (remainder > 8 ) base[8 ] = vgetq_lane_s8 (temp, 8 );
749
+ if (remainder > 9 ) base[9 ] = vgetq_lane_s8 (temp, 9 );
750
+ if (remainder > 10 ) base[10 ] = vgetq_lane_s8 (temp, 10 );
751
+ if (remainder > 11 ) base[11 ] = vgetq_lane_s8 (temp, 11 );
752
+ if (remainder > 12 ) base[12 ] = vgetq_lane_s8 (temp, 12 );
753
+ if (remainder > 13 ) base[13 ] = vgetq_lane_s8 (temp, 13 );
754
+ if (remainder > 14 ) base[14 ] = vgetq_lane_s8 (temp, 14 );
755
+ }
756
+ };
496
757
};
497
758
498
759
template <typename T>
0 commit comments