@@ -6117,29 +6117,29 @@ void ggml_compute_forward_im2col_back_f32(
6117
6117
}
6118
6118
}
6119
6119
6120
- static void ggml_call_mul_mat (
6121
- const ggml_compute_params * params,
6122
- int64_t m, int64_t n, int64_t k,
6123
- void * a, void * b, void * c) {
6124
-
6120
+ static void ggml_call_mul_mat (ggml_type T, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6121
+ void * a, void * b, void * c) {
6122
+ const ggml_type_traits * traits = ggml_get_type_traits (T);
6125
6123
struct ggml_tensor src1 = {};
6124
+ src1.type = T;
6126
6125
src1.ne [0 ] = k;
6127
6126
src1.ne [1 ] = m;
6128
6127
src1.ne [2 ] = 1 ;
6129
6128
src1.ne [3 ] = 1 ;
6130
- src1.nb [0 ] = sizeof ( float ) ;
6131
- src1.nb [1 ] = k * sizeof ( float ) ;
6129
+ src1.nb [0 ] = traits-> type_size ;
6130
+ src1.nb [1 ] = k * traits-> type_size ;
6132
6131
src1.nb [2 ] = src1.nb [1 ];
6133
6132
src1.nb [3 ] = src1.nb [2 ];
6134
6133
src1.data = a;
6135
6134
6136
6135
struct ggml_tensor src0 = {};
6136
+ src0.type = T;
6137
6137
src0.ne [0 ] = k;
6138
6138
src0.ne [1 ] = n;
6139
6139
src0.ne [2 ] = 1 ;
6140
6140
src0.ne [3 ] = 1 ;
6141
- src0.nb [0 ] = sizeof ( float ) ;
6142
- src0.nb [1 ] = k * sizeof ( float ) ;
6141
+ src0.nb [0 ] = traits-> type_size ;
6142
+ src0.nb [1 ] = k * traits-> type_size ;
6143
6143
src0.nb [2 ] = src0.nb [1 ];
6144
6144
src0.nb [3 ] = src0.nb [2 ];
6145
6145
src0.data = b;
@@ -6160,17 +6160,18 @@ static void ggml_call_mul_mat(
6160
6160
ggml_compute_forward_mul_mat (params, &dst);
6161
6161
}
6162
6162
6163
-
6164
6163
// ggml_compute_forward_conv_2d
6165
6164
6166
- static void ggml_compute_forward_conv_2d_f32 (
6167
- const ggml_compute_params * params,
6168
- const ggml_tensor * kernel , // [KW, KH, IC, OC] - fp32
6169
- const ggml_tensor * src , // [W, H, C , N]
6170
- ggml_tensor * dst) { // [OW, OH, OC, N]
6165
+ static void ggml_compute_forward_conv_2d_impl ( const ggml_compute_params * params,
6166
+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6167
+ const ggml_tensor * src , // [W, H, C, N]
6168
+ ggml_tensor * dst , // [OW, OH, OC , N]
6169
+ ggml_type kernel_type) {
6171
6170
6172
6171
GGML_ASSERT (ggml_is_contiguous (kernel));
6173
- GGML_ASSERT (kernel->type == GGML_TYPE_F32);
6172
+ GGML_ASSERT (kernel->type == kernel_type);
6173
+
6174
+ const ggml_type_traits * traits = ggml_get_type_traits (kernel_type);
6174
6175
6175
6176
const int32_t stride_x = dst->op_params [0 ];
6176
6177
const int32_t stride_y = dst->op_params [1 ];
@@ -6191,20 +6192,20 @@ static void ggml_compute_forward_conv_2d_f32(
6191
6192
const int64_t dst_h = dst->ne [1 ];
6192
6193
6193
6194
float * src_data = (float *) src->data ;
6194
- float * knl_data = ( float *) kernel->data ;
6195
+ void * knl_data = kernel->data ;
6195
6196
float * dst_data = (float *) dst->data ;
6196
6197
6197
6198
const int64_t knl_n = knl_w * knl_h * c_in;
6198
6199
const int64_t patch_total = dst->ne [3 ] * dst_w * dst_h;
6199
6200
6200
- const int64_t space_per_patch = knl_n * sizeof ( float ) + c_out * sizeof (float );
6201
+ const int64_t space_per_patch = knl_n * traits-> type_size + c_out * sizeof (float );
6201
6202
const int64_t batch_size = params->wsize / space_per_patch;
6202
6203
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8 ) * 8 : batch_size;
6203
6204
const int64_t batch_n = (patch_total + patches_per_batch - 1 ) / patches_per_batch;
6204
6205
6205
6206
GGML_ASSERT (patches_per_batch > 0 && batch_size >= 1 );
6206
6207
6207
- float * tmp = ( float *) params->wdata ;
6208
+ void * tmp = params->wdata ;
6208
6209
6209
6210
for (int64_t batch_i = 0 ; batch_i < batch_n; ++batch_i) {
6210
6211
@@ -6224,7 +6225,7 @@ static void ggml_compute_forward_conv_2d_f32(
6224
6225
const int64_t src_y = p % dst_w;
6225
6226
6226
6227
float * src_base = (float *)((char *)src_data + batch_n * src->nb [3 ]);
6227
- float * dst_row = tmp + (p % patches_per_batch) * knl_n;
6228
+ char * dst_row = ( char *) tmp + (p % patches_per_batch) * knl_n * traits-> type_size ;
6228
6229
6229
6230
for (int64_t ic = 0 ; ic < c_in; ++ic) {
6230
6231
for (int64_t ky = 0 ; ky < knl_h; ++ky) {
@@ -6234,11 +6235,19 @@ static void ggml_compute_forward_conv_2d_f32(
6234
6235
6235
6236
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6236
6237
6238
+ float src_val;
6237
6239
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6238
- dst_row[dst_idx] = 0 .0f ;
6240
+ src_val = 0 .0f ;
6239
6241
} else {
6240
6242
float * src_ptr = (float *)((char *)src_base + sx * src->nb [0 ] + sy * src->nb [1 ] + ic * src->nb [2 ]);
6241
- dst_row[dst_idx] = *src_ptr;
6243
+ src_val = *src_ptr;
6244
+ }
6245
+
6246
+ char * element_ptr = dst_row + dst_idx * traits->type_size ;
6247
+ if (kernel_type == GGML_TYPE_F32) {
6248
+ *(float *) element_ptr = src_val;
6249
+ } else if (kernel_type == GGML_TYPE_F16) {
6250
+ *(ggml_fp16_t *) element_ptr = GGML_FP32_TO_FP16 (src_val);
6242
6251
}
6243
6252
}
6244
6253
}
@@ -6247,11 +6256,10 @@ static void ggml_compute_forward_conv_2d_f32(
6247
6256
6248
6257
ggml_barrier (params->threadpool );
6249
6258
6250
- float * gemm_output = tmp + patches_per_batch * knl_n;
6259
+ float * gemm_output = ( float *) (( char *) tmp + patches_per_batch * knl_n * traits-> type_size ) ;
6251
6260
6252
6261
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6253
- ggml_call_mul_mat (params, patch_n, c_out, knl_n,
6254
- tmp, knl_data, gemm_output);
6262
+ ggml_call_mul_mat (kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
6255
6263
6256
6264
ggml_barrier (params->threadpool );
6257
6265
@@ -6269,7 +6277,6 @@ static void ggml_compute_forward_conv_2d_f32(
6269
6277
6270
6278
for (int64_t oc = 0 ; oc < c_out; ++oc) {
6271
6279
const float value = gemm_output[i * c_out + oc];
6272
- // Write to WHCN layout: dst[w, h, c, n]
6273
6280
float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb [0 ] + dst_y * dst->nb [1 ] + oc * dst->nb [2 ] + batch_n * dst->nb [3 ]);
6274
6281
*dst_ptr = value;
6275
6282
}
@@ -6284,11 +6291,7 @@ void ggml_compute_forward_conv_2d(
6284
6291
const ggml_tensor * src0 = dst->src [0 ];
6285
6292
const ggml_tensor * src1 = dst->src [1 ];
6286
6293
6287
- if (src0->type == GGML_TYPE_F16) {
6288
- GGML_ASSERT (false && " F16 not supported yet" );
6289
- } else {
6290
- ggml_compute_forward_conv_2d_f32 (params, src0, src1, dst);
6291
- }
6294
+ ggml_compute_forward_conv_2d_impl (params, src0, src1, dst, src0->type );
6292
6295
}
6293
6296
6294
6297
// ggml_compute_forward_conv_transpose_2d
0 commit comments