Skip to content

Commit 870b650

Browse files
committed
Half decent
1 parent 4e3f47c commit 870b650

File tree

4 files changed

+168
-50
lines changed

4 files changed

+168
-50
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,10 @@ static void ggml_init_arm_arch_features(void) {
683683

684684
#endif // __ARM_ARCH
685685

686+
void ggml_compute_forward_mul_mat(
687+
const struct ggml_compute_params * params,
688+
struct ggml_tensor * dst);
689+
686690
struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value) {
687691
GGML_ASSERT(!ggml_get_no_alloc(ctx));
688692

@@ -1189,7 +1193,7 @@ static void ggml_compute_forward_mul_mat_one_chunk(
11891193
}
11901194
}
11911195

1192-
static void ggml_compute_forward_mul_mat(
1196+
void ggml_compute_forward_mul_mat(
11931197
const struct ggml_compute_params * params,
11941198
struct ggml_tensor * dst) {
11951199

@@ -2726,6 +2730,12 @@ struct ggml_cplan ggml_graph_plan(
27262730
GGML_ABORT("fatal error");
27272731
}
27282732
} break;
2733+
case GGML_OP_CONV_2D:
2734+
{
2735+
cur = GGML_IM2COL_WORK_SIZE;
2736+
//Add enough space for kernel transpose
2737+
cur += sizeof(ggml_fp16_t)*node->src[1]->ne[0]*node->src[1]->ne[1]*node->src[1]->ne[2]*node->src[1]->ne[3];
2738+
} break;
27292739
case GGML_OP_CONV_TRANSPOSE_2D:
27302740
{
27312741
const int64_t ne00 = node->src[0]->ne[0]; // W

ggml/src/ggml-cpu/ops.cpp

Lines changed: 152 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "ggml-cpu.h"
44
#include "ggml-impl.h"
55
#include "binary-ops.h"
6+
#include "ggml.h"
67
#include "unary-ops.h"
78
#include "vec.h"
89

@@ -6058,70 +6059,173 @@ void ggml_compute_forward_im2col_back_f32(
60586059
}
60596060
}
60606061

6062+
static void ggml_call_mul_mat(
6063+
const ggml_compute_params * params,
6064+
int64_t m, int64_t n, int64_t k,
6065+
void * a, void * b, void * c) {
6066+
6067+
struct ggml_tensor src1 = {};
6068+
src1.ne[0] = k;
6069+
src1.ne[1] = m;
6070+
src1.ne[2] = 1;
6071+
src1.ne[3] = 1;
6072+
src1.nb[0] = sizeof(float);
6073+
src1.nb[1] = k * sizeof(float);
6074+
src1.nb[2] = src1.nb[1];
6075+
src1.nb[3] = src1.nb[2];
6076+
src1.data = a;
6077+
6078+
struct ggml_tensor src0 = {};
6079+
src0.ne[0] = k;
6080+
src0.ne[1] = n;
6081+
src0.ne[2] = 1;
6082+
src0.ne[3] = 1;
6083+
src0.nb[0] = sizeof(float);
6084+
src0.nb[1] = k * sizeof(float);
6085+
src0.nb[2] = src0.nb[1];
6086+
src0.nb[3] = src0.nb[2];
6087+
src0.data = b;
6088+
6089+
struct ggml_tensor dst = {};
6090+
dst.ne[0] = n;
6091+
dst.ne[1] = m;
6092+
dst.ne[2] = 1;
6093+
dst.ne[3] = 1;
6094+
dst.nb[0] = sizeof(float);
6095+
dst.nb[1] = n * sizeof(float);
6096+
dst.nb[2] = dst.nb[1];
6097+
dst.nb[3] = dst.nb[2];
6098+
dst.data = c;
6099+
dst.src[0] = &src0;
6100+
dst.src[1] = &src1;
6101+
6102+
ggml_compute_forward_mul_mat(params, &dst);
6103+
}
6104+
6105+
60616106
// ggml_compute_forward_conv_2d
60626107

6063-
static void ggml_compute_forward_conv_2d_f32(
6064-
const ggml_compute_params * params,
6065-
const ggml_tensor * kernel, // [KW, KH, IC, OC]
6066-
const ggml_tensor * src, // [W, H, C, N]
6067-
ggml_tensor * dst) { // [OW, OH, OC, N]
6108+
static void ggml_compute_forward_conv_2d_f32(const ggml_compute_params * params,
6109+
ggml_tensor * dst) {
60686110

6069-
const int32_t s0 = ggml_get_op_params_i32(dst, 0);
6070-
const int32_t s1 = ggml_get_op_params_i32(dst, 1);
6071-
const int32_t p0 = ggml_get_op_params_i32(dst, 2);
6072-
const int32_t p1 = ggml_get_op_params_i32(dst, 3);
6073-
const int32_t d0 = ggml_get_op_params_i32(dst, 4);
6074-
const int32_t d1 = ggml_get_op_params_i32(dst, 5);
6111+
const ggml_tensor * src = dst->src[1]; // [W H C_in N]
6112+
const ggml_tensor * kernel = dst->src[0]; // [W H C_in C_out]
60756113

6076-
const int64_t OW = dst->ne[0];
6077-
const int64_t OH = dst->ne[1];
6078-
const int64_t OC = dst->ne[2];
6079-
const int64_t N = dst->ne[3];
6114+
GGML_ASSERT(ggml_is_contiguous(kernel));
60806115

6081-
const int64_t IW = src->ne[0];
6082-
const int64_t IH = src->ne[1];
6083-
const int64_t IC = src->ne[2];
6116+
const int32_t stride_x = dst->op_params[0];
6117+
const int32_t stride_y = dst->op_params[1];
6118+
const int32_t pad_x = dst->op_params[2];
6119+
const int32_t pad_y = dst->op_params[3];
60846120

6085-
const int64_t KW = kernel->ne[0];
6086-
const int64_t KH = kernel->ne[1];
6121+
const int64_t c_in = src->ne[2];
6122+
const int64_t c_out = kernel->ne[3];
6123+
GGML_ASSERT(c_in == kernel->ne[2]);
60876124

6088-
const float * kernel_data = (const float *)kernel->data;
6089-
const float * src_data = (const float *)src->data;
6090-
float * dst_data = (float *)dst->data;
6125+
const int64_t src_w = src->ne[0];
6126+
const int64_t src_h = src->ne[1];
6127+
const int64_t knl_w = kernel->ne[0];
6128+
const int64_t knl_h = kernel->ne[1];
6129+
const int64_t dst_w = dst->ne[0];
6130+
const int64_t dst_h = dst->ne[1];
60916131

6092-
const int64_t rows_total = OH * N;
6093-
const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
6094-
const int64_t row_start = params->ith * rows_per_thread;
6095-
const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
60966132

6097-
for (int64_t row = row_start; row < row_end; ++row) {
6098-
const int64_t oh = row % OH;
6099-
const int64_t n = row / OH;
6100-
const float * src_batch = src_data + n * IW * IH * IC;
6133+
float * src_data = (float *) src->data;
6134+
float * knl_data = (float *) kernel->data;
6135+
float * dst_data = ( float *) dst->data;
61016136

6102-
for (int64_t ow = 0; ow < OW; ++ow) {
6103-
for (int64_t oc = 0; oc < OC; ++oc) {
6104-
float sum = 0.0f;
6105-
const float * kernel_channel = kernel_data + oc * KW * KH * IC;
61066137

6107-
for (int64_t kh = 0; kh < KH; ++kh) {
6108-
const int64_t ih = oh * s1 - p1 + kh * d1;
6109-
if (ih < 0 || ih >= IH) continue;
6138+
const int64_t knl_n = knl_w * knl_h * c_in;
6139+
const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
6140+
61106141

6111-
for (int64_t kw = 0; kw < KW; ++kw) {
6112-
const int64_t iw = ow * s0 - p0 + kw * d0;
6113-
if (iw < 0 || iw >= IW) continue;
6142+
6143+
const int64_t space_per_patch = knl_n * sizeof(float) + patch_total * c_out * sizeof(float);
61146144

6115-
#pragma omp simd
6116-
for (int64_t ic = 0; ic < IC; ++ic) {
6117-
const float * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6118-
const float * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6119-
sum += (*kernel_ptr) * (*src_ptr);
6145+
const int64_t batch_size = params->wsize / space_per_patch;
6146+
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6147+
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6148+
6149+
6150+
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6151+
6152+
float * tmp = (float *) params->wdata; // per-thread scratch
6153+
6154+
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6155+
6156+
const int64_t patch_start_batch = batch_i * patches_per_batch;
6157+
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
6158+
patch_total);
6159+
const int64_t patch_n = patch_end_batch - patch_start_batch;
6160+
6161+
const int64_t patch_per_thread =
6162+
(patch_n + params->nth - 1) / params->nth;
6163+
const int64_t patch_start = patch_start_batch +
6164+
params->ith * patch_per_thread;
6165+
const int64_t patch_end = std::min(patch_start + patch_per_thread,
6166+
patch_end_batch);
6167+
6168+
//im2col for a patch
6169+
for (int64_t p = patch_start; p < patch_end; ++p) {
6170+
const int64_t b = p / (dst_w * dst_h);
6171+
const int64_t dy = (p / dst_w) % dst_h;
6172+
const int64_t dx = p % dst_w;
6173+
6174+
const float * src_base = (const float *)((char *)src_data + b * src->nb[3]);
6175+
float * out_row = tmp + (p % patches_per_batch) * knl_n;
6176+
6177+
// Extract patch in IC,KH,KW order (same as im2col)
6178+
for (int64_t ic = 0; ic < c_in; ++ic) {
6179+
for (int64_t ky = 0; ky < knl_h; ++ky) {
6180+
for (int64_t kx = 0; kx < knl_w; ++kx) {
6181+
const int64_t sy = dy * stride_y + ky - pad_y;
6182+
const int64_t sx = dx * stride_x + kx - pad_x;
6183+
6184+
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6185+
6186+
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6187+
out_row[dst_idx] = 0.0f;
6188+
} else {
6189+
float * src_ptr = (float *)((char *)src_base +
6190+
sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6191+
out_row[dst_idx] = *src_ptr;
61206192
}
61216193
}
61226194
}
6195+
}
6196+
} // patches handled by this thread
6197+
6198+
ggml_barrier(params->threadpool); // wait for all threads
61236199

6124-
dst_data[((n * OC + oc) * OH + oh) * OW + ow] = sum;
6200+
//GEMM output is patch_n * cout
6201+
float * gemm_output = tmp + patches_per_batch * knl_n;
6202+
6203+
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6204+
ggml_call_mul_mat(params, patch_n, c_out, knl_n,
6205+
tmp, knl_data, gemm_output);
6206+
6207+
// Barrier to ensure GEMM completes before permutation
6208+
ggml_barrier(params->threadpool);
6209+
6210+
// Distribute permutation work across threads
6211+
const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
6212+
const int64_t permute_start = params->ith * permute_per_thread;
6213+
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
6214+
6215+
// Each thread handles part of the permutation from [patch_n, c_out] to WHCN layout
6216+
for (int64_t i = permute_start; i < permute_end; ++i) {
6217+
const int64_t p = patch_start_batch + i;
6218+
const int64_t b = p / (dst_w * dst_h); // batch index
6219+
const int64_t dy = (p / dst_w) % dst_h; // height index
6220+
const int64_t dx = p % dst_w; // width index
6221+
6222+
// Copy all channels for this spatial position
6223+
for (int64_t oc = 0; oc < c_out; ++oc) {
6224+
const float value = gemm_output[i * c_out + oc];
6225+
// Write to WHCN layout: dst[w, h, c, n]
6226+
float * dst_ptr = (float *)((char *)dst_data +
6227+
dx * dst->nb[0] + dy * dst->nb[1] + oc * dst->nb[2] + b * dst->nb[3]);
6228+
*dst_ptr = value;
61256229
}
61266230
}
61276231
}
@@ -6206,7 +6310,7 @@ void ggml_compute_forward_conv_2d(
62066310
} break;
62076311
case GGML_TYPE_F32:
62086312
{
6209-
ggml_compute_forward_conv_2d_f32(params, src0, src1, dst);
6313+
ggml_compute_forward_conv_2d_f32(params, dst);
62106314
} break;
62116315
default:
62126316
{

ggml/src/ggml-cpu/ops.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020

2121
static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
2222

23+
// Work buffer size for im2col operations in CONV2D
24+
#define GGML_IM2COL_WORK_SIZE (16 * 1024 * 1024) // 16MB work buffer
25+
2326
#ifdef __cplusplus
2427
extern "C" {
2528
#endif
@@ -106,6 +109,7 @@ void ggml_compute_forward_custom(const struct ggml_compute_params * params, stru
106109
void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
107110
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108111
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
112+
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
109113

110114
#ifdef __cplusplus
111115
}

ggml/src/ggml.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4146,7 +4146,7 @@ struct ggml_tensor * ggml_conv_2d_direct(
41464146
int d1) {// dilation dimension 1
41474147

41484148
GGML_ASSERT(a->ne[2] == b->ne[2]);
4149-
GGML_ASSERT(a->type == b->type);
4149+
//GGML_ASSERT(a->type == b->type);
41504150

41514151
int64_t ne[4];
41524152
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);

0 commit comments

Comments
 (0)