Skip to content

Commit efe6a83

Browse files
smesoggerganov
andcommitted
ggml : fix cont with transposed tensors when one dimension is 1 (ggml/934)
* ggml_cont: fix issue with transposed tensors when one dimension is 1 when using multiple threads, it is not enough to check for the tensors to be contiguous for ggml_compute_forward_dup_same_cont to work correctly. The tensors strides also need to match. Signed-off-by: Salvatore Mesoraca <s.mesoraca16@gmail.com> * Add ggml_cont tests Signed-off-by: Salvatore Mesoraca <s.mesoraca16@gmail.com> * Remove dead code it isn't possible to reach this code because all these functions are invoked by ggml_compute_forward_dup if and only if src0->type != dst->type Signed-off-by: Salvatore Mesoraca <s.mesoraca16@gmail.com> * Make ggml_compute_forward_dup_same_cont work with contiguous tensors Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Signed-off-by: Salvatore Mesoraca <s.mesoraca16@gmail.com> --------- Signed-off-by: Salvatore Mesoraca <s.mesoraca16@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent fbb7fcf commit efe6a83

File tree

2 files changed

+14
-21
lines changed

2 files changed

+14
-21
lines changed

ggml/src/ggml.c

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8322,8 +8322,7 @@ static void ggml_compute_forward_dup_same_cont(
83228322
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
83238323
GGML_ASSERT(src0->type == dst->type);
83248324

8325-
const size_t nb00 = src0->nb[0];
8326-
const size_t nb0 = dst->nb[0];
8325+
const size_t nb0 = ggml_type_size(src0->type);
83278326

83288327
const int ith = params->ith; // thread index
83298328
const int nth = params->nth; // number of threads
@@ -8337,8 +8336,8 @@ static void ggml_compute_forward_dup_same_cont(
83378336
if (ie0 < ie1) {
83388337
memcpy(
83398338
((char *) dst->data + ie0*nb0),
8340-
((char *) src0->data + ie0*nb00),
8341-
(ie1 - ie0) * ggml_type_size(src0->type));
8339+
((char *) src0->data + ie0*nb0),
8340+
(ie1 - ie0) * nb0);
83428341
}
83438342
}
83448343

@@ -8355,11 +8354,6 @@ static void ggml_compute_forward_dup_f16(
83558354
const int ith = params->ith; // thread index
83568355
const int nth = params->nth; // number of threads
83578356

8358-
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
8359-
ggml_compute_forward_dup_same_cont(params, dst);
8360-
return;
8361-
}
8362-
83638357
// parallelize by rows
83648358
const int nr = ne01;
83658359
// number of rows per thread
@@ -8624,11 +8618,6 @@ static void ggml_compute_forward_dup_bf16(
86248618
const int ith = params->ith; // thread index
86258619
const int nth = params->nth; // number of threads
86268620

8627-
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
8628-
ggml_compute_forward_dup_same_cont(params, dst);
8629-
return;
8630-
}
8631-
86328621
// parallelize by rows
86338622
const int nr = ne01;
86348623
// number of rows per thread
@@ -8980,11 +8969,6 @@ static void ggml_compute_forward_dup_f32(
89808969
const int ith = params->ith; // thread index
89818970
const int nth = params->nth; // number of threads
89828971

8983-
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
8984-
ggml_compute_forward_dup_same_cont(params, dst);
8985-
return;
8986-
}
8987-
89888972
// parallelize by rows
89898973
const int nr = ne01;
89908974
// number of rows per thread
@@ -9294,13 +9278,13 @@ static void ggml_compute_forward_dup_bytes(
92949278
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
92959279
GGML_ASSERT(src0->type == dst->type);
92969280

9281+
GGML_TENSOR_UNARY_OP_LOCALS;
9282+
92979283
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
92989284
ggml_compute_forward_dup_same_cont(params, dst);
92999285
return;
93009286
}
93019287

9302-
GGML_TENSOR_UNARY_OP_LOCALS;
9303-
93049288
const size_t type_size = ggml_type_size(src0->type);
93059289
const int ith = params->ith; // thread index
93069290
const int nth = params->nth; // number of threads

tests/test-backend-ops.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2322,6 +2322,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
23222322
}
23232323

23242324
test_cases.emplace_back(new test_cont());
2325+
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 1 ,1}));
2326+
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 1, 3 ,5}));
2327+
test_cases.emplace_back(new test_cont(GGML_TYPE_F32, {2, 3, 5 ,7}));
2328+
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 1 ,1}));
2329+
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 1, 3 ,5}));
2330+
test_cases.emplace_back(new test_cont(GGML_TYPE_F16, {2, 3, 5 ,7}));
2331+
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 1 ,1}));
2332+
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 3 ,5}));
2333+
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
23252334

23262335
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
23272336
for (auto op : {ggml_add, ggml_mul, ggml_div}) {

0 commit comments

Comments
 (0)