@@ -5170,9 +5170,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
5170
5170
5171
5171
test_cases.emplace_back (new test_l2_norm (GGML_TYPE_F32, {64 , 5 , 4 , 3 }, 1e-12f ));
5172
5172
5173
- test_cases.emplace_back (new test_ssm_conv (GGML_TYPE_F32, {4 , 1536 , 1 , 1 }, {4 , 1536 , 1 , 1 }));
5174
- test_cases.emplace_back (new test_ssm_conv (GGML_TYPE_F32, {8 , 1536 , 1 , 1 }, {4 , 1536 , 1 , 1 }));
5175
- test_cases.emplace_back (new test_ssm_conv (GGML_TYPE_F32, {4 , 1536 , 4 , 1 }, {4 , 1536 , 1 , 1 }));
5173
+ for (int64_t d_conv : {3 , 4 }) {
5174
+ for (int64_t d_inner: {1024 , 1536 , 2048 }) {
5175
+ test_cases.emplace_back (new test_ssm_conv (GGML_TYPE_F32, {4 , d_inner, 1 , 1 }, {d_conv, d_inner, 1 , 1 }));
5176
+ test_cases.emplace_back (new test_ssm_conv (GGML_TYPE_F32, {8 , d_inner, 1 , 1 }, {d_conv, d_inner, 1 , 1 }));
5177
+ test_cases.emplace_back (new test_ssm_conv (GGML_TYPE_F32, {4 , d_inner, 4 , 1 }, {d_conv, d_inner, 1 , 1 }));
5178
+ }
5179
+ }
5176
5180
5177
5181
test_cases.emplace_back (new test_ssm_scan (GGML_TYPE_F32, 16 , 1 , 1024 , 1 , 32 , 4 )); // Mamba-1
5178
5182
test_cases.emplace_back (new test_ssm_scan (GGML_TYPE_F32, 128 , 64 , 16 , 2 , 32 , 4 )); // Mamba-2
0 commit comments