@@ -3724,7 +3724,8 @@ static struct ggml_tensor * ggml_new_tensor_impl(
3724
3724
struct ggml_tensor * view_src,
3725
3725
size_t view_offs) {
3726
3726
3727
- assert(n_dims >= 1 && n_dims <= GGML_MAX_DIMS);
3727
+ GGML_ASSERT(type >= 0 && type < GGML_TYPE_COUNT);
3728
+ GGML_ASSERT(n_dims >= 1 && n_dims <= GGML_MAX_DIMS);
3728
3729
3729
3730
// find the base tensor and absolute offset
3730
3731
if (view_src != NULL && view_src->view_src != NULL) {
@@ -4660,11 +4661,13 @@ static struct ggml_tensor * ggml_sub_impl(
4660
4661
struct ggml_tensor * a,
4661
4662
struct ggml_tensor * b,
4662
4663
bool inplace) {
4663
- GGML_ASSERT(ggml_are_same_shape(a, b ));
4664
+ GGML_ASSERT(ggml_can_repeat(b, a ));
4664
4665
4665
4666
bool is_node = false;
4666
4667
4667
4668
if (!inplace && (a->grad || b->grad)) {
4669
+ // TODO: support backward pass for broadcasting
4670
+ GGML_ASSERT(ggml_are_same_shape(a, b));
4668
4671
is_node = true;
4669
4672
}
4670
4673
@@ -10103,11 +10106,10 @@ static void ggml_compute_forward_sub_f32(
10103
10106
const struct ggml_tensor * src0 = dst->src[0];
10104
10107
const struct ggml_tensor * src1 = dst->src[1];
10105
10108
10106
- if (params->ith != 0) {
10107
- return;
10108
- }
10109
+ assert(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
10109
10110
10110
- assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
10111
+ const int ith = params->ith;
10112
+ const int nth = params->nth;
10111
10113
10112
10114
const int nr = ggml_nrows(src0);
10113
10115
@@ -10116,40 +10118,55 @@ static void ggml_compute_forward_sub_f32(
10116
10118
GGML_ASSERT( nb0 == sizeof(float));
10117
10119
GGML_ASSERT(nb00 == sizeof(float));
10118
10120
10121
+ // rows per thread
10122
+ const int dr = (nr + nth - 1)/nth;
10123
+
10124
+ // row range for this thread
10125
+ const int ir0 = dr*ith;
10126
+ const int ir1 = MIN(ir0 + dr, nr);
10127
+
10119
10128
if (nb10 == sizeof(float)) {
10120
- for (int ir = 0 ; ir < nr ; ++ir) {
10121
- // src0, src1 and dst are same shape => same indices
10122
- const int i3 = ir/(ne2*ne1 );
10123
- const int i2 = (ir - i3*ne2*ne1)/ne1 ;
10124
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1 );
10129
+ for (int ir = ir0 ; ir < ir1 ; ++ir) {
10130
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
10131
+ const int64_t i03 = ir/(ne02*ne01 );
10132
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01 ;
10133
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01 );
10125
10134
10135
+ const int64_t i13 = i03 % ne13;
10136
+ const int64_t i12 = i02 % ne12;
10137
+ const int64_t i11 = i01 % ne11;
10138
+ const int64_t nr0 = ne00 / ne10;
10139
+
10140
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
10141
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
10142
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
10143
+
10144
+ for (int64_t r = 0; r < nr0; ++r) {
10126
10145
#ifdef GGML_USE_ACCELERATE
10127
- vDSP_vsub(
10128
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
10129
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
10130
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
10131
- ne0);
10146
+ vDSP_vsub(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10);
10132
10147
#else
10133
- ggml_vec_sub_f32(ne0,
10134
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
10135
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
10136
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
10148
+ ggml_vec_sub_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr);
10137
10149
#endif
10138
- // }
10139
- // }
10150
+ }
10140
10151
}
10141
10152
} else {
10142
10153
// src1 is not contiguous
10143
- for (int ir = 0; ir < nr; ++ir) {
10144
- // src0, src1 and dst are same shape => same indices
10145
- const int i3 = ir/(ne2*ne1);
10146
- const int i2 = (ir - i3*ne2*ne1)/ne1;
10147
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
10154
+ for (int ir = ir0; ir < ir1; ++ir) {
10155
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
10156
+ const int64_t i03 = ir/(ne02*ne01);
10157
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10158
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10159
+
10160
+ const int64_t i13 = i03 % ne13;
10161
+ const int64_t i12 = i02 % ne12;
10162
+ const int64_t i11 = i01 % ne11;
10163
+
10164
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
10165
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
10148
10166
10149
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
10150
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
10151
- for (int i0 = 0; i0 < ne0; i0++) {
10152
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
10167
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
10168
+ const int64_t i10 = i0 % ne10;
10169
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10);
10153
10170
10154
10171
dst_ptr[i0] = src0_ptr[i0] - *src1_ptr;
10155
10172
}
0 commit comments