@@ -32,8 +32,9 @@ TORCHAO_ALWAYS_INLINE static void block_mul_1x8x16(
32
32
const size_t ldb,
33
33
int32x4_t (&partial_sums)[8],
34
34
int32_t& row_sum_a,
35
- int32_t (&row_sum_b)[8]) {
35
+ int32x4_t (&row_sum_b)[8]) {
36
36
int8x16_t a_vec = vld1q_s8 (a);
37
+ int8x16_t ones = vdupq_n_s8 (1 );
37
38
row_sum_a = row_sum_a + vaddlvq_s8 (a_vec);
38
39
39
40
// godbolt (https://godbolt.org/z/9vbq1d1qY) shows this loops doesnt quantize
@@ -42,8 +43,9 @@ TORCHAO_ALWAYS_INLINE static void block_mul_1x8x16(
42
43
// deconstruct the loop and do manual optimization. Or just write assembly.
43
44
#pragma unroll(8)
44
45
for (int i = 0 ; i < 8 ; ++i) {
45
- int8x16_t b_vec = vld1q_s8 (b + i * ldb);
46
- row_sum_b[i] = row_sum_b[i] + vaddlvq_s8 (b_vec);
46
+ int8x16_t b_vec = vld1q_s8 (b);
47
+ b += ldb;
48
+ row_sum_b[i] = vdotq_s32 (row_sum_b[i], b_vec, ones);
47
49
partial_sums[i] = vdotq_s32 (partial_sums[i], a_vec, b_vec);
48
50
}
49
51
}
@@ -234,8 +236,9 @@ struct KernelImpl<true, true, false, true> {
234
236
const int8_t * rhs_ptr = (const int8_t *)rhs + n_idx * rhs_stride_n;
235
237
int32x4_t int32_sums[nr] = {vdupq_n_s32 (0 )};
236
238
int32_t row_sum_lhs = 0 ;
237
- int32_t row_sum_rhs [nr] = {0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 };
239
+ int32x4_t row_sum_rhs_vec [nr] = {vdupq_n_s32 ( 0 ) };
238
240
int32_t sums[nr];
241
+ int32_t row_sum_rhs[nr];
239
242
240
243
// Loop k_idx by group
241
244
int k_idx = 0 ;
@@ -246,12 +249,13 @@ struct KernelImpl<true, true, false, true> {
246
249
rhs_stride_n,
247
250
int32_sums,
248
251
row_sum_lhs,
249
- row_sum_rhs );
252
+ row_sum_rhs_vec );
250
253
lhs_ptr += kr;
251
254
rhs_ptr += kr;
252
255
}
253
256
254
257
reduce_1x8_int32x4_t_sums (int32_sums, sums);
258
+ reduce_1x8_int32x4_t_sums (row_sum_rhs_vec, row_sum_rhs);
255
259
for (int ki = 0 ; ki < (k - k_idx); ++ki) {
256
260
row_sum_lhs += (int32_t )lhs_ptr[ki];
257
261
}
0 commit comments