@@ -6105,18 +6105,21 @@ static void ggml_call_mul_mat(
6105
6105
6106
6106
// ggml_compute_forward_conv_2d
6107
6107
6108
- static void ggml_compute_forward_conv_2d_f32 (const ggml_compute_params * params,
6109
- ggml_tensor * dst) {
6110
-
6111
- const ggml_tensor * src = dst-> src [ 1 ]; // [W H C_in N]
6112
- const ggml_tensor * kernel = dst-> src [ 0 ]; // [W H C_in C_out ]
6108
+ static void ggml_compute_forward_conv_2d_f32 (
6109
+ const ggml_compute_params * params,
6110
+ const ggml_tensor * kernel, // [KW, KH, IC, OC] - fp32
6111
+ const ggml_tensor * src, // [W, H, C, N]
6112
+ ggml_tensor * dst) { // [OW, OH, OC, N ]
6113
6113
6114
6114
GGML_ASSERT (ggml_is_contiguous (kernel));
6115
+ GGML_ASSERT (kernel->type == GGML_TYPE_F32);
6115
6116
6116
- const int32_t stride_x = dst->op_params [0 ];
6117
- const int32_t stride_y = dst->op_params [1 ];
6118
- const int32_t pad_x = dst->op_params [2 ];
6119
- const int32_t pad_y = dst->op_params [3 ];
6117
+ const int32_t stride_x = dst->op_params [0 ];
6118
+ const int32_t stride_y = dst->op_params [1 ];
6119
+ const int32_t pad_x = dst->op_params [2 ];
6120
+ const int32_t pad_y = dst->op_params [3 ];
6121
+ const int32_t dilation_x = dst->op_params [4 ];
6122
+ const int32_t dilation_y = dst->op_params [5 ];
6120
6123
6121
6124
const int64_t c_in = src->ne [2 ];
6122
6125
const int64_t c_out = kernel->ne [3 ];
@@ -6129,193 +6132,104 @@ static void ggml_compute_forward_conv_2d_f32(const ggml_compute_params * params,
6129
6132
const int64_t dst_w = dst->ne [0 ];
6130
6133
const int64_t dst_h = dst->ne [1 ];
6131
6134
6132
-
6133
- float * src_data = (float *) src->data ;
6134
- float * knl_data = (float *) kernel->data ;
6135
- float * dst_data = ( float *) dst->data ;
6136
-
6135
+ float * src_data = (float *) src->data ;
6136
+ float * knl_data = (float *) kernel->data ;
6137
+ float * dst_data = (float *) dst->data ;
6137
6138
6138
6139
const int64_t knl_n = knl_w * knl_h * c_in;
6139
6140
const int64_t patch_total = dst->ne [3 ] * dst_w * dst_h;
6140
-
6141
-
6142
-
6143
- const int64_t space_per_patch = knl_n * sizeof (float ) + patch_total * c_out * sizeof (float );
6144
6141
6145
- const int64_t batch_size = params->wsize / space_per_patch;
6142
+ const int64_t space_per_patch = knl_n * sizeof (float ) + c_out * sizeof (float );
6143
+ const int64_t batch_size = params->wsize / space_per_patch;
6146
6144
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8 ) * 8 : batch_size;
6147
- const int64_t batch_n = (patch_total + patches_per_batch - 1 ) / patches_per_batch;
6148
-
6145
+ const int64_t batch_n = (patch_total + patches_per_batch - 1 ) / patches_per_batch;
6149
6146
6150
6147
GGML_ASSERT (patches_per_batch > 0 && batch_size >= 1 );
6151
6148
6152
- float * tmp = (float *) params->wdata ; // per-thread scratch
6149
+ float * tmp = (float *) params->wdata ;
6153
6150
6154
6151
for (int64_t batch_i = 0 ; batch_i < batch_n; ++batch_i) {
6155
6152
6156
6153
const int64_t patch_start_batch = batch_i * patches_per_batch;
6157
6154
const int64_t patch_end_batch = std::min (patch_start_batch + patches_per_batch,
6158
6155
patch_total);
6159
- const int64_t patch_n = patch_end_batch - patch_start_batch;
6156
+ const int64_t patch_n = patch_end_batch - patch_start_batch;
6160
6157
6161
- const int64_t patch_per_thread =
6162
- (patch_n + params->nth - 1 ) / params->nth ;
6163
- const int64_t patch_start = patch_start_batch +
6164
- params->ith * patch_per_thread;
6165
- const int64_t patch_end = std::min (patch_start + patch_per_thread,
6166
- patch_end_batch);
6158
+ const int64_t patch_per_thread = (patch_n + params->nth - 1 ) / params->nth ;
6159
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6160
+ const int64_t patch_end = std::min (patch_start + patch_per_thread,patch_end_batch);
6167
6161
6168
6162
// im2col for a patch
6169
6163
for (int64_t p = patch_start; p < patch_end; ++p) {
6170
- const int64_t b = p / (dst_w * dst_h);
6171
- const int64_t dy = (p / dst_w) % dst_h;
6172
- const int64_t dx = p % dst_w;
6164
+ const int64_t batch_n = p / (dst_w * dst_h);
6165
+ const int64_t src_x = (p / dst_w) % dst_h;
6166
+ const int64_t src_y = p % dst_w;
6173
6167
6174
- const float * src_base = (const float *)((char *)src_data + b * src->nb [3 ]);
6175
- float * out_row = tmp + (p % patches_per_batch) * knl_n;
6168
+ float * src_base = (float *)((char *)src_data + batch_n * src->nb [3 ]);
6169
+ float * dst_row = tmp + (p % patches_per_batch) * knl_n;
6176
6170
6177
- // Extract patch in IC,KH,KW order (same as im2col)
6178
6171
for (int64_t ic = 0 ; ic < c_in; ++ic) {
6179
6172
for (int64_t ky = 0 ; ky < knl_h; ++ky) {
6180
6173
for (int64_t kx = 0 ; kx < knl_w; ++kx) {
6181
- const int64_t sy = dy * stride_y + ky - pad_y;
6182
- const int64_t sx = dx * stride_x + kx - pad_x;
6183
-
6174
+ const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6175
+ const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6176
+
6184
6177
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6185
-
6178
+
6186
6179
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6187
- out_row [dst_idx] = 0 .0f ;
6180
+ dst_row [dst_idx] = 0 .0f ;
6188
6181
} else {
6189
- float * src_ptr = (float *)((char *)src_base +
6190
- sx * src->nb [0 ] + sy * src->nb [1 ] + ic * src->nb [2 ]);
6191
- out_row[dst_idx] = *src_ptr;
6182
+ float * src_ptr = (float *)((char *)src_base + sx * src->nb [0 ] + sy * src->nb [1 ] + ic * src->nb [2 ]);
6183
+ dst_row[dst_idx] = *src_ptr;
6192
6184
}
6193
6185
}
6194
6186
}
6195
6187
}
6196
6188
} // patches handled by this thread
6197
6189
6198
- ggml_barrier (params->threadpool ); // wait for all threads
6190
+ ggml_barrier (params->threadpool );
6199
6191
6200
- // GEMM output is patch_n * cout
6201
6192
float * gemm_output = tmp + patches_per_batch * knl_n;
6202
-
6193
+
6203
6194
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6204
6195
ggml_call_mul_mat (params, patch_n, c_out, knl_n,
6205
6196
tmp, knl_data, gemm_output);
6206
-
6207
- // Barrier to ensure GEMM completes before permutation
6197
+
6208
6198
ggml_barrier (params->threadpool );
6209
-
6210
- // Distribute permutation work across threads
6199
+
6200
+
6201
+ // permute back [OC, N, OH, OW] to [N, OC, OH, OW]
6211
6202
const int64_t permute_per_thread = (patch_n + params->nth - 1 ) / params->nth ;
6212
6203
const int64_t permute_start = params->ith * permute_per_thread;
6213
6204
const int64_t permute_end = std::min (permute_start + permute_per_thread, patch_n);
6214
-
6215
- // Each thread handles part of the permutation from [patch_n, c_out] to WHCN layout
6205
+
6216
6206
for (int64_t i = permute_start; i < permute_end; ++i) {
6217
- const int64_t p = patch_start_batch + i;
6218
- const int64_t b = p / (dst_w * dst_h); // batch index
6219
- const int64_t dy = (p / dst_w) % dst_h; // height index
6220
- const int64_t dx = p % dst_w; // width index
6221
-
6222
- // Copy all channels for this spatial position
6207
+ const int64_t p = patch_start_batch + i;
6208
+ const int64_t batch_n = p / (dst_w * dst_h);
6209
+ const int64_t dst_y = (p / dst_w) % dst_h;
6210
+ const int64_t dst_x = p % dst_w;
6211
+
6223
6212
for (int64_t oc = 0 ; oc < c_out; ++oc) {
6224
6213
const float value = gemm_output[i * c_out + oc];
6225
6214
// Write to WHCN layout: dst[w, h, c, n]
6226
- float * dst_ptr = (float *)((char *)dst_data +
6227
- dx * dst->nb [0 ] + dy * dst->nb [1 ] + oc * dst->nb [2 ] + b * dst->nb [3 ]);
6215
+ float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb [0 ] + dst_y * dst->nb [1 ] + oc * dst->nb [2 ] + batch_n * dst->nb [3 ]);
6228
6216
*dst_ptr = value;
6229
6217
}
6230
6218
}
6231
6219
}
6232
6220
}
6233
6221
6234
- static void ggml_compute_forward_conv_2d_f16 (
6235
- const ggml_compute_params * params,
6236
- const ggml_tensor * kernel, // [KW, KH, IC, OC]
6237
- const ggml_tensor * src, // [W, H, C, N]
6238
- ggml_tensor * dst) { // [OW, OH, OC, N]
6239
-
6240
- const int32_t s0 = ggml_get_op_params_i32 (dst, 0 );
6241
- const int32_t s1 = ggml_get_op_params_i32 (dst, 1 );
6242
- const int32_t p0 = ggml_get_op_params_i32 (dst, 2 );
6243
- const int32_t p1 = ggml_get_op_params_i32 (dst, 3 );
6244
- const int32_t d0 = ggml_get_op_params_i32 (dst, 4 );
6245
- const int32_t d1 = ggml_get_op_params_i32 (dst, 5 );
6246
-
6247
- const int64_t OW = dst->ne [0 ];
6248
- const int64_t OH = dst->ne [1 ];
6249
- const int64_t OC = dst->ne [2 ];
6250
- const int64_t N = dst->ne [3 ];
6251
-
6252
- const int64_t IW = src->ne [0 ];
6253
- const int64_t IH = src->ne [1 ];
6254
- const int64_t IC = src->ne [2 ];
6255
-
6256
- const int64_t KW = kernel->ne [0 ];
6257
- const int64_t KH = kernel->ne [1 ];
6258
-
6259
- const ggml_fp16_t * kernel_data = (const ggml_fp16_t *)kernel->data ;
6260
- const ggml_fp16_t * src_data = (const ggml_fp16_t *)src->data ;
6261
- ggml_fp16_t * dst_data = (ggml_fp16_t *)dst->data ;
6262
-
6263
- const int64_t rows_total = OH * N;
6264
- const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6265
- const int64_t row_start = params->ith * rows_per_thread;
6266
- const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6267
-
6268
- for (int64_t row = row_start; row < row_end; ++row) {
6269
- const int64_t oh = row % OH;
6270
- const int64_t n = row / OH;
6271
- const ggml_fp16_t * src_batch = src_data + n * IW * IH * IC;
6272
-
6273
- for (int64_t ow = 0 ; ow < OW; ++ow) {
6274
- for (int64_t oc = 0 ; oc < OC; ++oc) {
6275
- float sum = 0 .0f ;
6276
- const ggml_fp16_t * kernel_channel = kernel_data + oc * KW * KH * IC;
6277
- for (int64_t kh = 0 ; kh < KH; ++kh) {
6278
- const int64_t ih = oh * s1 - p1 + kh * d1;
6279
- if (ih < 0 || ih >= IH) continue ;
6280
-
6281
- for (int64_t kw = 0 ; kw < KW; ++kw) {
6282
- const int64_t iw = ow * s0 - p0 + kw * d0;
6283
- if (iw < 0 || iw >= IW) continue ;
6284
-
6285
- for (int64_t ic = 0 ; ic < IC; ++ic) {
6286
- const ggml_fp16_t * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6287
- const ggml_fp16_t * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6288
- sum += GGML_FP16_TO_FP32 (*kernel_ptr) * GGML_FP16_TO_FP32 (*src_ptr);
6289
- }
6290
- }
6291
- }
6292
-
6293
- dst_data[((n * OC + oc) * OH + oh) * OW + ow] = GGML_FP32_TO_FP16 (sum);
6294
- }
6295
- }
6296
- }
6297
- }
6298
-
6299
6222
void ggml_compute_forward_conv_2d (
6300
6223
const ggml_compute_params * params,
6301
6224
ggml_tensor * dst) {
6302
6225
6303
6226
const ggml_tensor * src0 = dst->src [0 ];
6304
6227
const ggml_tensor * src1 = dst->src [1 ];
6305
6228
6306
- switch (src0->type ) {
6307
- case GGML_TYPE_F16:
6308
- {
6309
- ggml_compute_forward_conv_2d_f16 (params, src0, src1, dst);
6310
- } break ;
6311
- case GGML_TYPE_F32:
6312
- {
6313
- ggml_compute_forward_conv_2d_f32 (params, dst);
6314
- } break ;
6315
- default :
6316
- {
6317
- GGML_ABORT (" fatal error" );
6318
- }
6229
+ if (src0->type == GGML_TYPE_F16) {
6230
+ GGML_ASSERT (false && " F16 not supported yet" );
6231
+ } else {
6232
+ ggml_compute_forward_conv_2d_f32 (params, src0, src1, dst);
6319
6233
}
6320
6234
}
6321
6235
0 commit comments