Skip to content

Commit a2c0311

Browse files
committed
Tiled approach for F32
1 parent 48b7fa2 commit a2c0311

File tree

2 files changed

+53
-138
lines changed

2 files changed

+53
-138
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 52 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -6592,18 +6592,21 @@ static void ggml_call_mul_mat(
65926592

65936593
// ggml_compute_forward_conv_2d
65946594

6595-
static void ggml_compute_forward_conv_2d_f32(const ggml_compute_params * params,
6596-
ggml_tensor * dst) {
6597-
6598-
const ggml_tensor * src = dst->src[1]; // [W H C_in N]
6599-
const ggml_tensor * kernel = dst->src[0]; // [W H C_in C_out]
6595+
static void ggml_compute_forward_conv_2d_f32(
6596+
const ggml_compute_params * params,
6597+
const ggml_tensor * kernel, // [KW, KH, IC, OC] - fp32
6598+
const ggml_tensor * src, // [W, H, C, N]
6599+
ggml_tensor * dst) { // [OW, OH, OC, N]
66006600

66016601
GGML_ASSERT(ggml_is_contiguous(kernel));
6602+
GGML_ASSERT(kernel->type == GGML_TYPE_F32);
66026603

6603-
const int32_t stride_x = dst->op_params[0];
6604-
const int32_t stride_y = dst->op_params[1];
6605-
const int32_t pad_x = dst->op_params[2];
6606-
const int32_t pad_y = dst->op_params[3];
6604+
const int32_t stride_x = dst->op_params[0];
6605+
const int32_t stride_y = dst->op_params[1];
6606+
const int32_t pad_x = dst->op_params[2];
6607+
const int32_t pad_y = dst->op_params[3];
6608+
const int32_t dilation_x = dst->op_params[4];
6609+
const int32_t dilation_y = dst->op_params[5];
66076610

66086611
const int64_t c_in = src->ne[2];
66096612
const int64_t c_out = kernel->ne[3];
@@ -6616,193 +6619,104 @@ static void ggml_compute_forward_conv_2d_f32(const ggml_compute_params * params,
66166619
const int64_t dst_w = dst->ne[0];
66176620
const int64_t dst_h = dst->ne[1];
66186621

6619-
6620-
float * src_data = (float *) src->data;
6621-
float * knl_data = (float *) kernel->data;
6622-
float * dst_data = ( float *) dst->data;
6623-
6622+
float * src_data = (float*) src->data;
6623+
float * knl_data = (float*) kernel->data;
6624+
float * dst_data = (float*) dst->data;
66246625

66256626
const int64_t knl_n = knl_w * knl_h * c_in;
66266627
const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
6627-
6628-
6629-
6630-
const int64_t space_per_patch = knl_n * sizeof(float) + patch_total * c_out * sizeof(float);
66316628

6632-
const int64_t batch_size = params->wsize / space_per_patch;
6629+
const int64_t space_per_patch = knl_n * sizeof(float) + c_out * sizeof(float);
6630+
const int64_t batch_size = params->wsize / space_per_patch;
66336631
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6634-
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6635-
6632+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
66366633

66376634
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
66386635

6639-
float * tmp = (float *) params->wdata; // per-thread scratch
6636+
float * tmp = (float *) params->wdata;
66406637

66416638
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
66426639

66436640
const int64_t patch_start_batch = batch_i * patches_per_batch;
66446641
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
66456642
patch_total);
6646-
const int64_t patch_n = patch_end_batch - patch_start_batch;
6643+
const int64_t patch_n = patch_end_batch - patch_start_batch;
66476644

6648-
const int64_t patch_per_thread =
6649-
(patch_n + params->nth - 1) / params->nth;
6650-
const int64_t patch_start = patch_start_batch +
6651-
params->ith * patch_per_thread;
6652-
const int64_t patch_end = std::min(patch_start + patch_per_thread,
6653-
patch_end_batch);
6645+
const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
6646+
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6647+
const int64_t patch_end = std::min(patch_start + patch_per_thread,patch_end_batch);
66546648

66556649
//im2col for a patch
66566650
for (int64_t p = patch_start; p < patch_end; ++p) {
6657-
const int64_t b = p / (dst_w * dst_h);
6658-
const int64_t dy = (p / dst_w) % dst_h;
6659-
const int64_t dx = p % dst_w;
6651+
const int64_t batch_n = p / (dst_w * dst_h);
6652+
const int64_t src_x = (p / dst_w) % dst_h;
6653+
const int64_t src_y = p % dst_w;
66606654

6661-
const float * src_base = (const float *)((char *)src_data + b * src->nb[3]);
6662-
float * out_row = tmp + (p % patches_per_batch) * knl_n;
6655+
float * src_base = (float *)((char *)src_data + batch_n * src->nb[3]);
6656+
float * dst_row = tmp + (p % patches_per_batch) * knl_n;
66636657

6664-
// Extract patch in IC,KH,KW order (same as im2col)
66656658
for (int64_t ic = 0; ic < c_in; ++ic) {
66666659
for (int64_t ky = 0; ky < knl_h; ++ky) {
66676660
for (int64_t kx = 0; kx < knl_w; ++kx) {
6668-
const int64_t sy = dy * stride_y + ky - pad_y;
6669-
const int64_t sx = dx * stride_x + kx - pad_x;
6670-
6661+
const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6662+
const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6663+
66716664
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6672-
6665+
66736666
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6674-
out_row[dst_idx] = 0.0f;
6667+
dst_row[dst_idx] = 0.0f;
66756668
} else {
6676-
float * src_ptr = (float *)((char *)src_base +
6677-
sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6678-
out_row[dst_idx] = *src_ptr;
6669+
float * src_ptr = (float *)((char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6670+
dst_row[dst_idx] = *src_ptr;
66796671
}
66806672
}
66816673
}
66826674
}
66836675
} // patches handled by this thread
66846676

6685-
ggml_barrier(params->threadpool); // wait for all threads
6677+
ggml_barrier(params->threadpool);
66866678

6687-
//GEMM output is patch_n * cout
66886679
float * gemm_output = tmp + patches_per_batch * knl_n;
6689-
6680+
66906681
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
66916682
ggml_call_mul_mat(params, patch_n, c_out, knl_n,
66926683
tmp, knl_data, gemm_output);
6693-
6694-
// Barrier to ensure GEMM completes before permutation
6684+
66956685
ggml_barrier(params->threadpool);
6696-
6697-
// Distribute permutation work across threads
6686+
6687+
6688+
//permute back [OC, N, OH, OW] to [N, OC, OH, OW]
66986689
const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
66996690
const int64_t permute_start = params->ith * permute_per_thread;
67006691
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
6701-
6702-
// Each thread handles part of the permutation from [patch_n, c_out] to WHCN layout
6692+
67036693
for (int64_t i = permute_start; i < permute_end; ++i) {
6704-
const int64_t p = patch_start_batch + i;
6705-
const int64_t b = p / (dst_w * dst_h); // batch index
6706-
const int64_t dy = (p / dst_w) % dst_h; // height index
6707-
const int64_t dx = p % dst_w; // width index
6708-
6709-
// Copy all channels for this spatial position
6694+
const int64_t p = patch_start_batch + i;
6695+
const int64_t batch_n = p / (dst_w * dst_h);
6696+
const int64_t dst_y = (p / dst_w) % dst_h;
6697+
const int64_t dst_x = p % dst_w;
6698+
67106699
for (int64_t oc = 0; oc < c_out; ++oc) {
67116700
const float value = gemm_output[i * c_out + oc];
67126701
// Write to WHCN layout: dst[w, h, c, n]
6713-
float * dst_ptr = (float *)((char *)dst_data +
6714-
dx * dst->nb[0] + dy * dst->nb[1] + oc * dst->nb[2] + b * dst->nb[3]);
6702+
float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
67156703
*dst_ptr = value;
67166704
}
67176705
}
67186706
}
67196707
}
67206708

6721-
static void ggml_compute_forward_conv_2d_f16(
6722-
const ggml_compute_params * params,
6723-
const ggml_tensor * kernel, // [KW, KH, IC, OC]
6724-
const ggml_tensor * src, // [W, H, C, N]
6725-
ggml_tensor * dst) { // [OW, OH, OC, N]
6726-
6727-
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
6728-
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
6729-
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
6730-
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
6731-
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
6732-
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
6733-
6734-
const int64_t OW = dst->ne[0];
6735-
const int64_t OH = dst->ne[1];
6736-
const int64_t OC = dst->ne[2];
6737-
const int64_t N = dst->ne[3];
6738-
6739-
const int64_t IW = src->ne[0];
6740-
const int64_t IH = src->ne[1];
6741-
const int64_t IC = src->ne[2];
6742-
6743-
const int64_t KW = kernel->ne[0];
6744-
const int64_t KH = kernel->ne[1];
6745-
6746-
const ggml_fp16_t * kernel_data = (const ggml_fp16_t *)kernel->data;
6747-
const ggml_fp16_t * src_data = (const ggml_fp16_t *)src->data;
6748-
ggml_fp16_t * dst_data = (ggml_fp16_t *)dst->data;
6749-
6750-
const int64_t rows_total = OH * N;
6751-
const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
6752-
const int64_t row_start = params->ith * rows_per_thread;
6753-
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
6754-
6755-
for (int64_t row = row_start; row < row_end; ++row) {
6756-
const int64_t oh = row % OH;
6757-
const int64_t n = row / OH;
6758-
const ggml_fp16_t * src_batch = src_data + n * IW * IH * IC;
6759-
6760-
for (int64_t ow = 0; ow < OW; ++ow) {
6761-
for (int64_t oc = 0; oc < OC; ++oc) {
6762-
float sum = 0.0f;
6763-
const ggml_fp16_t * kernel_channel = kernel_data + oc * KW * KH * IC;
6764-
for (int64_t kh = 0; kh < KH; ++kh) {
6765-
const int64_t ih = oh * s1 - p1 + kh * d1;
6766-
if (ih < 0 || ih >= IH) continue;
6767-
6768-
for (int64_t kw = 0; kw < KW; ++kw) {
6769-
const int64_t iw = ow * s0 - p0 + kw * d0;
6770-
if (iw < 0 || iw >= IW) continue;
6771-
6772-
for (int64_t ic = 0; ic < IC; ++ic) {
6773-
const ggml_fp16_t * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6774-
const ggml_fp16_t * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6775-
sum += GGML_FP16_TO_FP32(*kernel_ptr) * GGML_FP16_TO_FP32(*src_ptr);
6776-
}
6777-
}
6778-
}
6779-
6780-
dst_data[((n * OC + oc) * OH + oh) * OW + ow] = GGML_FP32_TO_FP16(sum);
6781-
}
6782-
}
6783-
}
6784-
}
6785-
67866709
void ggml_compute_forward_conv_2d(
67876710
const ggml_compute_params * params,
67886711
ggml_tensor * dst) {
67896712

67906713
const ggml_tensor * src0 = dst->src[0];
67916714
const ggml_tensor * src1 = dst->src[1];
67926715

6793-
switch (src0->type) {
6794-
case GGML_TYPE_F16:
6795-
{
6796-
ggml_compute_forward_conv_2d_f16(params, src0, src1, dst);
6797-
} break;
6798-
case GGML_TYPE_F32:
6799-
{
6800-
ggml_compute_forward_conv_2d_f32(params, dst);
6801-
} break;
6802-
default:
6803-
{
6804-
GGML_ABORT("fatal error");
6805-
}
6716+
if (src0->type == GGML_TYPE_F16) {
6717+
GGML_ASSERT(false && "F16 not supported yet");
6718+
} else {
6719+
ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
68066720
}
68076721
}
68086722

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ endif()
195195
# llama_build_and_test(test-opt.cpp) # SLOW
196196
llama_build_and_test(test-gguf.cpp)
197197
llama_build_and_test(test-backend-ops.cpp)
198+
llama_build_and_test(test_conv2d_comparison.cpp)
198199

199200
llama_build_and_test(test-model-load-cancel.cpp LABEL "model")
200201
llama_build_and_test(test-autorelease.cpp LABEL "model")

0 commit comments

Comments
 (0)