Skip to content

Commit 04259eb

Browse files
authored
Add tests cases for q @ k attention variant
Differential Revision: D71936846 Pull Request resolved: #2051
1 parent 54d5a68 commit 04259eb

File tree

2 files changed

+38
-6
lines changed

2 files changed

+38
-6
lines changed

torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ struct KernelImpl<true, true, false, true> {
288288
constexpr int nr = 8;
289289
constexpr int kr = 8;
290290
assert(m % mr == 0);
291-
assert(k % kr == 0);
291+
assert(k % 16 == 0);
292292
assert(n >= nr);
293293
std::vector<int8_t> rhs_packed(n * k);
294294
// Since we are casting int8_t to float32_t in order to tranpose matrix in a

torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ struct test_channelwise_8bit_channelwise_8bit_b<
5353
const int,
5454
const int);
5555
kernel_fn_type kernel_fn = nullptr;
56-
if (use_gemm && (m % 4 == 0) && (n % 8 == 0) && (k % 8 == 0)) {
56+
if (use_gemm && (m % 4 == 0) && (n % 8 == 0) && (k % 16 == 0)) {
5757
using namespace torchao::kernels::cpu::aarch64::quantized_matmul::
5858
channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot;
5959
kernel_fn = kernel<a_has_zeros, b_has_zeros, false, true>;
@@ -531,9 +531,6 @@ static void test_8bit_per_token_q_at_k_matmul_attention(
531531
channelwise_8bit_a_channelwise_8bit_b_q_at_k_attention_test_case::
532532
generate(b, s_q, s_k, h, d, transpose);
533533

534-
using namespace torchao::kernels::cpu::aarch64::quantized_matmul::
535-
channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot;
536-
537534
size_t q_b_stride = test_case.b_q_stride;
538535
size_t q_h_stride = test_case.h_q_stride;
539536
size_t q_s_q_stride = test_case.s_q_stride;
@@ -553,9 +550,36 @@ static void test_8bit_per_token_q_at_k_matmul_attention(
553550
size_t output_h_stride = s_q * s_k;
554551
size_t output_s_q_stride = s_k;
555552

553+
using kernel_fn_type = void (*)(
554+
int,
555+
int,
556+
int,
557+
const void*,
558+
int,
559+
const void*,
560+
int,
561+
float*,
562+
int,
563+
const int8_t*,
564+
const int8_t*,
565+
const float*,
566+
const float*,
567+
const int,
568+
const int);
569+
kernel_fn_type kernel_fn = nullptr;
570+
if ((s_q % 4 == 0) && (s_k % 8 == 0) && (d % 16 == 0)) {
571+
using namespace torchao::kernels::cpu::aarch64::quantized_matmul::
572+
channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot;
573+
kernel_fn = kernel<true, true, false, true>;
574+
} else {
575+
using namespace torchao::kernels::cpu::aarch64::quantized_matmul::
576+
channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot;
577+
kernel_fn = kernel<true, true, false, true>;
578+
}
579+
556580
for (int b_idx = 0; b_idx < b; b_idx++) {
557581
for (int h_idx = 0; h_idx < h; h_idx++) {
558-
kernel<true, true, false, true>(
582+
kernel_fn(
559583
s_q,
560584
s_k,
561585
d,
@@ -587,6 +611,14 @@ TEST(test_8bit_per_token_q_at_k_matmul_attention, Basic) {
587611
test_8bit_per_token_q_at_k_matmul_attention(1, 16, 16, 8, 16);
588612
}
589613

614+
TEST(test_8bit_per_token_q_at_k_matmul_attention, BasicGemmKernel) {
615+
test_8bit_per_token_q_at_k_matmul_attention(1, 4, 16, 4, 16);
616+
}
617+
618+
TEST(test_8bit_per_token_q_at_k_matmul_attention, BasicGemmKernelNoTranspose) {
619+
test_8bit_per_token_q_at_k_matmul_attention(1, 4, 16, 4, 16, false);
620+
}
621+
590622
TEST(test_8bit_per_token_q_at_k_matmul_attention, PrimeHeadsAndHeadDim) {
591623
test_8bit_per_token_q_at_k_matmul_attention(1, 8, 8, 7, 33);
592624
}

0 commit comments

Comments
 (0)