@@ -4661,11 +4661,13 @@ static struct ggml_tensor * ggml_sub_impl(
4661
4661
struct ggml_tensor * a,
4662
4662
struct ggml_tensor * b,
4663
4663
bool inplace) {
4664
- GGML_ASSERT(ggml_are_same_shape(a, b ));
4664
+ GGML_ASSERT(ggml_can_repeat(b, a ));
4665
4665
4666
4666
bool is_node = false;
4667
4667
4668
4668
if (!inplace && (a->grad || b->grad)) {
4669
+ // TODO: support backward pass for broadcasting
4670
+ GGML_ASSERT(ggml_are_same_shape(a, b));
4669
4671
is_node = true;
4670
4672
}
4671
4673
@@ -10104,11 +10106,10 @@ static void ggml_compute_forward_sub_f32(
10104
10106
const struct ggml_tensor * src0 = dst->src[0];
10105
10107
const struct ggml_tensor * src1 = dst->src[1];
10106
10108
10107
- if (params->ith != 0) {
10108
- return;
10109
- }
10109
+ assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
10110
10110
10111
- assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
10111
+ const int ith = params->ith;
10112
+ const int nth = params->nth;
10112
10113
10113
10114
const int nr = ggml_nrows(src0);
10114
10115
@@ -10117,40 +10118,55 @@ static void ggml_compute_forward_sub_f32(
10117
10118
GGML_ASSERT( nb0 == sizeof(float));
10118
10119
GGML_ASSERT(nb00 == sizeof(float));
10119
10120
10121
+ // rows per thread
10122
+ const int dr = (nr + nth - 1)/nth;
10123
+
10124
+ // row range for this thread
10125
+ const int ir0 = dr*ith;
10126
+ const int ir1 = MIN(ir0 + dr, nr);
10127
+
10120
10128
if (nb10 == sizeof(float)) {
10121
- for (int ir = 0 ; ir < nr ; ++ir) {
10122
- // src0, src1 and dst are same shape => same indices
10123
- const int i3 = ir/(ne2*ne1 );
10124
- const int i2 = (ir - i3*ne2*ne1)/ne1 ;
10125
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1 );
10129
+ for (int ir = ir0 ; ir < ir1 ; ++ir) {
10130
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
10131
+ const int64_t i03 = ir/(ne02*ne01 );
10132
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01 ;
10133
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01 );
10126
10134
10135
+ const int64_t i13 = i03 % ne13;
10136
+ const int64_t i12 = i02 % ne12;
10137
+ const int64_t i11 = i01 % ne11;
10138
+ const int64_t nr0 = ne00 / ne10;
10139
+
10140
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
10141
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
10142
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
10143
+
10144
+ for (int64_t r = 0; r < nr0; ++r) {
10127
10145
#ifdef GGML_USE_ACCELERATE
10128
- vDSP_vsub(
10129
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
10130
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
10131
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
10132
- ne0);
10146
+ vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
10133
10147
#else
10134
- ggml_vec_sub_f32(ne0,
10135
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
10136
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
10137
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
10148
+ ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
10138
10149
#endif
10139
- // }
10140
- // }
10150
+ }
10141
10151
}
10142
10152
} else {
10143
10153
// src1 is not contiguous
10144
- for (int ir = 0; ir < nr; ++ir) {
10145
- // src0, src1 and dst are same shape => same indices
10146
- const int i3 = ir/(ne2*ne1);
10147
- const int i2 = (ir - i3*ne2*ne1)/ne1;
10148
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
10154
+ for (int ir = ir0; ir < ir1; ++ir) {
10155
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
10156
+ const int64_t i03 = ir/(ne02*ne01);
10157
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10158
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10159
+
10160
+ const int64_t i13 = i03 % ne13;
10161
+ const int64_t i12 = i02 % ne12;
10162
+ const int64_t i11 = i01 % ne11;
10163
+
10164
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
10165
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
10149
10166
10150
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
10151
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
10152
- for (int i0 = 0; i0 < ne0; i0++) {
10153
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
10167
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
10168
+ const int64_t i10 = i0 % ne10;
10169
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
10154
10170
10155
10171
dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
10156
10172
}
0 commit comments