Skip to content

Commit b4af5d9

Browse files
committed
Conv2D: Add CPU version
1 parent b25e927 commit b4af5d9

File tree

5 files changed

+214
-0
lines changed

5 files changed

+214
-0
lines changed

ggml/include/ggml.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ extern "C" {
482482
GGML_OP_CONV_TRANSPOSE_1D,
483483
GGML_OP_IM2COL,
484484
GGML_OP_IM2COL_BACK,
485+
GGML_OP_CONV_2D,
485486
GGML_OP_CONV_2D_DW,
486487
GGML_OP_CONV_TRANSPOSE_2D,
487488
GGML_OP_POOL_1D,
@@ -1744,6 +1745,17 @@ extern "C" {
17441745
struct ggml_tensor * b,
17451746
int stride);
17461747

1748+
GGML_API struct ggml_tensor * ggml_conv_2d_direct(
1749+
struct ggml_context * ctx,
1750+
struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
1751+
struct ggml_tensor * b, // input data [W, H, C, N]
1752+
int s0, // stride dimension 0
1753+
int s1, // stride dimension 1
1754+
int p0, // padding dimension 0
1755+
int p1, // padding dimension 1
1756+
int d0, // dilation dimension 0
1757+
int d1); // dilation dimension 1
1758+
17471759
enum ggml_op_pool {
17481760
GGML_OP_POOL_MAX,
17491761
GGML_OP_POOL_AVG,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
18661866
{
18671867
ggml_compute_forward_im2col_back_f32(params, tensor);
18681868
} break;
1869+
case GGML_OP_CONV_2D:
1870+
{
1871+
ggml_compute_forward_conv_2d(params, tensor);
1872+
} break;
18691873
case GGML_OP_CONV_2D_DW:
18701874
{
18711875
ggml_compute_forward_conv_2d_dw(params, tensor);
@@ -2212,6 +2216,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22122216
} break;
22132217
case GGML_OP_IM2COL:
22142218
case GGML_OP_IM2COL_BACK:
2219+
case GGML_OP_CONV_2D:
22152220
case GGML_OP_CONV_2D_DW:
22162221
case GGML_OP_CONV_TRANSPOSE_1D:
22172222
case GGML_OP_CONV_TRANSPOSE_2D:

ggml/src/ggml-cpu/ops.cpp

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6116,6 +6116,163 @@ void ggml_compute_forward_im2col_back_f32(
61166116
}
61176117
}
61186118

6119+
// ggml_compute_forward_conv_2d
6120+
6121+
static void ggml_compute_forward_conv_2d_f32(
6122+
const ggml_compute_params * params,
6123+
const ggml_tensor * kernel, // [KW, KH, IC, OC]
6124+
const ggml_tensor * src, // [W, H, C, N]
6125+
ggml_tensor * dst) { // [OW, OH, OC, N]
6126+
6127+
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
6128+
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
6129+
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
6130+
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
6131+
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
6132+
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
6133+
6134+
const int64_t OW = dst->ne[0];
6135+
const int64_t OH = dst->ne[1];
6136+
const int64_t OC = dst->ne[2];
6137+
const int64_t N = dst->ne[3];
6138+
6139+
const int64_t IW = src->ne[0];
6140+
const int64_t IH = src->ne[1];
6141+
const int64_t IC = src->ne[2];
6142+
6143+
const int64_t KW = kernel->ne[0];
6144+
const int64_t KH = kernel->ne[1];
6145+
6146+
const float * kernel_data = (const float *)kernel->data;
6147+
const float * src_data = (const float *)src->data;
6148+
float * dst_data = (float *)dst->data;
6149+
6150+
const int64_t rows_total = OH * N;
6151+
const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
6152+
const int64_t row_start = params->ith * rows_per_thread;
6153+
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
6154+
6155+
for (int64_t row = row_start; row < row_end; ++row) {
6156+
const int64_t oh = row % OH;
6157+
const int64_t n = row / OH;
6158+
const float * src_batch = src_data + n * IW * IH * IC;
6159+
6160+
for (int64_t ow = 0; ow < OW; ++ow) {
6161+
for (int64_t oc = 0; oc < OC; ++oc) {
6162+
float sum = 0.0f;
6163+
const float * kernel_channel = kernel_data + oc * KW * KH * IC;
6164+
6165+
for (int64_t kh = 0; kh < KH; ++kh) {
6166+
const int64_t ih = oh * s1 - p1 + kh * d1;
6167+
if (ih < 0 || ih >= IH) continue;
6168+
6169+
for (int64_t kw = 0; kw < KW; ++kw) {
6170+
const int64_t iw = ow * s0 - p0 + kw * d0;
6171+
if (iw < 0 || iw >= IW) continue;
6172+
6173+
#pragma omp simd
6174+
for (int64_t ic = 0; ic < IC; ++ic) {
6175+
const float * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6176+
const float * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6177+
sum += (*kernel_ptr) * (*src_ptr);
6178+
}
6179+
}
6180+
}
6181+
6182+
dst_data[((n * OC + oc) * OH + oh) * OW + ow] = sum;
6183+
}
6184+
}
6185+
}
6186+
}
6187+
6188+
static void ggml_compute_forward_conv_2d_f16(
6189+
const ggml_compute_params * params,
6190+
const ggml_tensor * kernel, // [KW, KH, IC, OC]
6191+
const ggml_tensor * src, // [W, H, C, N]
6192+
ggml_tensor * dst) { // [OW, OH, OC, N]
6193+
6194+
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
6195+
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
6196+
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
6197+
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
6198+
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
6199+
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
6200+
6201+
const int64_t OW = dst->ne[0];
6202+
const int64_t OH = dst->ne[1];
6203+
const int64_t OC = dst->ne[2];
6204+
const int64_t N = dst->ne[3];
6205+
6206+
const int64_t IW = src->ne[0];
6207+
const int64_t IH = src->ne[1];
6208+
const int64_t IC = src->ne[2];
6209+
6210+
const int64_t KW = kernel->ne[0];
6211+
const int64_t KH = kernel->ne[1];
6212+
6213+
const ggml_fp16_t * kernel_data = (const ggml_fp16_t *)kernel->data;
6214+
const ggml_fp16_t * src_data = (const ggml_fp16_t *)src->data;
6215+
ggml_fp16_t * dst_data = (ggml_fp16_t *)dst->data;
6216+
6217+
const int64_t rows_total = OH * N;
6218+
const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
6219+
const int64_t row_start = params->ith * rows_per_thread;
6220+
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
6221+
6222+
for (int64_t row = row_start; row < row_end; ++row) {
6223+
const int64_t oh = row % OH;
6224+
const int64_t n = row / OH;
6225+
const ggml_fp16_t * src_batch = src_data + n * IW * IH * IC;
6226+
6227+
for (int64_t ow = 0; ow < OW; ++ow) {
6228+
for (int64_t oc = 0; oc < OC; ++oc) {
6229+
float sum = 0.0f;
6230+
const ggml_fp16_t * kernel_channel = kernel_data + oc * KW * KH * IC;
6231+
for (int64_t kh = 0; kh < KH; ++kh) {
6232+
const int64_t ih = oh * s1 - p1 + kh * d1;
6233+
if (ih < 0 || ih >= IH) continue;
6234+
6235+
for (int64_t kw = 0; kw < KW; ++kw) {
6236+
const int64_t iw = ow * s0 - p0 + kw * d0;
6237+
if (iw < 0 || iw >= IW) continue;
6238+
6239+
for (int64_t ic = 0; ic < IC; ++ic) {
6240+
const ggml_fp16_t * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6241+
const ggml_fp16_t * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6242+
sum += GGML_FP16_TO_FP32(*kernel_ptr) * GGML_FP16_TO_FP32(*src_ptr);
6243+
}
6244+
}
6245+
}
6246+
6247+
dst_data[((n * OC + oc) * OH + oh) * OW + ow] = GGML_FP32_TO_FP16(sum);
6248+
}
6249+
}
6250+
}
6251+
}
6252+
6253+
void ggml_compute_forward_conv_2d(
6254+
const ggml_compute_params * params,
6255+
ggml_tensor * dst) {
6256+
6257+
const ggml_tensor * src0 = dst->src[0];
6258+
const ggml_tensor * src1 = dst->src[1];
6259+
6260+
switch (src0->type) {
6261+
case GGML_TYPE_F16:
6262+
{
6263+
ggml_compute_forward_conv_2d_f16(params, src0, src1, dst);
6264+
} break;
6265+
case GGML_TYPE_F32:
6266+
{
6267+
ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
6268+
} break;
6269+
default:
6270+
{
6271+
GGML_ABORT("fatal error");
6272+
}
6273+
}
6274+
}
6275+
61196276
// ggml_compute_forward_conv_transpose_2d
61206277

61216278
void ggml_compute_forward_conv_transpose_2d(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc
6565
void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
6666
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
6767
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
68+
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
6869
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
6970
void ggml_compute_forward_conv_2d_dw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
7071
void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml.c

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10421042
"conv_transpose_1d(x)",
10431043
"im2col(x)",
10441044
"im2col_back(x)",
1045+
"conv_2d(x)",
10451046
"conv_2d_dw(x)",
10461047
"conv_transpose_2d(x)",
10471048
"pool_1d(x)",
@@ -4157,6 +4158,44 @@ struct ggml_tensor * ggml_conv_2d_dw_direct(
41574158
return result;
41584159
}
41594160

4161+
// ggml_conv_2d_direct
4162+
4163+
struct ggml_tensor * ggml_conv_2d_direct(
4164+
struct ggml_context * ctx,
4165+
struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
4166+
struct ggml_tensor * b, // input data [W, H, C, N]
4167+
int s0, // stride dimension 0
4168+
int s1, // stride dimension 1
4169+
int p0, // padding dimension 0
4170+
int p1, // padding dimension 1
4171+
int d0, // dilation dimension 0
4172+
int d1) {// dilation dimension 1
4173+
4174+
GGML_ASSERT(a->ne[2] == b->ne[2]);
4175+
GGML_ASSERT(a->type == b->type);
4176+
4177+
int64_t ne[4];
4178+
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
4179+
ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
4180+
ne[2] = a->ne[3];
4181+
ne[3] = b->ne[3];
4182+
4183+
struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
4184+
4185+
ggml_set_op_params_i32(result, 0, s0);
4186+
ggml_set_op_params_i32(result, 1, s1);
4187+
ggml_set_op_params_i32(result, 2, p0);
4188+
ggml_set_op_params_i32(result, 3, p1);
4189+
ggml_set_op_params_i32(result, 4, d0);
4190+
ggml_set_op_params_i32(result, 5, d1);
4191+
4192+
result->op = GGML_OP_CONV_2D;
4193+
result->src[0] = a;
4194+
result->src[1] = b;
4195+
4196+
return result;
4197+
}
4198+
41604199
// ggml_conv_transpose_2d_p0
41614200

41624201
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {

0 commit comments

Comments
 (0)