Skip to content

Commit c2e058f

Browse files
authored
vulkan/cuda: Fix im2col when KW!=KH (ggml-org#14789)
The tid is decomposed into "ow + ky*OW + kx*OW*KH". Change "ksize" to match.
1 parent c82d48e commit c2e058f

File tree

3 files changed

+4
-5
lines changed

3 files changed

+4
-5
lines changed

ggml/src/ggml-cuda/im2col.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ static __global__ void im2col_kernel(
1010
return;
1111
}
1212

13-
const int64_t ksize = OW * (KH > 1 ? KW : 1);
13+
const int64_t ksize = OW * KH;
1414
const int64_t kx = i / ksize;
1515
const int64_t kd = kx * ksize;
1616
const int64_t ky = (i - kd) / OW;

ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,10 @@ void main() {
4040
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
4141
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
4242
const int oh_s1 = int(oh) * p.s1;
43-
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
43+
const uint ksize = p.OW * p.KH;
4444

4545
const uint base_linear_idx = gidx * NUM_ITER;
4646

47-
const uint max_ky = ksize / p.OW;
48-
4947
uint current_kx = base_linear_idx / ksize;
5048
const uint rem = base_linear_idx - (current_kx * ksize);
5149
uint current_ky = rem / p.OW;
@@ -76,7 +74,7 @@ void main() {
7674

7775
if (++current_ix == p.OW) {
7876
current_ix = 0;
79-
if (++current_ky == max_ky) {
77+
if (++current_ky == p.KH) {
8078
current_ky = 0;
8179
current_kx++;
8280
}

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5093,6 +5093,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
50935093
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
50945094
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
50955095
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
5096+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));
50965097

50975098
// Conv_2D test cases
50985099
#ifdef DETAILED_TESTS

0 commit comments

Comments
 (0)