|
3 | 3 | #include "ggml-cpu.h"
|
4 | 4 | #include "ggml-impl.h"
|
5 | 5 | #include "binary-ops.h"
|
| 6 | +#include "ggml.h" |
6 | 7 | #include "unary-ops.h"
|
7 | 8 | #include "vec.h"
|
8 | 9 |
|
@@ -6058,70 +6059,173 @@ void ggml_compute_forward_im2col_back_f32(
|
6058 | 6059 | }
|
6059 | 6060 | }
|
6060 | 6061 |
|
| 6062 | +static void ggml_call_mul_mat( |
| 6063 | + const ggml_compute_params * params, |
| 6064 | + int64_t m, int64_t n, int64_t k, |
| 6065 | + void * a, void * b, void * c) { |
| 6066 | + |
| 6067 | + struct ggml_tensor src1 = {}; |
| 6068 | + src1.ne[0] = k; |
| 6069 | + src1.ne[1] = m; |
| 6070 | + src1.ne[2] = 1; |
| 6071 | + src1.ne[3] = 1; |
| 6072 | + src1.nb[0] = sizeof(float); |
| 6073 | + src1.nb[1] = k * sizeof(float); |
| 6074 | + src1.nb[2] = src1.nb[1]; |
| 6075 | + src1.nb[3] = src1.nb[2]; |
| 6076 | + src1.data = a; |
| 6077 | + |
| 6078 | + struct ggml_tensor src0 = {}; |
| 6079 | + src0.ne[0] = k; |
| 6080 | + src0.ne[1] = n; |
| 6081 | + src0.ne[2] = 1; |
| 6082 | + src0.ne[3] = 1; |
| 6083 | + src0.nb[0] = sizeof(float); |
| 6084 | + src0.nb[1] = k * sizeof(float); |
| 6085 | + src0.nb[2] = src0.nb[1]; |
| 6086 | + src0.nb[3] = src0.nb[2]; |
| 6087 | + src0.data = b; |
| 6088 | + |
| 6089 | + struct ggml_tensor dst = {}; |
| 6090 | + dst.ne[0] = n; |
| 6091 | + dst.ne[1] = m; |
| 6092 | + dst.ne[2] = 1; |
| 6093 | + dst.ne[3] = 1; |
| 6094 | + dst.nb[0] = sizeof(float); |
| 6095 | + dst.nb[1] = n * sizeof(float); |
| 6096 | + dst.nb[2] = dst.nb[1]; |
| 6097 | + dst.nb[3] = dst.nb[2]; |
| 6098 | + dst.data = c; |
| 6099 | + dst.src[0] = &src0; |
| 6100 | + dst.src[1] = &src1; |
| 6101 | + |
| 6102 | + ggml_compute_forward_mul_mat(params, &dst); |
| 6103 | +} |
| 6104 | + |
| 6105 | + |
6061 | 6106 | // ggml_compute_forward_conv_2d
|
6062 | 6107 |
|
6063 |
| -static void ggml_compute_forward_conv_2d_f32( |
6064 |
| - const ggml_compute_params * params, |
6065 |
| - const ggml_tensor * kernel, // [KW, KH, IC, OC] |
6066 |
| - const ggml_tensor * src, // [W, H, C, N] |
6067 |
| - ggml_tensor * dst) { // [OW, OH, OC, N] |
| 6108 | +static void ggml_compute_forward_conv_2d_f32(const ggml_compute_params * params, |
| 6109 | + ggml_tensor * dst) { |
6068 | 6110 |
|
6069 |
| - const int32_t s0 = ggml_get_op_params_i32(dst, 0); |
6070 |
| - const int32_t s1 = ggml_get_op_params_i32(dst, 1); |
6071 |
| - const int32_t p0 = ggml_get_op_params_i32(dst, 2); |
6072 |
| - const int32_t p1 = ggml_get_op_params_i32(dst, 3); |
6073 |
| - const int32_t d0 = ggml_get_op_params_i32(dst, 4); |
6074 |
| - const int32_t d1 = ggml_get_op_params_i32(dst, 5); |
| 6111 | + const ggml_tensor * src = dst->src[1]; // [W H C_in N] |
| 6112 | + const ggml_tensor * kernel = dst->src[0]; // [W H C_in C_out] |
6075 | 6113 |
|
6076 |
| - const int64_t OW = dst->ne[0]; |
6077 |
| - const int64_t OH = dst->ne[1]; |
6078 |
| - const int64_t OC = dst->ne[2]; |
6079 |
| - const int64_t N = dst->ne[3]; |
| 6114 | + GGML_ASSERT(ggml_is_contiguous(kernel)); |
6080 | 6115 |
|
6081 |
| - const int64_t IW = src->ne[0]; |
6082 |
| - const int64_t IH = src->ne[1]; |
6083 |
| - const int64_t IC = src->ne[2]; |
| 6116 | + const int32_t stride_x = dst->op_params[0]; |
| 6117 | + const int32_t stride_y = dst->op_params[1]; |
| 6118 | + const int32_t pad_x = dst->op_params[2]; |
| 6119 | + const int32_t pad_y = dst->op_params[3]; |
6084 | 6120 |
|
6085 |
| - const int64_t KW = kernel->ne[0]; |
6086 |
| - const int64_t KH = kernel->ne[1]; |
| 6121 | + const int64_t c_in = src->ne[2]; |
| 6122 | + const int64_t c_out = kernel->ne[3]; |
| 6123 | + GGML_ASSERT(c_in == kernel->ne[2]); |
6087 | 6124 |
|
6088 |
| - const float * kernel_data = (const float *)kernel->data; |
6089 |
| - const float * src_data = (const float *)src->data; |
6090 |
| - float * dst_data = (float *)dst->data; |
| 6125 | + const int64_t src_w = src->ne[0]; |
| 6126 | + const int64_t src_h = src->ne[1]; |
| 6127 | + const int64_t knl_w = kernel->ne[0]; |
| 6128 | + const int64_t knl_h = kernel->ne[1]; |
| 6129 | + const int64_t dst_w = dst->ne[0]; |
| 6130 | + const int64_t dst_h = dst->ne[1]; |
6091 | 6131 |
|
6092 |
| - const int64_t rows_total = OH * N; |
6093 |
| - const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth; |
6094 |
| - const int64_t row_start = params->ith * rows_per_thread; |
6095 |
| - const int64_t row_end = MIN(row_start + rows_per_thread, rows_total); |
6096 | 6132 |
|
6097 |
| - for (int64_t row = row_start; row < row_end; ++row) { |
6098 |
| - const int64_t oh = row % OH; |
6099 |
| - const int64_t n = row / OH; |
6100 |
| - const float * src_batch = src_data + n * IW * IH * IC; |
| 6133 | + float * src_data = (float *) src->data; |
| 6134 | + float * knl_data = (float *) kernel->data; |
| 6135 | + float * dst_data = ( float *) dst->data; |
6101 | 6136 |
|
6102 |
| - for (int64_t ow = 0; ow < OW; ++ow) { |
6103 |
| - for (int64_t oc = 0; oc < OC; ++oc) { |
6104 |
| - float sum = 0.0f; |
6105 |
| - const float * kernel_channel = kernel_data + oc * KW * KH * IC; |
6106 | 6137 |
|
6107 |
| - for (int64_t kh = 0; kh < KH; ++kh) { |
6108 |
| - const int64_t ih = oh * s1 - p1 + kh * d1; |
6109 |
| - if (ih < 0 || ih >= IH) continue; |
| 6138 | + const int64_t knl_n = knl_w * knl_h * c_in; |
| 6139 | + const int64_t patch_total = dst->ne[3] * dst_w * dst_h; |
| 6140 | + |
6110 | 6141 |
|
6111 |
| - for (int64_t kw = 0; kw < KW; ++kw) { |
6112 |
| - const int64_t iw = ow * s0 - p0 + kw * d0; |
6113 |
| - if (iw < 0 || iw >= IW) continue; |
| 6142 | + |
| 6143 | + const int64_t space_per_patch = knl_n * sizeof(float) + patch_total * c_out * sizeof(float); |
6114 | 6144 |
|
6115 |
| - #pragma omp simd |
6116 |
| - for (int64_t ic = 0; ic < IC; ++ic) { |
6117 |
| - const float * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH; |
6118 |
| - const float * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH; |
6119 |
| - sum += (*kernel_ptr) * (*src_ptr); |
| 6145 | + const int64_t batch_size = params->wsize / space_per_patch; |
| 6146 | + const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size; |
| 6147 | + const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch; |
| 6148 | + |
| 6149 | + |
| 6150 | + GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1); |
| 6151 | + |
| 6152 | + float * tmp = (float *) params->wdata; // per-thread scratch |
| 6153 | + |
| 6154 | + for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) { |
| 6155 | + |
| 6156 | + const int64_t patch_start_batch = batch_i * patches_per_batch; |
| 6157 | + const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, |
| 6158 | + patch_total); |
| 6159 | + const int64_t patch_n = patch_end_batch - patch_start_batch; |
| 6160 | + |
| 6161 | + const int64_t patch_per_thread = |
| 6162 | + (patch_n + params->nth - 1) / params->nth; |
| 6163 | + const int64_t patch_start = patch_start_batch + |
| 6164 | + params->ith * patch_per_thread; |
| 6165 | + const int64_t patch_end = std::min(patch_start + patch_per_thread, |
| 6166 | + patch_end_batch); |
| 6167 | + |
| 6168 | + //im2col for a patch |
| 6169 | + for (int64_t p = patch_start; p < patch_end; ++p) { |
| 6170 | + const int64_t b = p / (dst_w * dst_h); |
| 6171 | + const int64_t dy = (p / dst_w) % dst_h; |
| 6172 | + const int64_t dx = p % dst_w; |
| 6173 | + |
| 6174 | + const float * src_base = (const float *)((char *)src_data + b * src->nb[3]); |
| 6175 | + float * out_row = tmp + (p % patches_per_batch) * knl_n; |
| 6176 | + |
| 6177 | + // Extract patch in IC,KH,KW order (same as im2col) |
| 6178 | + for (int64_t ic = 0; ic < c_in; ++ic) { |
| 6179 | + for (int64_t ky = 0; ky < knl_h; ++ky) { |
| 6180 | + for (int64_t kx = 0; kx < knl_w; ++kx) { |
| 6181 | + const int64_t sy = dy * stride_y + ky - pad_y; |
| 6182 | + const int64_t sx = dx * stride_x + kx - pad_x; |
| 6183 | + |
| 6184 | + int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx; |
| 6185 | + |
| 6186 | + if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) { |
| 6187 | + out_row[dst_idx] = 0.0f; |
| 6188 | + } else { |
| 6189 | + float * src_ptr = (float *)((char *)src_base + |
| 6190 | + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]); |
| 6191 | + out_row[dst_idx] = *src_ptr; |
6120 | 6192 | }
|
6121 | 6193 | }
|
6122 | 6194 | }
|
| 6195 | + } |
| 6196 | + } // patches handled by this thread |
| 6197 | + |
| 6198 | + ggml_barrier(params->threadpool); // wait for all threads |
6123 | 6199 |
|
6124 |
| - dst_data[((n * OC + oc) * OH + oh) * OW + ow] = sum; |
| 6200 | + //GEMM output is patch_n * cout |
| 6201 | + float * gemm_output = tmp + patches_per_batch * knl_n; |
| 6202 | + |
| 6203 | + // GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out] |
| 6204 | + ggml_call_mul_mat(params, patch_n, c_out, knl_n, |
| 6205 | + tmp, knl_data, gemm_output); |
| 6206 | + |
| 6207 | + // Barrier to ensure GEMM completes before permutation |
| 6208 | + ggml_barrier(params->threadpool); |
| 6209 | + |
| 6210 | + // Distribute permutation work across threads |
| 6211 | + const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth; |
| 6212 | + const int64_t permute_start = params->ith * permute_per_thread; |
| 6213 | + const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n); |
| 6214 | + |
| 6215 | + // Each thread handles part of the permutation from [patch_n, c_out] to WHCN layout |
| 6216 | + for (int64_t i = permute_start; i < permute_end; ++i) { |
| 6217 | + const int64_t p = patch_start_batch + i; |
| 6218 | + const int64_t b = p / (dst_w * dst_h); // batch index |
| 6219 | + const int64_t dy = (p / dst_w) % dst_h; // height index |
| 6220 | + const int64_t dx = p % dst_w; // width index |
| 6221 | + |
| 6222 | + // Copy all channels for this spatial position |
| 6223 | + for (int64_t oc = 0; oc < c_out; ++oc) { |
| 6224 | + const float value = gemm_output[i * c_out + oc]; |
| 6225 | + // Write to WHCN layout: dst[w, h, c, n] |
| 6226 | + float * dst_ptr = (float *)((char *)dst_data + |
| 6227 | + dx * dst->nb[0] + dy * dst->nb[1] + oc * dst->nb[2] + b * dst->nb[3]); |
| 6228 | + *dst_ptr = value; |
6125 | 6229 | }
|
6126 | 6230 | }
|
6127 | 6231 | }
|
@@ -6206,7 +6310,7 @@ void ggml_compute_forward_conv_2d(
|
6206 | 6310 | } break;
|
6207 | 6311 | case GGML_TYPE_F32:
|
6208 | 6312 | {
|
6209 |
| - ggml_compute_forward_conv_2d_f32(params, src0, src1, dst); |
| 6313 | + ggml_compute_forward_conv_2d_f32(params, dst); |
6210 | 6314 | } break;
|
6211 | 6315 | default:
|
6212 | 6316 | {
|
|
0 commit comments