@@ -6592,18 +6592,21 @@ static void ggml_call_mul_mat(
6592
6592
6593
6593
// ggml_compute_forward_conv_2d
6594
6594
6595
- static void ggml_compute_forward_conv_2d_f32 (const ggml_compute_params * params,
6596
- ggml_tensor * dst) {
6597
-
6598
- const ggml_tensor * src = dst-> src [ 1 ]; // [W H C_in N]
6599
- const ggml_tensor * kernel = dst-> src [ 0 ]; // [W H C_in C_out ]
6595
+ static void ggml_compute_forward_conv_2d_f32 (
6596
+ const ggml_compute_params * params,
6597
+ const ggml_tensor * kernel, // [KW, KH, IC, OC] - fp32
6598
+ const ggml_tensor * src, // [W, H, C, N]
6599
+ ggml_tensor * dst) { // [OW, OH, OC, N ]
6600
6600
6601
6601
GGML_ASSERT (ggml_is_contiguous (kernel));
6602
+ GGML_ASSERT (kernel->type == GGML_TYPE_F32);
6602
6603
6603
- const int32_t stride_x = dst->op_params [0 ];
6604
- const int32_t stride_y = dst->op_params [1 ];
6605
- const int32_t pad_x = dst->op_params [2 ];
6606
- const int32_t pad_y = dst->op_params [3 ];
6604
+ const int32_t stride_x = dst->op_params [0 ];
6605
+ const int32_t stride_y = dst->op_params [1 ];
6606
+ const int32_t pad_x = dst->op_params [2 ];
6607
+ const int32_t pad_y = dst->op_params [3 ];
6608
+ const int32_t dilation_x = dst->op_params [4 ];
6609
+ const int32_t dilation_y = dst->op_params [5 ];
6607
6610
6608
6611
const int64_t c_in = src->ne [2 ];
6609
6612
const int64_t c_out = kernel->ne [3 ];
@@ -6616,193 +6619,104 @@ static void ggml_compute_forward_conv_2d_f32(const ggml_compute_params * params,
6616
6619
const int64_t dst_w = dst->ne [0 ];
6617
6620
const int64_t dst_h = dst->ne [1 ];
6618
6621
6619
-
6620
- float * src_data = (float *) src->data ;
6621
- float * knl_data = (float *) kernel->data ;
6622
- float * dst_data = ( float *) dst->data ;
6623
-
6622
+ float * src_data = (float *) src->data ;
6623
+ float * knl_data = (float *) kernel->data ;
6624
+ float * dst_data = (float *) dst->data ;
6624
6625
6625
6626
const int64_t knl_n = knl_w * knl_h * c_in;
6626
6627
const int64_t patch_total = dst->ne [3 ] * dst_w * dst_h;
6627
-
6628
-
6629
-
6630
- const int64_t space_per_patch = knl_n * sizeof (float ) + patch_total * c_out * sizeof (float );
6631
6628
6632
- const int64_t batch_size = params->wsize / space_per_patch;
6629
+ const int64_t space_per_patch = knl_n * sizeof (float ) + c_out * sizeof (float );
6630
+ const int64_t batch_size = params->wsize / space_per_patch;
6633
6631
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8 ) * 8 : batch_size;
6634
- const int64_t batch_n = (patch_total + patches_per_batch - 1 ) / patches_per_batch;
6635
-
6632
+ const int64_t batch_n = (patch_total + patches_per_batch - 1 ) / patches_per_batch;
6636
6633
6637
6634
GGML_ASSERT (patches_per_batch > 0 && batch_size >= 1 );
6638
6635
6639
- float * tmp = (float *) params->wdata ; // per-thread scratch
6636
+ float * tmp = (float *) params->wdata ;
6640
6637
6641
6638
for (int64_t batch_i = 0 ; batch_i < batch_n; ++batch_i) {
6642
6639
6643
6640
const int64_t patch_start_batch = batch_i * patches_per_batch;
6644
6641
const int64_t patch_end_batch = std::min (patch_start_batch + patches_per_batch,
6645
6642
patch_total);
6646
- const int64_t patch_n = patch_end_batch - patch_start_batch;
6643
+ const int64_t patch_n = patch_end_batch - patch_start_batch;
6647
6644
6648
- const int64_t patch_per_thread =
6649
- (patch_n + params->nth - 1 ) / params->nth ;
6650
- const int64_t patch_start = patch_start_batch +
6651
- params->ith * patch_per_thread;
6652
- const int64_t patch_end = std::min (patch_start + patch_per_thread,
6653
- patch_end_batch);
6645
+ const int64_t patch_per_thread = (patch_n + params->nth - 1 ) / params->nth ;
6646
+ const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6647
+ const int64_t patch_end = std::min (patch_start + patch_per_thread,patch_end_batch);
6654
6648
6655
6649
// im2col for a patch
6656
6650
for (int64_t p = patch_start; p < patch_end; ++p) {
6657
- const int64_t b = p / (dst_w * dst_h);
6658
- const int64_t dy = (p / dst_w) % dst_h;
6659
- const int64_t dx = p % dst_w;
6651
+ const int64_t batch_n = p / (dst_w * dst_h);
6652
+ const int64_t src_x = (p / dst_w) % dst_h;
6653
+ const int64_t src_y = p % dst_w;
6660
6654
6661
- const float * src_base = (const float *)((char *)src_data + b * src->nb [3 ]);
6662
- float * out_row = tmp + (p % patches_per_batch) * knl_n;
6655
+ float * src_base = (float *)((char *)src_data + batch_n * src->nb [3 ]);
6656
+ float * dst_row = tmp + (p % patches_per_batch) * knl_n;
6663
6657
6664
- // Extract patch in IC,KH,KW order (same as im2col)
6665
6658
for (int64_t ic = 0 ; ic < c_in; ++ic) {
6666
6659
for (int64_t ky = 0 ; ky < knl_h; ++ky) {
6667
6660
for (int64_t kx = 0 ; kx < knl_w; ++kx) {
6668
- const int64_t sy = dy * stride_y + ky - pad_y;
6669
- const int64_t sx = dx * stride_x + kx - pad_x;
6670
-
6661
+ const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6662
+ const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6663
+
6671
6664
int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6672
-
6665
+
6673
6666
if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6674
- out_row [dst_idx] = 0 .0f ;
6667
+ dst_row [dst_idx] = 0 .0f ;
6675
6668
} else {
6676
- float * src_ptr = (float *)((char *)src_base +
6677
- sx * src->nb [0 ] + sy * src->nb [1 ] + ic * src->nb [2 ]);
6678
- out_row[dst_idx] = *src_ptr;
6669
+ float * src_ptr = (float *)((char *)src_base + sx * src->nb [0 ] + sy * src->nb [1 ] + ic * src->nb [2 ]);
6670
+ dst_row[dst_idx] = *src_ptr;
6679
6671
}
6680
6672
}
6681
6673
}
6682
6674
}
6683
6675
} // patches handled by this thread
6684
6676
6685
- ggml_barrier (params->threadpool ); // wait for all threads
6677
+ ggml_barrier (params->threadpool );
6686
6678
6687
- // GEMM output is patch_n * cout
6688
6679
float * gemm_output = tmp + patches_per_batch * knl_n;
6689
-
6680
+
6690
6681
// GEMM: patches[patch_n, knl_n] × kernel[knl_n, c_out] = output[patch_n, c_out]
6691
6682
ggml_call_mul_mat (params, patch_n, c_out, knl_n,
6692
6683
tmp, knl_data, gemm_output);
6693
-
6694
- // Barrier to ensure GEMM completes before permutation
6684
+
6695
6685
ggml_barrier (params->threadpool );
6696
-
6697
- // Distribute permutation work across threads
6686
+
6687
+
6688
+ // permute back [OC, N, OH, OW] to [N, OC, OH, OW]
6698
6689
const int64_t permute_per_thread = (patch_n + params->nth - 1 ) / params->nth ;
6699
6690
const int64_t permute_start = params->ith * permute_per_thread;
6700
6691
const int64_t permute_end = std::min (permute_start + permute_per_thread, patch_n);
6701
-
6702
- // Each thread handles part of the permutation from [patch_n, c_out] to WHCN layout
6692
+
6703
6693
for (int64_t i = permute_start; i < permute_end; ++i) {
6704
- const int64_t p = patch_start_batch + i;
6705
- const int64_t b = p / (dst_w * dst_h); // batch index
6706
- const int64_t dy = (p / dst_w) % dst_h; // height index
6707
- const int64_t dx = p % dst_w; // width index
6708
-
6709
- // Copy all channels for this spatial position
6694
+ const int64_t p = patch_start_batch + i;
6695
+ const int64_t batch_n = p / (dst_w * dst_h);
6696
+ const int64_t dst_y = (p / dst_w) % dst_h;
6697
+ const int64_t dst_x = p % dst_w;
6698
+
6710
6699
for (int64_t oc = 0 ; oc < c_out; ++oc) {
6711
6700
const float value = gemm_output[i * c_out + oc];
6712
6701
// Write to WHCN layout: dst[w, h, c, n]
6713
- float * dst_ptr = (float *)((char *)dst_data +
6714
- dx * dst->nb [0 ] + dy * dst->nb [1 ] + oc * dst->nb [2 ] + b * dst->nb [3 ]);
6702
+ float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb [0 ] + dst_y * dst->nb [1 ] + oc * dst->nb [2 ] + batch_n * dst->nb [3 ]);
6715
6703
*dst_ptr = value;
6716
6704
}
6717
6705
}
6718
6706
}
6719
6707
}
6720
6708
6721
- static void ggml_compute_forward_conv_2d_f16 (
6722
- const ggml_compute_params * params,
6723
- const ggml_tensor * kernel, // [KW, KH, IC, OC]
6724
- const ggml_tensor * src, // [W, H, C, N]
6725
- ggml_tensor * dst) { // [OW, OH, OC, N]
6726
-
6727
- const int32_t s0 = ggml_get_op_params_i32 (dst, 0 );
6728
- const int32_t s1 = ggml_get_op_params_i32 (dst, 1 );
6729
- const int32_t p0 = ggml_get_op_params_i32 (dst, 2 );
6730
- const int32_t p1 = ggml_get_op_params_i32 (dst, 3 );
6731
- const int32_t d0 = ggml_get_op_params_i32 (dst, 4 );
6732
- const int32_t d1 = ggml_get_op_params_i32 (dst, 5 );
6733
-
6734
- const int64_t OW = dst->ne [0 ];
6735
- const int64_t OH = dst->ne [1 ];
6736
- const int64_t OC = dst->ne [2 ];
6737
- const int64_t N = dst->ne [3 ];
6738
-
6739
- const int64_t IW = src->ne [0 ];
6740
- const int64_t IH = src->ne [1 ];
6741
- const int64_t IC = src->ne [2 ];
6742
-
6743
- const int64_t KW = kernel->ne [0 ];
6744
- const int64_t KH = kernel->ne [1 ];
6745
-
6746
- const ggml_fp16_t * kernel_data = (const ggml_fp16_t *)kernel->data ;
6747
- const ggml_fp16_t * src_data = (const ggml_fp16_t *)src->data ;
6748
- ggml_fp16_t * dst_data = (ggml_fp16_t *)dst->data ;
6749
-
6750
- const int64_t rows_total = OH * N;
6751
- const int64_t rows_per_thread = (rows_total + params->nth - 1 ) / params->nth ;
6752
- const int64_t row_start = params->ith * rows_per_thread;
6753
- const int64_t row_end = MIN (row_start + rows_per_thread, rows_total);
6754
-
6755
- for (int64_t row = row_start; row < row_end; ++row) {
6756
- const int64_t oh = row % OH;
6757
- const int64_t n = row / OH;
6758
- const ggml_fp16_t * src_batch = src_data + n * IW * IH * IC;
6759
-
6760
- for (int64_t ow = 0 ; ow < OW; ++ow) {
6761
- for (int64_t oc = 0 ; oc < OC; ++oc) {
6762
- float sum = 0 .0f ;
6763
- const ggml_fp16_t * kernel_channel = kernel_data + oc * KW * KH * IC;
6764
- for (int64_t kh = 0 ; kh < KH; ++kh) {
6765
- const int64_t ih = oh * s1 - p1 + kh * d1;
6766
- if (ih < 0 || ih >= IH) continue ;
6767
-
6768
- for (int64_t kw = 0 ; kw < KW; ++kw) {
6769
- const int64_t iw = ow * s0 - p0 + kw * d0;
6770
- if (iw < 0 || iw >= IW) continue ;
6771
-
6772
- for (int64_t ic = 0 ; ic < IC; ++ic) {
6773
- const ggml_fp16_t * kernel_ptr = kernel_channel + (kh * KW + kw) + ic * KW * KH;
6774
- const ggml_fp16_t * src_ptr = src_batch + (ih * IW + iw) + ic * IW * IH;
6775
- sum += GGML_FP16_TO_FP32 (*kernel_ptr) * GGML_FP16_TO_FP32 (*src_ptr);
6776
- }
6777
- }
6778
- }
6779
-
6780
- dst_data[((n * OC + oc) * OH + oh) * OW + ow] = GGML_FP32_TO_FP16 (sum);
6781
- }
6782
- }
6783
- }
6784
- }
6785
-
6786
6709
void ggml_compute_forward_conv_2d (
6787
6710
const ggml_compute_params * params,
6788
6711
ggml_tensor * dst) {
6789
6712
6790
6713
const ggml_tensor * src0 = dst->src [0 ];
6791
6714
const ggml_tensor * src1 = dst->src [1 ];
6792
6715
6793
- switch (src0->type ) {
6794
- case GGML_TYPE_F16:
6795
- {
6796
- ggml_compute_forward_conv_2d_f16 (params, src0, src1, dst);
6797
- } break ;
6798
- case GGML_TYPE_F32:
6799
- {
6800
- ggml_compute_forward_conv_2d_f32 (params, dst);
6801
- } break ;
6802
- default :
6803
- {
6804
- GGML_ABORT (" fatal error" );
6805
- }
6716
+ if (src0->type == GGML_TYPE_F16) {
6717
+ GGML_ASSERT (false && " F16 not supported yet" );
6718
+ } else {
6719
+ ggml_compute_forward_conv_2d_f32 (params, src0, src1, dst);
6806
6720
}
6807
6721
}
6808
6722
0 commit comments