@@ -569,7 +569,7 @@ static void test_fp32_attn_scores_at_v_matmul_attention(
569
569
b, s_attn, s_v, h, d, transpose_v);
570
570
571
571
using namespace torchao ::kernels::cpu::aarch64::quantized_matmul::
572
- fp32_a_input_channelwise_8bit_b_1x16x4_f32 ;
572
+ fp32_a_input_channelwise_8bit_b_f32 ;
573
573
574
574
size_t attn_b_stride = test_case.b_attn_stride ;
575
575
size_t attn_h_stride = test_case.h_attn_stride ;
@@ -644,4 +644,14 @@ TEST(test_fp32_attn_scores_at_v_matmul_attention, PrimeSequenceDimNoTranspose) {
644
644
test_fp32_attn_scores_at_v_matmul_attention (1 , 7 , 9 , 7 , 33 , false );
645
645
}
646
646
647
+ TEST (test_fp32_attn_scores_at_v_matmul_attention, BasicNoTranspose2) {
648
+ test_fp32_attn_scores_at_v_matmul_attention (1 , 13 , 20 , 8 , 16 , false );
649
+ }
650
+
651
+ TEST (
652
+ test_fp32_attn_scores_at_v_matmul_attention,
653
+ PrimeSequenceDimNoTranspose2) {
654
+ test_fp32_attn_scores_at_v_matmul_attention (1 , 7 , 17 , 7 , 33 , false );
655
+ }
656
+
647
657
#endif // defined(__aarch64__) || defined(__ARM_NEON)
0 commit comments