Skip to content

Commit e69f0c7

Browse files
committed
Half decent
1 parent b4af5d9 commit e69f0c7

File tree

4 files changed

+168
-50
lines changed

4 files changed

+168
-50
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,10 @@ static void ggml_init_arm_arch_features(void) {
687687

688688
#endif // __ARM_ARCH
689689

690+
void ggml_compute_forward_mul_mat(
691+
const struct ggml_compute_params * params,
692+
struct ggml_tensor * dst);
693+
690694
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
691695
GGML_ASSERT(!ggml_get_no_alloc(ctx));
692696

@@ -1193,7 +1197,7 @@ static void ggml_compute_forward_mul_mat_one_chunk(
11931197
}
11941198
}
11951199

1196-
static void ggml_compute_forward_mul_mat(
1200+
void ggml_compute_forward_mul_mat(
11971201
const struct ggml_compute_params * params,
11981202
struct ggml_tensor * dst) {
11991203

@@ -2735,6 +2739,12 @@ struct ggml_cplan ggml_graph_plan(
27352739
GGML_ABORT("fatal error");
27362740
}
27372741
} break;
2742+
case GGML_OP_CONV_2D:
2743+
{
2744+
cur = GGML_IM2COL_WORK_SIZE;
2745+
//Add enough space for kernel transpose
2746+
cur += sizeof(ggml_fp16_t)*node->src[1]->ne[0]*node->src[1]->ne[1]*node->src[1]->ne[2]*node->src[1]->ne[3];
2747+
} break;
27382748
case GGML_OP_CONV_TRANSPOSE_2D:
27392749
{
27402750
const int64_t ne00 = node->src[0]->ne[0]; // W

ggml/src/ggml-cpu/ops.cpp

Lines changed: 152 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "ggml-cpu.h"
44
#include "ggml-impl.h"
55
#include "binary-ops.h"
6+
#include "ggml.h"
67
#include "unary-ops.h"
78
#include "vec.h"
89

@@ -6116,70 +6117,173 @@ void ggml_compute_forward_im2col_back_f32(
61166117
}
61176118
}
61186119

6120+
static void ggml_call_mul_mat(
6121+
const ggml_compute_params * params,
6122+
int64_t m, int64_t n, int64_t k,
6123+
void * a, void * b, void * c) {
6124+
6125+
struct ggml_tensor src1 = {};
6126+
src1.ne[0] = k;
6127+
src1.ne[1] = m;
6128+
src1.ne[2] = 1;
6129+
src1.ne[3] = 1;
6130+
src1.nb[0] = sizeof(float);
6131+
src1.nb[1] = k * sizeof(float);
6132+
src1.nb[2] = src1.nb[1];
6133+
src1.nb[3] = src1.nb[2];
6134+
src1.data = a;
6135+
6136+
struct ggml_tensor src0 = {};
6137+
src0.ne[0] = k;
6138+
src0.ne[1] = n;
6139+
src0.ne[2] = 1;
6140+
src0.ne[3] = 1;
6141+
src0.nb[0] = sizeof(float);
6142+
src0.nb[1] = k * sizeof(float);
6143+
src0.nb[2] = src0.nb[1];
6144+
src0.nb[3] = src0.nb[2];
6145+
src0.data = b;
6146+
6147+
struct ggml_tensor dst = {};
6148+
dst.ne[0] = n;
6149+
dst.ne[1] = m;
6150+
dst.ne[2] = 1;
6151+
dst.ne[3] = 1;
6152+
dst.nb[0] = sizeof(float);
6153+
dst.nb[1] = n * sizeof(float);
6154+
dst.nb[2] = dst.nb[1];
6155+
dst.nb[3] = dst.nb[2];
6156+
dst.data = c;
6157+
dst.src[0] = &src0;
6158+
dst.src[1] = &src1;
6159+
6160+
ggml_compute_forward_mul_mat(params, &dst);
6161+
}
6162+
6163+
61196164
// ggml_compute_forward_conv_2d
61206165

6121-
static void ggml_compute_forward_conv_2d_f32(
6122-
const ggml_compute_params * params,
6123-
const ggml_tensor * kernel, // [KW, KH, IC, OC]
6124-
const ggml_tensor * src, // [W, H, C, N]
6125-
ggml_tensor * dst) { // [OW, OH, OC, N]
6166+
static void ggml_compute_forward_conv_2d_f32(const ggml_compute_params * params,
6167+
ggml_tensor * dst) {
61266168

6127-
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
6128-
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
6129-
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
6130-
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
6131-
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
6132-
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
6169+
const ggml_tensor * src = dst->src[1]; // [W H C_in N]
6170+
const ggml_tensor * kernel = dst->src[0]; // [W H C_in C_out]
61336171

6134-
const int64_t OW = dst->ne[0];
6135-
const int64_t OH = dst->ne[1];
6136-
const int64_t OC = dst->ne[2];
6137-
const int64_t N = dst->ne[3];
6172+
GGML_ASSERT(ggml_is_contiguous(kernel));
61386173

6139-
const int64_t IW = src->ne[0];
6140-
const int64_t IH = src->ne[1];
6141-
const int64_t IC = src->ne[2];
6174+
const int32_t stride_x = dst->op_params[0];
6175+
const int32_t stride_y = dst->op_params[1];
6176+
const int32_t pad_x = dst->op_params[2];
6177+
const int32_t pad_y = dst->op_params[3];
61426178

6143-
const int64_t KW = kernel->ne[0];
6144-
const int64_t KH = kernel->ne[1];
6179+
const int64_t c_in = src->ne[2];
6180+
const int64_t c_out = kernel->ne[3];
6181+
GGML_ASSERT(c_in == kernel->ne[2]);
61456182

6146-
const float * kernel_data = (const float *)kernel->data;
6147-
const float * src_data = (const float *)src->data;
6148-
float * dst_data = (float *)dst->data;
6183+
const int64_t src_w = src->ne[0];
6184+
const int64_t src_h = src->ne[1];
6185+
const int64_t knl_w = kernel->ne[0];
6186+
const int64_t knl_h = kernel->ne[1];
6187+
const int64_t dst_w = dst->ne[0];
6188+
const int64_t dst_h = dst->ne[1];
61496189

6150-
const int64_t rows_total = OH * N;
6151-
const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
6152-
const int64_t row_start = params->ith * rows_per_thread;
6153-
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
61546190

6155-
for (int64_t row = row_start; row < row_end; ++row) {
6156-
const int64_t oh = row % OH;
6157-
const int64_t n = row / OH;
6158-
const float * src_batch = src_data + n * IW * IH * IC;
6191+
float * src_data = (float *) src->data;
6192+
float * knl_data = (float *) kernel->data;
6193+
float * dst_data = ( float *) dst->data;
61596194

6160-
for (int64_t ow = 0; ow < OW; ++ow) {
6161-
for (int64_t oc = 0; oc < OC; ++oc) {
6162-
float sum = 0.0f;
6163-
const float * kernel_channel = kernel_data + oc * KW * KH * IC;
61646195

6165-
for (int64_t kh = 0; kh < KH; ++kh) {
6166-
const int64_t ih = oh * s1 - p1 + kh * d1;
6167-
if (ih < 0 || ih >= IH) continue;
6196+
const int64_t knl_n = knl_w * knl_h * c_in;
6197+
const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
6198+
61686199

6169-
for (int64_t kw = 0; kw < KW; ++kw) {
6170-
const int64_t iw = ow * s0 - p0 + kw * d0;
6171-
if (iw < 0 || iw >= IW) continue;
6200+
6201+
const int64_t space_per_patch = knl_n * sizeof(float) + patch_total * c_out * sizeof(float);
61726202

6173-
#pragma omp simd
6174-
for (int64_t ic = 0; ic < IC; ++ic) {
6175-
const float * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6176-
const float * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6177-
sum += (*kernel_ptr) * (*src_ptr);
6203+
const int64_t batch_size = params->wsize / space_per_patch;
6204+
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6205+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6206+
6207+
6208+
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6209+
6210+
float * tmp = (float *) params->wdata; // per-thread scratch
6211+
6212+
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6213+
6214+
const int64_t patch_start_batch = batch_i * patches_per_batch;
6215+
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
6216+
patch_total);
6217+
const int64_t patch_n = patch_end_batch - patch_start_batch;
6218+
6219+
const int64_t patch_per_thread =
6220+
(patch_n + params->nth - 1) / params->nth;
6221+
const int64_t patch_start = patch_start_batch +
6222+
params->ith * patch_per_thread;
6223+
const int64_t patch_end = std::min(patch_start + patch_per_thread,
6224+
patch_end_batch);
6225+
6226+
//im2col for a patch
6227+
for (int64_t p = patch_start; p < patch_end; ++p) {
6228+
const int64_t b = p / (dst_w * dst_h);
6229+
const int64_t dy = (p / dst_w) % dst_h;
6230+
const int64_t dx = p % dst_w;
6231+
6232+
const float * src_base = (const float *)((char *)src_data + b * src->nb[3]);
6233+
float * out_row = tmp + (p % patches_per_batch) * knl_n;
6234+
6235+
// Extract patch in IC,KH,KW order (same as im2col)
6236+
for (int64_t ic = 0; ic < c_in; ++ic) {
6237+
for (int64_t ky = 0; ky < knl_h; ++ky) {
6238+
for (int64_t kx = 0; kx < knl_w; ++kx) {
6239+
const int64_t sy = dy * stride_y + ky - pad_y;
6240+
const int64_t sx = dx * stride_x + kx - pad_x;
6241+
6242+
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6243+
6244+
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6245+
out_row[dst_idx] = 0.0f;
6246+
} else {
6247+
float * src_ptr = (float *)((char *)src_base +
6248+
sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6249+
out_row[dst_idx] = *src_ptr;
61786250
}
61796251
}
61806252
}
6253+
}
6254+
} // patches handled by this thread
6255+
6256+
ggml_barrier(params->threadpool); // wait for all threads
61816257

6182-
dst_data[((n * OC + oc) * OH + oh) * OW + ow] = sum;
6258+
//GEMM output is patch_n * cout
6259+
float * gemm_output = tmp + patches_per_batch * knl_n;
6260+
6261+
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6262+
ggml_call_mul_mat(params, patch_n, c_out, knl_n,
6263+
tmp, knl_data, gemm_output);
6264+
6265+
// Barrier to ensure GEMM completes before permutation
6266+
ggml_barrier(params->threadpool);
6267+
6268+
// Distribute permutation work across threads
6269+
const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
6270+
const int64_t permute_start = params->ith * permute_per_thread;
6271+
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
6272+
6273+
// Each thread handles part of the permutation from [patch_n, c_out] to WHCN layout
6274+
for (int64_t i = permute_start; i < permute_end; ++i) {
6275+
const int64_t p = patch_start_batch + i;
6276+
const int64_t b = p / (dst_w * dst_h); // batch index
6277+
const int64_t dy = (p / dst_w) % dst_h; // height index
6278+
const int64_t dx = p % dst_w; // width index
6279+
6280+
// Copy all channels for this spatial position
6281+
for (int64_t oc = 0; oc < c_out; ++oc) {
6282+
const float value = gemm_output[i * c_out + oc];
6283+
// Write to WHCN layout: dst[w, h, c, n]
6284+
float * dst_ptr = (float *)((char *)dst_data +
6285+
dx * dst->nb[0] + dy * dst->nb[1] + oc * dst->nb[2] + b * dst->nb[3]);
6286+
*dst_ptr = value;
61836287
}
61846288
}
61856289
}
@@ -6264,7 +6368,7 @@ void ggml_compute_forward_conv_2d(
62646368
} break;
62656369
case GGML_TYPE_F32:
62666370
{
6267-
ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
6371+
ggml_compute_forward_conv_2d_f32(params, dst);
62686372
} break;
62696373
default:
62706374
{

ggml/src/ggml-cpu/ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020

2121
static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
2222

23+
// Work buffer size for im2col operations in CONV2D
24+
#define GGML_IM2COL_WORK_SIZE (16 * 1024 * 1024) // 16MB work buffer
25+
2326
#ifdef __cplusplus
2427
extern "C" {
2528
#endif
@@ -107,6 +110,7 @@ void ggml_compute_forward_custom(const struct ggml_compute_params * params, stru
107110
void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108111
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
109112
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
113+
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
110114

111115
#ifdef __cplusplus
112116
}

ggml/src/ggml.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4172,7 +4172,7 @@ struct ggml_tensor * ggml_conv_2d_direct(
41724172
int d1) {// dilation dimension 1
41734173

41744174
GGML_ASSERT(a->ne[2] == b->ne[2]);
4175-
GGML_ASSERT(a->type == b->type);
4175+
//GGML_ASSERT(a->type == b->type);
41764176

41774177
int64_t ne[4];
41784178
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);

0 commit comments

Comments
 (0)