Skip to content

Commit ba66175

Browse files
ggerganovleejet
andauthored
sync : ggml (fix im2col) (#4591)
* cuda : fix im2col_f32_f16 (ggml/#658) ggml-ci * ggml-alloc : fix ggml_tallocr_is_own --------- Co-authored-by: leejet <leejet714@gmail.com>
1 parent a558769 commit ba66175

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

ggml-alloc.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ static void remove_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * t
7272

7373
// check if a tensor is allocated by this buffer
7474
static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) {
75-
return tensor->buffer == alloc->buffer;
75+
return tensor->buffer == alloc->buffer && (!tensor->view_src || tensor->view_src->buffer == alloc->buffer);
7676
}
7777

7878
static bool ggml_is_view(struct ggml_tensor * t) {

ggml-cuda.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5273,17 +5273,17 @@ static __global__ void im2col_f32_f16(
52735273
const int ky = (i - kd) / OW;
52745274
const int ix = i % OW;
52755275

5276-
const int iiw = ix * s0 + kx * d0 - p0;
5277-
const int iih = blockIdx.y * s1 + ky * d1 - p1;
5276+
const int64_t iiw = ix * s0 + kx * d0 - p0;
5277+
const int64_t iih = blockIdx.y * s1 + ky * d1 - p1;
52785278

5279-
const int offset_dst =
5279+
const int64_t offset_dst =
52805280
(blockIdx.y * OW + ix) * CHW +
52815281
(blockIdx.z * (KW * KH) + ky * KW + kx);
52825282

52835283
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
52845284
dst[offset_dst] = __float2half(0.0f);
52855285
} else {
5286-
const int offset_src = blockIdx.z * offset_delta;
5286+
const int64_t offset_src = blockIdx.z * offset_delta;
52875287
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
52885288
}
52895289
}

0 commit comments

Comments
 (0)