@@ -53,7 +53,7 @@ struct test_channelwise_8bit_channelwise_8bit_b<
53
53
const int ,
54
54
const int );
55
55
kernel_fn_type kernel_fn = nullptr ;
56
- if (use_gemm && (m % 4 == 0 ) && (n % 8 == 0 ) && (k % 8 == 0 )) {
56
+ if (use_gemm && (m % 4 == 0 ) && (n % 8 == 0 ) && (k % 16 == 0 )) {
57
57
using namespace torchao ::kernels::cpu::aarch64::quantized_matmul::
58
58
channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot;
59
59
kernel_fn = kernel<a_has_zeros, b_has_zeros, false , true >;
@@ -531,9 +531,6 @@ static void test_8bit_per_token_q_at_k_matmul_attention(
531
531
channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case::
532
532
generate (b, s_q, s_k, h, d, transpose);
533
533
534
- using namespace torchao ::kernels::cpu::aarch64::quantized_matmul::
535
- channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot;
536
-
537
534
size_t q_b_stride = test_case.b_q_stride ;
538
535
size_t q_h_stride = test_case.h_q_stride ;
539
536
size_t q_s_q_stride = test_case.s_q_stride ;
@@ -553,9 +550,36 @@ static void test_8bit_per_token_q_at_k_matmul_attention(
553
550
size_t output_h_stride = s_q * s_k;
554
551
size_t output_s_q_stride = s_k;
555
552
553
+ using kernel_fn_type = void (*)(
554
+ int ,
555
+ int ,
556
+ int ,
557
+ const void *,
558
+ int ,
559
+ const void *,
560
+ int ,
561
+ float *,
562
+ int ,
563
+ const int8_t *,
564
+ const int8_t *,
565
+ const float *,
566
+ const float *,
567
+ const int ,
568
+ const int );
569
+ kernel_fn_type kernel_fn = nullptr ;
570
+ if ((s_q % 4 == 0 ) && (s_k % 8 == 0 ) && (d % 16 == 0 )) {
571
+ using namespace torchao ::kernels::cpu::aarch64::quantized_matmul::
572
+ channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot;
573
+ kernel_fn = kernel<true , true , false , true >;
574
+ } else {
575
+ using namespace torchao ::kernels::cpu::aarch64::quantized_matmul::
576
+ channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot;
577
+ kernel_fn = kernel<true , true , false , true >;
578
+ }
579
+
556
580
for (int b_idx = 0 ; b_idx < b; b_idx++) {
557
581
for (int h_idx = 0 ; h_idx < h; h_idx++) {
558
- kernel< true , true , false , true > (
582
+ kernel_fn (
559
583
s_q,
560
584
s_k,
561
585
d,
@@ -587,6 +611,14 @@ TEST(test_8bit_per_token_q_at_k_matmul_attention, Basic) {
587
611
test_8bit_per_token_q_at_k_matmul_attention (1 , 16 , 16 , 8 , 16 );
588
612
}
589
613
614
+ TEST (test_8bit_per_token_q_at_k_matmul_attention, BasicGemmKernel) {
615
+ test_8bit_per_token_q_at_k_matmul_attention (1 , 4 , 16 , 4 , 16 );
616
+ }
617
+
618
+ TEST (test_8bit_per_token_q_at_k_matmul_attention, BasicGemmKernelNoTranspose) {
619
+ test_8bit_per_token_q_at_k_matmul_attention (1 , 4 , 16 , 4 , 16 , false );
620
+ }
621
+
590
622
TEST (test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndHeadDim) {
591
623
test_8bit_per_token_q_at_k_matmul_attention (1 , 8 , 8 , 7 , 33 );
592
624
}
0 commit comments