@@ -6058,6 +6058,163 @@ void ggml_compute_forward_im2col_back_f32(
6058
6058
}
6059
6059
}
6060
6060
6061
+ // ggml_compute_forward_conv_2d
6062
+
6063
+ static void ggml_compute_forward_conv_2d_f32 (
6064
+ const ggml_compute_params * params,
6065
+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6066
+ const ggml_tensor * src, // [W, H, C, N]
6067
+ ggml_tensor * dst) { // [OW, OH, OC, N]
6068
+
6069
+ const int32_t s0 = ggml_get_op_params_i32 (dst, 0 );
6070
+ const int32_t s1 = ggml_get_op_params_i32 (dst, 1 );
6071
+ const int32_t p0 = ggml_get_op_params_i32 (dst, 2 );
6072
+ const int32_t p1 = ggml_get_op_params_i32 (dst, 3 );
6073
+ const int32_t d0 = ggml_get_op_params_i32 (dst, 4 );
6074
+ const int32_t d1 = ggml_get_op_params_i32 (dst, 5 );
6075
+
6076
+ const int64_t OW = dst->ne [0 ];
6077
+ const int64_t OH = dst->ne [1 ];
6078
+ const int64_t OC = dst->ne [2 ];
6079
+ const int64_t N = dst->ne [3 ];
6080
+
6081
+ const int64_t IW = src->ne [0 ];
6082
+ const int64_t IH = src->ne [1 ];
6083
+ const int64_t IC = src->ne [2 ];
6084
+
6085
+ const int64_t KW = kernel->ne [0 ];
6086
+ const int64_t KH = kernel->ne [1 ];
6087
+
6088
+ const float * kernel_data = (const float *)kernel->data ;
6089
+ const float * src_data = (const float *)src->data ;
6090
+ float * dst_data = (float *)dst->data ;
6091
+
6092
+ const int64_t rows_total = OH * N;
6093
+ const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6094
+ const int64_t row_start = params->ith * rows_per_thread;
6095
+ const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6096
+
6097
+ for (int64_t row = row_start; row < row_end; ++row) {
6098
+ const int64_t oh = row % OH;
6099
+ const int64_t n = row / OH;
6100
+ const float * src_batch = src_data + n * IW * IH * IC;
6101
+
6102
+ for (int64_t ow = 0 ; ow < OW; ++ow) {
6103
+ for (int64_t oc = 0 ; oc < OC; ++oc) {
6104
+ float sum = 0 .0f ;
6105
+ const float * kernel_channel = kernel_data + oc * KW * KH * IC;
6106
+
6107
+ for (int64_t kh = 0 ; kh < KH; ++kh) {
6108
+ const int64_t ih = oh * s1 - p1 + kh * d1;
6109
+ if (ih < 0 || ih >= IH) continue ;
6110
+
6111
+ for (int64_t kw = 0 ; kw < KW; ++kw) {
6112
+ const int64_t iw = ow * s0 - p0 + kw * d0;
6113
+ if (iw < 0 || iw >= IW) continue ;
6114
+
6115
+ #pragma omp simd
6116
+ for (int64_t ic = 0 ; ic < IC; ++ic) {
6117
+ const float * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6118
+ const float * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6119
+ sum += (*kernel_ptr) * (*src_ptr);
6120
+ }
6121
+ }
6122
+ }
6123
+
6124
+ dst_data[((n * OC + oc) * OH + oh) * OW + ow] = sum;
6125
+ }
6126
+ }
6127
+ }
6128
+ }
6129
+
6130
+ static void ggml_compute_forward_conv_2d_f16 (
6131
+ const ggml_compute_params * params,
6132
+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6133
+ const ggml_tensor * src, // [W, H, C, N]
6134
+ ggml_tensor * dst) { // [OW, OH, OC, N]
6135
+
6136
+ const int32_t s0 = ggml_get_op_params_i32 (dst, 0 );
6137
+ const int32_t s1 = ggml_get_op_params_i32 (dst, 1 );
6138
+ const int32_t p0 = ggml_get_op_params_i32 (dst, 2 );
6139
+ const int32_t p1 = ggml_get_op_params_i32 (dst, 3 );
6140
+ const int32_t d0 = ggml_get_op_params_i32 (dst, 4 );
6141
+ const int32_t d1 = ggml_get_op_params_i32 (dst, 5 );
6142
+
6143
+ const int64_t OW = dst->ne [0 ];
6144
+ const int64_t OH = dst->ne [1 ];
6145
+ const int64_t OC = dst->ne [2 ];
6146
+ const int64_t N = dst->ne [3 ];
6147
+
6148
+ const int64_t IW = src->ne [0 ];
6149
+ const int64_t IH = src->ne [1 ];
6150
+ const int64_t IC = src->ne [2 ];
6151
+
6152
+ const int64_t KW = kernel->ne [0 ];
6153
+ const int64_t KH = kernel->ne [1 ];
6154
+
6155
+ const ggml_fp16_t * kernel_data = (const ggml_fp16_t *)kernel->data ;
6156
+ const ggml_fp16_t * src_data = (const ggml_fp16_t *)src->data ;
6157
+ ggml_fp16_t * dst_data = (ggml_fp16_t *)dst->data ;
6158
+
6159
+ const int64_t rows_total = OH * N;
6160
+ const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6161
+ const int64_t row_start = params->ith * rows_per_thread;
6162
+ const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6163
+
6164
+ for (int64_t row = row_start; row < row_end; ++row) {
6165
+ const int64_t oh = row % OH;
6166
+ const int64_t n = row / OH;
6167
+ const ggml_fp16_t * src_batch = src_data + n * IW * IH * IC;
6168
+
6169
+ for (int64_t ow = 0 ; ow < OW; ++ow) {
6170
+ for (int64_t oc = 0 ; oc < OC; ++oc) {
6171
+ float sum = 0 .0f ;
6172
+ const ggml_fp16_t * kernel_channel = kernel_data + oc * KW * KH * IC;
6173
+ for (int64_t kh = 0 ; kh < KH; ++kh) {
6174
+ const int64_t ih = oh * s1 - p1 + kh * d1;
6175
+ if (ih < 0 || ih >= IH) continue ;
6176
+
6177
+ for (int64_t kw = 0 ; kw < KW; ++kw) {
6178
+ const int64_t iw = ow * s0 - p0 + kw * d0;
6179
+ if (iw < 0 || iw >= IW) continue ;
6180
+
6181
+ for (int64_t ic = 0 ; ic < IC; ++ic) {
6182
+ const ggml_fp16_t * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6183
+ const ggml_fp16_t * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6184
+ sum += GGML_FP16_TO_FP32 (*kernel_ptr) * GGML_FP16_TO_FP32 (*src_ptr);
6185
+ }
6186
+ }
6187
+ }
6188
+
6189
+ dst_data[((n * OC + oc) * OH + oh) * OW + ow] = GGML_FP32_TO_FP16 (sum);
6190
+ }
6191
+ }
6192
+ }
6193
+ }
6194
+
6195
+ void ggml_compute_forward_conv_2d (
6196
+ const ggml_compute_params * params,
6197
+ ggml_tensor * dst) {
6198
+
6199
+ const ggml_tensor * src0 = dst->src [0 ];
6200
+ const ggml_tensor * src1 = dst->src [1 ];
6201
+
6202
+ switch (src0->type ) {
6203
+ case GGML_TYPE_F16:
6204
+ {
6205
+ ggml_compute_forward_conv_2d_f16 (params, src0, src1, dst);
6206
+ } break ;
6207
+ case GGML_TYPE_F32:
6208
+ {
6209
+ ggml_compute_forward_conv_2d_f32 (params, src0, src1, dst);
6210
+ } break ;
6211
+ default :
6212
+ {
6213
+ GGML_ABORT (" fatal error" );
6214
+ }
6215
+ }
6216
+ }
6217
+
6061
6218
// ggml_compute_forward_conv_transpose_2d
6062
6219
6063
6220
void ggml_compute_forward_conv_transpose_2d (
0 commit comments