@@ -6116,6 +6116,163 @@ void ggml_compute_forward_im2col_back_f32(
6116
6116
}
6117
6117
}
6118
6118
6119
+ // ggml_compute_forward_conv_2d
6120
+
6121
+ static void ggml_compute_forward_conv_2d_f32 (
6122
+ const ggml_compute_params * params,
6123
+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6124
+ const ggml_tensor * src, // [W, H, C, N]
6125
+ ggml_tensor * dst) { // [OW, OH, OC, N]
6126
+
6127
+ const int32_t s0 = ggml_get_op_params_i32 (dst, 0 );
6128
+ const int32_t s1 = ggml_get_op_params_i32 (dst, 1 );
6129
+ const int32_t p0 = ggml_get_op_params_i32 (dst, 2 );
6130
+ const int32_t p1 = ggml_get_op_params_i32 (dst, 3 );
6131
+ const int32_t d0 = ggml_get_op_params_i32 (dst, 4 );
6132
+ const int32_t d1 = ggml_get_op_params_i32 (dst, 5 );
6133
+
6134
+ const int64_t OW = dst->ne [0 ];
6135
+ const int64_t OH = dst->ne [1 ];
6136
+ const int64_t OC = dst->ne [2 ];
6137
+ const int64_t N = dst->ne [3 ];
6138
+
6139
+ const int64_t IW = src->ne [0 ];
6140
+ const int64_t IH = src->ne [1 ];
6141
+ const int64_t IC = src->ne [2 ];
6142
+
6143
+ const int64_t KW = kernel->ne [0 ];
6144
+ const int64_t KH = kernel->ne [1 ];
6145
+
6146
+ const float * kernel_data = (const float *)kernel->data ;
6147
+ const float * src_data = (const float *)src->data ;
6148
+ float * dst_data = (float *)dst->data ;
6149
+
6150
+ const int64_t rows_total = OH * N;
6151
+ const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6152
+ const int64_t row_start = params->ith * rows_per_thread;
6153
+ const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6154
+
6155
+ for (int64_t row = row_start; row < row_end; ++row) {
6156
+ const int64_t oh = row % OH;
6157
+ const int64_t n = row / OH;
6158
+ const float * src_batch = src_data + n * IW * IH * IC;
6159
+
6160
+ for (int64_t ow = 0 ; ow < OW; ++ow) {
6161
+ for (int64_t oc = 0 ; oc < OC; ++oc) {
6162
+ float sum = 0 .0f ;
6163
+ const float * kernel_channel = kernel_data + oc * KW * KH * IC;
6164
+
6165
+ for (int64_t kh = 0 ; kh < KH; ++kh) {
6166
+ const int64_t ih = oh * s1 - p1 + kh * d1;
6167
+ if (ih < 0 || ih >= IH) continue ;
6168
+
6169
+ for (int64_t kw = 0 ; kw < KW; ++kw) {
6170
+ const int64_t iw = ow * s0 - p0 + kw * d0;
6171
+ if (iw < 0 || iw >= IW) continue ;
6172
+
6173
+ #pragma omp simd
6174
+ for (int64_t ic = 0 ; ic < IC; ++ic) {
6175
+ const float * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6176
+ const float * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6177
+ sum += (*kernel_ptr) * (*src_ptr);
6178
+ }
6179
+ }
6180
+ }
6181
+
6182
+ dst_data[((n * OC + oc) * OH + oh) * OW + ow] = sum;
6183
+ }
6184
+ }
6185
+ }
6186
+ }
6187
+
6188
+ static void ggml_compute_forward_conv_2d_f16 (
6189
+ const ggml_compute_params * params,
6190
+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6191
+ const ggml_tensor * src, // [W, H, C, N]
6192
+ ggml_tensor * dst) { // [OW, OH, OC, N]
6193
+
6194
+ const int32_t s0 = ggml_get_op_params_i32 (dst, 0 );
6195
+ const int32_t s1 = ggml_get_op_params_i32 (dst, 1 );
6196
+ const int32_t p0 = ggml_get_op_params_i32 (dst, 2 );
6197
+ const int32_t p1 = ggml_get_op_params_i32 (dst, 3 );
6198
+ const int32_t d0 = ggml_get_op_params_i32 (dst, 4 );
6199
+ const int32_t d1 = ggml_get_op_params_i32 (dst, 5 );
6200
+
6201
+ const int64_t OW = dst->ne [0 ];
6202
+ const int64_t OH = dst->ne [1 ];
6203
+ const int64_t OC = dst->ne [2 ];
6204
+ const int64_t N = dst->ne [3 ];
6205
+
6206
+ const int64_t IW = src->ne [0 ];
6207
+ const int64_t IH = src->ne [1 ];
6208
+ const int64_t IC = src->ne [2 ];
6209
+
6210
+ const int64_t KW = kernel->ne [0 ];
6211
+ const int64_t KH = kernel->ne [1 ];
6212
+
6213
+ const ggml_fp16_t * kernel_data = (const ggml_fp16_t *)kernel->data ;
6214
+ const ggml_fp16_t * src_data = (const ggml_fp16_t *)src->data ;
6215
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)dst->data ;
6216
+
6217
+ const int64_t rows_total = OH * N;
6218
+ const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6219
+ const int64_t row_start = params->ith * rows_per_thread;
6220
+ const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6221
+
6222
+ for (int64_t row = row_start; row < row_end; ++row) {
6223
+ const int64_t oh = row % OH;
6224
+ const int64_t n = row / OH;
6225
+ const ggml_fp16_t * src_batch = src_data + n * IW * IH * IC;
6226
+
6227
+ for (int64_t ow = 0 ; ow < OW; ++ow) {
6228
+ for (int64_t oc = 0 ; oc < OC; ++oc) {
6229
+ float sum = 0 .0f ;
6230
+ const ggml_fp16_t * kernel_channel = kernel_data + oc * KW * KH * IC;
6231
+ for (int64_t kh = 0 ; kh < KH; ++kh) {
6232
+ const int64_t ih = oh * s1 - p1 + kh * d1;
6233
+ if (ih < 0 || ih >= IH) continue ;
6234
+
6235
+ for (int64_t kw = 0 ; kw < KW; ++kw) {
6236
+ const int64_t iw = ow * s0 - p0 + kw * d0;
6237
+ if (iw < 0 || iw >= IW) continue ;
6238
+
6239
+ for (int64_t ic = 0 ; ic < IC; ++ic) {
6240
+ const ggml_fp16_t * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6241
+ const ggml_fp16_t * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6242
+ sum += GGML_FP16_TO_FP32 (*kernel_ptr) * GGML_FP16_TO_FP32 (*src_ptr);
6243
+ }
6244
+ }
6245
+ }
6246
+
6247
+ dst_data[((n * OC + oc) * OH + oh) * OW + ow] = GGML_FP32_TO_FP16 (sum);
6248
+ }
6249
+ }
6250
+ }
6251
+ }
6252
+
6253
+ void ggml_compute_forward_conv_2d (
6254
+ const ggml_compute_params * params,
6255
+ ggml_tensor * dst) {
6256
+
6257
+ const ggml_tensor * src0 = dst->src [0 ];
6258
+ const ggml_tensor * src1 = dst->src [1 ];
6259
+
6260
+ switch (src0->type ) {
6261
+ case GGML_TYPE_F16:
6262
+ {
6263
+ ggml_compute_forward_conv_2d_f16 (params, src0, src1, dst);
6264
+ } break ;
6265
+ case GGML_TYPE_F32:
6266
+ {
6267
+ ggml_compute_forward_conv_2d_f32 (params, src0, src1, dst);
6268
+ } break ;
6269
+ default :
6270
+ {
6271
+ GGML_ABORT (" fatal error" );
6272
+ }
6273
+ }
6274
+ }
6275
+
6119
6276
// ggml_compute_forward_conv_transpose_2d
6120
6277
6121
6278
void ggml_compute_forward_conv_transpose_2d (
0 commit comments