Skip to content

Commit ea995d8

Browse files
balisujohnslaren
authored andcommitted
feat: cuda implementation for ggml_conv_transpose_1d (ggml/854)
* conv transpose 1d passing test for 1d input and kernel * working for different input and output channel counts, added test for variable stride * initial draft appears to work with stride other than 1 * working with all old and new conv1d tests * added a test for large tensors * removed use cuda hardcoding * restored test-conv-transpose.c * removed unused arugments, and fixed bug where test failure would cause subsequent tests to fail * fixed accumulator bug * added test to test-backend-ops * fixed mistake * addressed review * fixed includes * removed blank lines * style and warning fixes * return failure when test fails * fix supports_op --------- Co-authored-by: slaren <slarengh@gmail.com>
1 parent f10e7c8 commit ea995d8

File tree

4 files changed

+146
-1
lines changed

4 files changed

+146
-1
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "ggml-cuda/tsembd.cuh"
3030
#include "ggml-cuda/unary.cuh"
3131
#include "ggml-cuda/upscale.cuh"
32+
#include "ggml-cuda/conv-transpose-1d.cuh"
3233

3334
#include <algorithm>
3435
#include <array>
@@ -2261,6 +2262,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22612262
case GGML_OP_IM2COL:
22622263
ggml_cuda_op_im2col(ctx, dst);
22632264
break;
2265+
case GGML_OP_CONV_TRANSPOSE_1D:
2266+
ggml_cuda_op_conv_transpose_1d(ctx,dst);
2267+
break;
22642268
case GGML_OP_POOL_2D:
22652269
ggml_cuda_op_pool2d(ctx, dst);
22662270
break;
@@ -2804,6 +2808,15 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28042808
ggml_type src0_type = op->src[0]->type;
28052809
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
28062810
} break;
2811+
case GGML_OP_CONV_TRANSPOSE_1D:
2812+
{
2813+
ggml_type src0_type = op->src[0]->type;
2814+
ggml_type src1_type = op->src[1]->type;
2815+
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
2816+
return true;
2817+
}
2818+
return false;
2819+
} break;
28072820
case GGML_OP_NONE:
28082821
case GGML_OP_RESHAPE:
28092822
case GGML_OP_VIEW:
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#include "conv-transpose-1d.cuh"
2+
3+
static __global__ void conv_transpose_1d_kernel(
4+
const int s0, const int p0, const int d0, const int output_size,
5+
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
6+
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
7+
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
8+
const float * src0, const float * src1, float * dst) {
9+
int global_index = threadIdx.x + blockIdx.x * blockDim.x;
10+
if (global_index >= output_size) {
11+
return;
12+
}
13+
14+
int out_index = global_index / dst_ne0;
15+
16+
float accumulator = 0;
17+
18+
for (int c = 0; c < src0_ne2; c++) {
19+
int idx = global_index % dst_ne0;
20+
21+
int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
22+
int input_offset = src1_ne0 * c;
23+
24+
for (int i = 0; i < src1_ne0; i++) {
25+
if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
26+
continue;
27+
}
28+
int weight_idx = idx - i*s0;
29+
30+
float kernel_weight = src0[kernel_offset + weight_idx];
31+
float input_value = src1[input_offset+i];
32+
33+
accumulator += kernel_weight * input_value;
34+
}
35+
}
36+
dst[global_index] = accumulator;
37+
}
38+
39+
static void conv_transpose_1d_f32_f32_cuda(
40+
const int s0, const int p0, const int d0, const int output_size,
41+
const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
42+
const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
43+
const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
44+
const float * src0, const float * src1, float * dst,
45+
cudaStream_t stream) {
46+
47+
const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;
48+
conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(
49+
s0,p0,d0,output_size,
50+
src0_ne0, src0_ne1, src0_ne2, src0_ne3,
51+
src1_ne0, src1_ne1, src1_ne2, src1_ne3,
52+
dst_ne0, dst_ne1, dst_ne2, dst_ne3,
53+
src0,src1, dst);
54+
}
55+
56+
void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
57+
const ggml_tensor * src0 = dst->src[0];
58+
const float * src0_d = (const float *)src0->data;
59+
60+
const ggml_tensor * src1 = dst->src[1];
61+
const float * src1_d = (const float *)src1->data;
62+
63+
float * dst_d = (float *)dst->data;
64+
cudaStream_t stream = ctx.stream();
65+
66+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
67+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
68+
69+
GGML_ASSERT(ggml_is_contiguous(src0));
70+
GGML_ASSERT(ggml_is_contiguous(src1));
71+
72+
const int32_t * opts = (const int32_t *)dst->op_params;
73+
74+
const int s0 = opts[0];
75+
const int p0 = 0;//opts[3];
76+
const int d0 = 1;//opts[4];
77+
78+
const int64_t kernel_size = ggml_nelements(src0);
79+
const int64_t input_size = ggml_nelements(src1);
80+
const int64_t output_size = ggml_nelements(dst);
81+
82+
conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
83+
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
84+
src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
85+
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
86+
src0_d, src1_d, dst_d, stream);
87+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256
4+
5+
void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

tests/test-backend-ops.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,36 @@ struct test_pool2d : public test_case {
12661266
}
12671267
};
12681268

1269+
// GGML_OP_CONV_TRANSPOSE_1D
1270+
struct test_conv_transpose_1d : public test_case {
1271+
1272+
const std::array<int64_t, 4> ne_input;
1273+
const std::array<int64_t, 4> ne_kernel;
1274+
1275+
// stride
1276+
const int s0;
1277+
// padding
1278+
const int p0;
1279+
// dilation
1280+
const int d0;
1281+
1282+
std::string vars() override {
1283+
return VARS_TO_STR5(ne_input, ne_kernel, s0, p0, d0);
1284+
}
1285+
1286+
test_conv_transpose_1d(std::array<int64_t, 4> ne_input = {197, 32, 1, 1}, // [input_width, input_height, input_channels, 1]
1287+
std::array<int64_t, 4> ne_kernel = {16, 32, 32, 1}, // [kernel_width, kernel_height, input_channels, 1]
1288+
int s0 = 1, int p0 = 0, int d0 = 1)
1289+
: ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), p0(p0), d0(d0) {}
1290+
1291+
ggml_tensor * build_graph(ggml_context * ctx) override {
1292+
ggml_tensor * input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_input.data());
1293+
ggml_tensor * kernel = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne_kernel.data());
1294+
ggml_tensor * out = ggml_conv_transpose_1d(ctx, kernel, input, s0, p0, d0);
1295+
return out;
1296+
}
1297+
};
1298+
12691299
// GGML_OP_IM2COL
12701300
struct test_im2col : public test_case {
12711301
const ggml_type type_input;
@@ -1279,7 +1309,7 @@ struct test_im2col : public test_case {
12791309
// padding
12801310
const int p0;
12811311
const int p1;
1282-
// dilatation
1312+
// dilation
12831313
const int d0;
12841314
const int d1;
12851315
// mode
@@ -2098,6 +2128,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
20982128
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32));
20992129
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16));
21002130

2131+
test_cases.emplace_back(new test_conv_transpose_1d());
2132+
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 3, 0, 1));
2133+
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 2, 0, 1));
2134+
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {2,3,2,1}, 1, 0, 1));
2135+
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 2, 0, 1));
2136+
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,2,2,1}, 1, 0, 1));
2137+
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
2138+
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
2139+
2140+
21012141
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
21022142
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1}));
21032143
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1}));

0 commit comments

Comments
 (0)