@@ -6546,29 +6546,29 @@ void ggml_compute_forward_im2col_back_f32(
6546
6546
}
6547
6547
}
6548
6548
6549
- static void ggml_call_mul_mat (
6550
- const ggml_compute_params * params,
6551
- int64_t m, int64_t n, int64_t k,
6552
- void * a, void * b, void * c) {
6553
-
6549
+ static void ggml_call_mul_mat (ggml_type T, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6550
+ void * a, void * b, void * c) {
6551
+ const ggml_type_traits * traits = ggml_get_type_traits (T);
6554
6552
struct ggml_tensor src1 = {};
6553
+ src1.type = T;
6555
6554
src1.ne [0 ] = k;
6556
6555
src1.ne [1 ] = m;
6557
6556
src1.ne [2 ] = 1 ;
6558
6557
src1.ne [3 ] = 1 ;
6559
- src1.nb [0 ] = sizeof ( float ) ;
6560
- src1.nb [1 ] = k * sizeof ( float ) ;
6558
+ src1.nb [0 ] = traits-> type_size ;
6559
+ src1.nb [1 ] = k * traits-> type_size ;
6561
6560
src1.nb [2 ] = src1.nb [1 ];
6562
6561
src1.nb [3 ] = src1.nb [2 ];
6563
6562
src1.data = a;
6564
6563
6565
6564
struct ggml_tensor src0 = {};
6565
+ src0.type = T;
6566
6566
src0.ne [0 ] = k;
6567
6567
src0.ne [1 ] = n;
6568
6568
src0.ne [2 ] = 1 ;
6569
6569
src0.ne [3 ] = 1 ;
6570
- src0.nb [0 ] = sizeof ( float ) ;
6571
- src0.nb [1 ] = k * sizeof ( float ) ;
6570
+ src0.nb [0 ] = traits-> type_size ;
6571
+ src0.nb [1 ] = k * traits-> type_size ;
6572
6572
src0.nb [2 ] = src0.nb [1 ];
6573
6573
src0.nb [3 ] = src0.nb [2 ];
6574
6574
src0.data = b;
@@ -6589,17 +6589,18 @@ static void ggml_call_mul_mat(
6589
6589
ggml_compute_forward_mul_mat (params, &dst);
6590
6590
}
6591
6591
6592
-
6593
6592
// ggml_compute_forward_conv_2d
6594
6593
6595
- static void ggml_compute_forward_conv_2d_f32 (
6596
- const ggml_compute_params * params,
6597
- const ggml_tensor * kernel , // [KW, KH, IC, OC] - fp32
6598
- const ggml_tensor * src , // [W, H, C , N]
6599
- ggml_tensor * dst) { // [OW, OH, OC, N]
6594
+ static void ggml_compute_forward_conv_2d_impl ( const ggml_compute_params * params,
6595
+ const ggml_tensor * kernel, // [KW, KH, IC, OC]
6596
+ const ggml_tensor * src , // [W, H, C, N]
6597
+ ggml_tensor * dst , // [OW, OH, OC , N]
6598
+ ggml_type kernel_type) {
6600
6599
6601
6600
GGML_ASSERT (ggml_is_contiguous (kernel));
6602
- GGML_ASSERT (kernel->type == GGML_TYPE_F32);
6601
+ GGML_ASSERT (kernel->type == kernel_type);
6602
+
6603
+ const ggml_type_traits * traits = ggml_get_type_traits (kernel_type);
6603
6604
6604
6605
const int32_t stride_x = dst->op_params [0 ];
6605
6606
const int32_t stride_y = dst->op_params [1 ];
@@ -6620,20 +6621,20 @@ static void ggml_compute_forward_conv_2d_f32(
6620
6621
const int64_t dst_h = dst->ne [1 ];
6621
6622
6622
6623
float * src_data = (float *) src->data ;
6623
- float * knl_data = ( float *) kernel->data ;
6624
+ void * knl_data = kernel->data ;
6624
6625
float * dst_data = (float *) dst->data ;
6625
6626
6626
6627
const int64_t knl_n = knl_w * knl_h * c_in;
6627
6628
const int64_t patch_total = dst->ne [3 ] * dst_w * dst_h;
6628
6629
6629
- const int64_t space_per_patch = knl_n * sizeof ( float ) + c_out * sizeof (float );
6630
+ const int64_t space_per_patch = knl_n * traits-> type_size + c_out * sizeof (float );
6630
6631
const int64_t batch_size = params->wsize / space_per_patch;
6631
6632
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8 ) * 8 : batch_size;
6632
6633
const int64_t batch_n = (patch_total + patches_per_batch - 1 ) / patches_per_batch;
6633
6634
6634
6635
GGML_ASSERT (patches_per_batch > 0 && batch_size >= 1 );
6635
6636
6636
- float * tmp = ( float *) params->wdata ;
6637
+ void * tmp = params->wdata ;
6637
6638
6638
6639
for (int64_t batch_i = 0 ; batch_i < batch_n; ++batch_i) {
6639
6640
@@ -6653,7 +6654,7 @@ static void ggml_compute_forward_conv_2d_f32(
6653
6654
const int64_t src_y = p % dst_w;
6654
6655
6655
6656
float * src_base = (float *)((char *)src_data + batch_n * src->nb [3 ]);
6656
- float * dst_row = tmp + (p % patches_per_batch) * knl_n;
6657
+ char * dst_row = ( char *) tmp + (p % patches_per_batch) * knl_n * traits-> type_size ;
6657
6658
6658
6659
for (int64_t ic = 0 ; ic < c_in; ++ic) {
6659
6660
for (int64_t ky = 0 ; ky < knl_h; ++ky) {
@@ -6663,11 +6664,19 @@ static void ggml_compute_forward_conv_2d_f32(
6663
6664
6664
6665
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6665
6666
6667
+ float src_val;
6666
6668
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6667
- dst_row[dst_idx] = 0 .0f ;
6669
+ src_val = 0 .0f ;
6668
6670
} else {
6669
6671
float * src_ptr = (float *)((char *)src_base + sx * src->nb [0 ] + sy * src->nb [1 ] + ic * src->nb [2 ]);
6670
- dst_row[dst_idx] = *src_ptr;
6672
+ src_val = *src_ptr;
6673
+ }
6674
+
6675
+ char * element_ptr = dst_row + dst_idx * traits->type_size ;
6676
+ if (kernel_type == GGML_TYPE_F32) {
6677
+ *(float *) element_ptr = src_val;
6678
+ } else if (kernel_type == GGML_TYPE_F16) {
6679
+ *(ggml_fp16_t *) element_ptr = GGML_FP32_TO_FP16 (src_val);
6671
6680
}
6672
6681
}
6673
6682
}
@@ -6676,11 +6685,10 @@ static void ggml_compute_forward_conv_2d_f32(
6676
6685
6677
6686
ggml_barrier (params->threadpool );
6678
6687
6679
- float * gemm_output = tmp + patches_per_batch * knl_n;
6688
+ float * gemm_output = ( float *) (( char *) tmp + patches_per_batch * knl_n * traits-> type_size ) ;
6680
6689
6681
6690
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6682
- ggml_call_mul_mat (params, patch_n, c_out, knl_n,
6683
- tmp, knl_data, gemm_output);
6691
+ ggml_call_mul_mat (kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
6684
6692
6685
6693
ggml_barrier (params->threadpool );
6686
6694
@@ -6698,7 +6706,6 @@ static void ggml_compute_forward_conv_2d_f32(
6698
6706
6699
6707
for (int64_t oc = 0 ; oc < c_out; ++oc) {
6700
6708
const float value = gemm_output[i * c_out + oc];
6701
- // Write to WHCN layout: dst[w, h, c, n]
6702
6709
float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb [0 ] + dst_y * dst->nb [1 ] + oc * dst->nb [2 ] + batch_n * dst->nb [3 ]);
6703
6710
*dst_ptr = value;
6704
6711
}
@@ -6713,11 +6720,7 @@ void ggml_compute_forward_conv_2d(
6713
6720
const ggml_tensor * src0 = dst->src [0 ];
6714
6721
const ggml_tensor * src1 = dst->src [1 ];
6715
6722
6716
- if (src0->type == GGML_TYPE_F16) {
6717
- GGML_ASSERT (false && " F16 not supported yet" );
6718
- } else {
6719
- ggml_compute_forward_conv_2d_f32 (params, src0, src1, dst);
6720
- }
6723
+ ggml_compute_forward_conv_2d_impl (params, src0, src1, dst, src0->type );
6721
6724
}
6722
6725
6723
6726
// ggml_compute_forward_conv_transpose_2d
0 commit comments