Skip to content

Commit 993f0df

Browse files
smesoggerganov
authored andcommitted
ggml : support forward pass broadcasting in ggml_sub (ggml/914)
* ggml: support forward pass broadcasting in ggml_sub Signed-off-by: Salvatore Mesoraca <s.mesoraca16@gmail.com> * Use assert instead of GGML_ASSERT in ggml_compute_forward_sub_f32 The check is already performed in ggml_sub_impl Signed-off-by: Salvatore Mesoraca <s.mesoraca16@gmail.com> --------- Signed-off-by: Salvatore Mesoraca <s.mesoraca16@gmail.com>
1 parent 9b17884 commit 993f0df

File tree

1 file changed

+46
-30
lines changed

1 file changed

+46
-30
lines changed

ggml/src/ggml.c

Lines changed: 46 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4661,11 +4661,13 @@ static struct ggml_tensor * ggml_sub_impl(
46614661
struct ggml_tensor * a,
46624662
struct ggml_tensor * b,
46634663
bool inplace) {
4664-
GGML_ASSERT(ggml_are_same_shape(a, b));
4664+
GGML_ASSERT(ggml_can_repeat(b, a));
46654665

46664666
bool is_node = false;
46674667

46684668
if (!inplace && (a->grad || b->grad)) {
4669+
// TODO: support backward pass for broadcasting
4670+
GGML_ASSERT(ggml_are_same_shape(a, b));
46694671
is_node = true;
46704672
}
46714673

@@ -10104,11 +10106,10 @@ static void ggml_compute_forward_sub_f32(
1010410106
const struct ggml_tensor * src0 = dst->src[0];
1010510107
const struct ggml_tensor * src1 = dst->src[1];
1010610108

10107-
if (params->ith != 0) {
10108-
return;
10109-
}
10109+
assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
1011010110

10111-
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
10111+
const int ith = params->ith;
10112+
const int nth = params->nth;
1011210113

1011310114
const int nr = ggml_nrows(src0);
1011410115

@@ -10117,40 +10118,55 @@ static void ggml_compute_forward_sub_f32(
1011710118
GGML_ASSERT( nb0 == sizeof(float));
1011810119
GGML_ASSERT(nb00 == sizeof(float));
1011910120

10121+
// rows per thread
10122+
const int dr = (nr + nth - 1)/nth;
10123+
10124+
// row range for this thread
10125+
const int ir0 = dr*ith;
10126+
const int ir1 = MIN(ir0 + dr, nr);
10127+
1012010128
if (nb10 == sizeof(float)) {
10121-
for (int ir = 0; ir < nr; ++ir) {
10122-
// src0, src1 and dst are same shape => same indices
10123-
const int i3 = ir/(ne2*ne1);
10124-
const int i2 = (ir - i3*ne2*ne1)/ne1;
10125-
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
10129+
for (int ir = ir0; ir < ir1; ++ir) {
10130+
// src1 is broadcastable across src0 and dst in i1, i2, i3
10131+
const int64_t i03 = ir/(ne02*ne01);
10132+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10133+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
1012610134

10135+
const int64_t i13 = i03 % ne13;
10136+
const int64_t i12 = i02 % ne12;
10137+
const int64_t i11 = i01 % ne11;
10138+
const int64_t nr0 = ne00 / ne10;
10139+
10140+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
10141+
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
10142+
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
10143+
10144+
for (int64_t r = 0; r < nr0; ++r) {
1012710145
#ifdef GGML_USE_ACCELERATE
10128-
vDSP_vsub(
10129-
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
10130-
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
10131-
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
10132-
ne0);
10146+
vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
1013310147
#else
10134-
ggml_vec_sub_f32(ne0,
10135-
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
10136-
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
10137-
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
10148+
ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
1013810149
#endif
10139-
// }
10140-
// }
10150+
}
1014110151
}
1014210152
} else {
1014310153
// src1 is not contiguous
10144-
for (int ir = 0; ir < nr; ++ir) {
10145-
// src0, src1 and dst are same shape => same indices
10146-
const int i3 = ir/(ne2*ne1);
10147-
const int i2 = (ir - i3*ne2*ne1)/ne1;
10148-
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
10154+
for (int ir = ir0; ir < ir1; ++ir) {
10155+
// src1 is broadcastable across src0 and dst in i1, i2, i3
10156+
const int64_t i03 = ir/(ne02*ne01);
10157+
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10158+
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10159+
10160+
const int64_t i13 = i03 % ne13;
10161+
const int64_t i12 = i02 % ne12;
10162+
const int64_t i11 = i01 % ne11;
10163+
10164+
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
10165+
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
1014910166

10150-
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
10151-
float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
10152-
for (int i0 = 0; i0 < ne0; i0++) {
10153-
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
10167+
for (int64_t i0 = 0; i0 < ne0; ++i0) {
10168+
const int64_t i10 = i0 % ne10;
10169+
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
1015410170

1015510171
dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
1015610172
}

0 commit comments

Comments
 (0)