Skip to content

Commit f872788

Browse files
committed
Add tests for conv_transpose_1d_gemm
Signed-off-by: Salvatore Mesoraca <s.mesoraca16@gmail.com>
1 parent 822aebd commit f872788

File tree

4 files changed

+1003
-0
lines changed

4 files changed

+1003
-0
lines changed

tests/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,12 @@ add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
350350
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
351351
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
352352

353+
# test-conv-transpose-1d-gemm
354+
355+
set(TEST_TARGET test-conv-transpose-1d-gemm)
356+
add_executable(${TEST_TARGET} ${TEST_TARGET}.cpp)
357+
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
358+
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
353359

354360
#
355361
# test-dup

tests/test-backend-ops.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,6 +1412,43 @@ struct test_conv_transpose_1d : public test_case {
14121412
}
14131413
};
14141414

1415+
struct test_conv_transpose_1d_gemm : public test_case {
1416+
const std::array<int64_t, 4> ne_input;
1417+
const std::array<int64_t, 4> ne_kernel;
1418+
1419+
const int s0; // stride
1420+
const int p0; // padding
1421+
const int d0; // dilation
1422+
1423+
ggml_type input_type;
1424+
ggml_type kernel_type;
1425+
1426+
std::string vars() override {
1427+
return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0);
1428+
}
1429+
1430+
test_conv_transpose_1d_gemm(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_height, input_channels, 1]
1431+
std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, kernel_height, input_channels, 1]
1432+
int s0 = 1, int p0 = 0, int d0 = 1,
1433+
ggml_type input_type = GGML_TYPE_F32,
1434+
ggml_type kernel_type = GGML_TYPE_F16)
1435+
: ne_input(ne_input)
1436+
, ne_kernel(ne_kernel)
1437+
, s0(s0)
1438+
, p0(p0)
1439+
, d0(d0)
1440+
, input_type(input_type)
1441+
, kernel_type(kernel_type)
1442+
{}
1443+
1444+
ggml_tensor * build_graph(ggml_context * ctx) override {
1445+
ggml_tensor * input = ggml_new_tensor(ctx, input_type, 4, ne_input.data());
1446+
ggml_tensor * kernel = ggml_new_tensor(ctx, kernel_type, 4, ne_kernel.data());
1447+
ggml_tensor * out = ggml_conv_transpose_1d_gemm(ctx, kernel, input, s0, p0, d0);
1448+
return out;
1449+
}
1450+
};
1451+
14151452
// GGML_OP_IM2COL
14161453
struct test_im2col : public test_case {
14171454
const ggml_type type_input;
@@ -2330,6 +2367,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
23302367
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
23312368
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
23322369

2370+
test_cases.emplace_back(new test_conv_transpose_1d_gemm());
2371+
for (int64_t s0 = 1; s0 < 4; ++s0) {
2372+
for (int64_t p0 = 0; p0 < 2; ++p0) {
2373+
for (int64_t d0 = 1; d0 < 4; ++d0) {
2374+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({3,2,1,1}, {2,3,2,1}, s0, p0, d0));
2375+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({3,2,1,1}, {3,2,2,1}, s0, p0, d0));
2376+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({3,2,1,1}, {3,1,2,1}, s0, p0, d0));
2377+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({2,1,1,1}, {3,1,1,1}, s0, p0, d0));
2378+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({3,2,1,1}, {2,3,2,1},
2379+
s0, p0, d0, GGML_TYPE_F16));
2380+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({3,2,1,1}, {3,2,2,1},
2381+
s0, p0, d0, GGML_TYPE_F16));
2382+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({3,2,1,1}, {3,1,2,1},
2383+
s0, p0, d0, GGML_TYPE_F16));
2384+
test_cases.emplace_back(new test_conv_transpose_1d_gemm({2,1,1,1}, {3,1,1,1},
2385+
s0, p0, d0, GGML_TYPE_F16));
2386+
}
2387+
}
2388+
}
23332389

23342390
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
23352391
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1}));

0 commit comments

Comments
 (0)