@@ -1412,6 +1412,43 @@ struct test_conv_transpose_1d : public test_case {
1412
1412
}
1413
1413
};
1414
1414
1415
+ struct test_conv_transpose_1d_gemm : public test_case {
1416
+ const std::array<int64_t , 4 > ne_input;
1417
+ const std::array<int64_t , 4 > ne_kernel;
1418
+
1419
+ const int s0; // stride
1420
+ const int p0; // padding
1421
+ const int d0; // dilation
1422
+
1423
+ ggml_type input_type;
1424
+ ggml_type kernel_type;
1425
+
1426
+ std::string vars () override {
1427
+ return VARS_TO_STR5 (ne_input, ne_kernel, s0, p0, d0);
1428
+ }
1429
+
1430
+ test_conv_transpose_1d_gemm (std::array<int64_t , 4 > ne_input = {197 , 32 , 1 , 1 }, // [input_width, input_height, input_channels, 1]
1431
+ std::array<int64_t , 4 > ne_kernel = {16 , 32 , 32 , 1 }, // [kernel_width, kernel_height, input_channels, 1]
1432
+ int s0 = 1 , int p0 = 0 , int d0 = 1 ,
1433
+ ggml_type input_type = GGML_TYPE_F32,
1434
+ ggml_type kernel_type = GGML_TYPE_F16)
1435
+ : ne_input(ne_input)
1436
+ , ne_kernel(ne_kernel)
1437
+ , s0(s0)
1438
+ , p0(p0)
1439
+ , d0(d0)
1440
+ , input_type(input_type)
1441
+ , kernel_type(kernel_type)
1442
+ {}
1443
+
1444
+ ggml_tensor * build_graph (ggml_context * ctx) override {
1445
+ ggml_tensor * input = ggml_new_tensor (ctx, input_type, 4 , ne_input.data ());
1446
+ ggml_tensor * kernel = ggml_new_tensor (ctx, kernel_type, 4 , ne_kernel.data ());
1447
+ ggml_tensor * out = ggml_conv_transpose_1d_gemm (ctx, kernel, input, s0, p0, d0);
1448
+ return out;
1449
+ }
1450
+ };
1451
+
1415
1452
// GGML_OP_IM2COL
1416
1453
struct test_im2col : public test_case {
1417
1454
const ggml_type type_input;
@@ -2330,6 +2367,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
2330
2367
test_cases.emplace_back (new test_conv_transpose_1d ({3 ,2 ,1 ,1 }, {3 ,1 ,2 ,1 }, 1 , 0 , 1 ));
2331
2368
test_cases.emplace_back (new test_conv_transpose_1d ({2 ,1 ,1 ,1 }, {3 ,1 ,1 ,1 }, 1 , 0 , 1 ));
2332
2369
2370
+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ());
2371
+ for (int64_t s0 = 1 ; s0 < 4 ; ++s0) {
2372
+ for (int64_t p0 = 0 ; p0 < 2 ; ++p0) {
2373
+ for (int64_t d0 = 1 ; d0 < 4 ; ++d0) {
2374
+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {2 ,3 ,2 ,1 }, s0, p0, d0));
2375
+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {3 ,2 ,2 ,1 }, s0, p0, d0));
2376
+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {3 ,1 ,2 ,1 }, s0, p0, d0));
2377
+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({2 ,1 ,1 ,1 }, {3 ,1 ,1 ,1 }, s0, p0, d0));
2378
+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {2 ,3 ,2 ,1 },
2379
+ s0, p0, d0, GGML_TYPE_F16));
2380
+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {3 ,2 ,2 ,1 },
2381
+ s0, p0, d0, GGML_TYPE_F16));
2382
+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({3 ,2 ,1 ,1 }, {3 ,1 ,2 ,1 },
2383
+ s0, p0, d0, GGML_TYPE_F16));
2384
+ test_cases.emplace_back (new test_conv_transpose_1d_gemm ({2 ,1 ,1 ,1 }, {3 ,1 ,1 ,1 },
2385
+ s0, p0, d0, GGML_TYPE_F16));
2386
+ }
2387
+ }
2388
+ }
2333
2389
2334
2390
test_cases.emplace_back (new test_repeat (GGML_TYPE_F32, {10 , 10 , 10 , 10 }, {1 , 1 , 1 , 1 }));
2335
2391
test_cases.emplace_back (new test_repeat (GGML_TYPE_F32, {10 , 10 , 10 , 10 }, {2 , 1 , 1 , 1 }));
0 commit comments