Skip to content

Commit f788897

Browse files
authored
Add tests for attention matmul for gemm kernels
Differential Revision: D71833062 Pull Request resolved: #2041
1 parent c99e37c commit f788897

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

torchao/experimental/kernels/cpu/aarch64/tests/test_qmatmul.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ static void test_fp32_attn_scores_at_v_matmul_attention(
569569
b, s_attn, s_v, h, d, transpose_v);
570570

571571
using namespace torchao::kernels::cpu::aarch64::quantized_matmul::
572-
fp32_a_input_channelwise_8bit_b_1x16x4_f32;
572+
fp32_a_input_channelwise_8bit_b_f32;
573573

574574
size_t attn_b_stride = test_case.b_attn_stride;
575575
size_t attn_h_stride = test_case.h_attn_stride;
@@ -644,4 +644,14 @@ TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDimNoTranspose) {
644644
test_fp32_attn_scores_at_v_matmul_attention(1, 7, 9, 7, 33, false);
645645
}
646646

647+
TEST(test_fp32_attn_scores_at_v_matmul_attention, BasicNoTranspose2) {
648+
test_fp32_attn_scores_at_v_matmul_attention(1, 13, 20, 8, 16, false);
649+
}
650+
651+
TEST(
652+
test_fp32_attn_scores_at_v_matmul_attention,
653+
PrimeSequenceDimNoTranspose2) {
654+
test_fp32_attn_scores_at_v_matmul_attention(1, 7, 17, 7, 33, false);
655+
}
656+
647657
#endif // defined(__aarch64__) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)