Skip to content

Commit 822a6d3

Browse files
fairydreamingsszymczy
authored andcommitted
SYCL : support non-contiguous tensors in binary ops (add, sub, etc) (ggml-org#12399)
* sycl : support non-contiguous tensors in binary ops * sycl : silence unused variable warning --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
1 parent cb1de04 commit 822a6d3

File tree

1 file changed

+61
-26
lines changed

1 file changed

+61
-26
lines changed

ggml/src/ggml-sycl/common.hpp

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
434434
int ne0, int ne1, int ne2, int ne3,
435435
int ne10, int ne11, int ne12, int ne13,
436436
/*int s0, */ int s1, int s2, int s3,
437+
/*int s00,*/ int s01, int s02, int s03,
437438
/*int s10,*/ int s11, int s12, int s13,
438439
const sycl::nd_item<3> &item_ct1) {
439440
const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
@@ -455,9 +456,9 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
455456
const int i12 = i2 % ne12;
456457
const int i13 = i3 % ne13;
457458

458-
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
459+
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
459460
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
460-
const size_t i_dst = i_src0;
461+
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
461462

462463
const src0_t * src0_row = src0 + i_src0;
463464
const src1_t * src1_row = src1 + i_src1;
@@ -475,6 +476,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
475476
int ne0, int ne1, int ne2, int ne3,
476477
int ne10, int ne11, int ne12, int ne13,
477478
/*int s0, */ int s1, int s2, int s3,
479+
/*int s00,*/ int s01, int s02, int s03,
478480
/*int s10,*/ int s11, int s12, int s13,
479481
const sycl::nd_item<3> &item_ct1) {
480482

@@ -494,9 +496,9 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t
494496
const int i12 = i2 % ne12;
495497
const int i13 = i3 % ne13;
496498

497-
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
499+
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
498500
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
499-
const size_t i_dst = i_src0;
501+
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
500502

501503
const src0_t * src0_row = src0 + i_src0;
502504
const src1_t * src1_row = src1 + i_src1;
@@ -526,9 +528,11 @@ struct bin_bcast_sycl {
526528
int nr[4] = { nr0, nr1, nr2, nr3 };
527529

528530
// collapse dimensions until first broadcast dimension
529-
int64_t cne0[] = {ne0, ne1, ne2, ne3};
531+
int64_t cne[] = {ne0, ne1, ne2, ne3};
532+
int64_t cne0[] = {ne00, ne01, ne02, ne03};
530533
int64_t cne1[] = {ne10, ne11, ne12, ne13};
531-
size_t cnb0[] = {nb0, nb1, nb2, nb3};
534+
size_t cnb[] = {nb0, nb1, nb2, nb3};
535+
size_t cnb0[] = {nb00, nb01, nb02, nb03};
532536
size_t cnb1[] = {nb10, nb11, nb12, nb13};
533537
auto collapse = [](int64_t cne[]) {
534538
cne[0] *= cne[1];
@@ -543,32 +547,41 @@ struct bin_bcast_sycl {
543547
cnb[3] *= cne[3];
544548
};
545549

546-
for (int i = 0; i < 4; i++) {
547-
if (nr[i] != 1) {
548-
break;
549-
}
550-
if (i > 0) {
551-
collapse_nb(cnb0, cne0);
552-
collapse_nb(cnb1, cne1);
553-
collapse(cne0);
554-
collapse(cne1);
550+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
551+
for (int i = 0; i < 4; i++) {
552+
if (nr[i] != 1) {
553+
break;
554+
}
555+
if (i > 0) {
556+
collapse_nb(cnb, cne);
557+
collapse_nb(cnb0, cne0);
558+
collapse_nb(cnb1, cne1);
559+
collapse(cne);
560+
collapse(cne0);
561+
collapse(cne1);
562+
}
555563
}
556564
}
557565
{
558-
int64_t ne0 = cne0[0];
559-
int64_t ne1 = cne0[1];
560-
int64_t ne2 = cne0[2];
561-
int64_t ne3 = cne0[3];
566+
int64_t ne0 = cne[0];
567+
int64_t ne1 = cne[1];
568+
int64_t ne2 = cne[2];
569+
int64_t ne3 = cne[3];
562570

563571
int64_t ne10 = cne1[0];
564572
int64_t ne11 = cne1[1];
565573
int64_t ne12 = cne1[2];
566574
int64_t ne13 = cne1[3];
567575

568-
size_t nb0 = cnb0[0];
569-
size_t nb1 = cnb0[1];
570-
size_t nb2 = cnb0[2];
571-
size_t nb3 = cnb0[3];
576+
size_t nb0 = cnb[0];
577+
size_t nb1 = cnb[1];
578+
size_t nb2 = cnb[2];
579+
size_t nb3 = cnb[3];
580+
581+
size_t nb00 = cnb0[0];
582+
size_t nb01 = cnb0[1];
583+
size_t nb02 = cnb0[2];
584+
size_t nb03 = cnb0[3];
572585

573586
size_t nb10 = cnb1[0];
574587
size_t nb11 = cnb1[1];
@@ -585,6 +598,28 @@ struct bin_bcast_sycl {
585598
size_t s12 = nb12 / sizeof(src1_t);
586599
size_t s13 = nb13 / sizeof(src1_t);
587600

601+
size_t s00 = nb00 / sizeof(src0_t);
602+
size_t s01 = nb01 / sizeof(src0_t);
603+
size_t s02 = nb02 / sizeof(src0_t);
604+
size_t s03 = nb03 / sizeof(src0_t);
605+
606+
GGML_UNUSED(s00);
607+
608+
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
609+
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
610+
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
611+
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
612+
613+
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
614+
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
615+
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
616+
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
617+
618+
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
619+
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
620+
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
621+
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
622+
588623
GGML_ASSERT(s0 == 1);
589624
GGML_ASSERT(s10 == 1);
590625

@@ -621,8 +656,8 @@ struct bin_bcast_sycl {
621656
[=](sycl::nd_item<3> item_ct1) {
622657
k_bin_bcast_unravel<bin_op>(
623658
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
624-
ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
625-
s13, item_ct1);
659+
ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
660+
s03, s11, s12, s13, item_ct1);
626661
});
627662
}
628663
} else {
@@ -640,7 +675,7 @@ struct bin_bcast_sycl {
640675
[=](sycl::nd_item<3> item_ct1) {
641676
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,
642677
ne2, ne3, ne10, ne11, ne12, ne13,
643-
s1, s2, s3, s11, s12, s13,
678+
s1, s2, s3, s01, s02, s03, s11, s12, s13,
644679
item_ct1);
645680
});
646681
}

0 commit comments

Comments
 (0)