Skip to content

Commit a3b857f

Browse files
authored
vectorized row sum
Differential Revision: D71833069 Pull Request resolved: #2034
1 parent 625a76e commit a3b857f

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ TORCHAO_ALWAYS_INLINE static void block_mul_1x8x16(
3232
const size_t ldb,
3333
int32x4_t (&partial_sums)[8],
3434
int32_t& row_sum_a,
35-
int32_t (&row_sum_b)[8]) {
35+
int32x4_t (&row_sum_b)[8]) {
3636
int8x16_t a_vec = vld1q_s8(a);
37+
int8x16_t ones = vdupq_n_s8(1);
3738
row_sum_a = row_sum_a + vaddlvq_s8(a_vec);
3839

3940
// godbolt (https://godbolt.org/z/9vbq1d1qY) shows this loops doesnt quantize
@@ -42,8 +43,9 @@ TORCHAO_ALWAYS_INLINE static void block_mul_1x8x16(
4243
// deconstruct the loop and do manual optimization. Or just write assembly.
4344
#pragma unroll(8)
4445
for (int i = 0; i < 8; ++i) {
45-
int8x16_t b_vec = vld1q_s8(b + i * ldb);
46-
row_sum_b[i] = row_sum_b[i] + vaddlvq_s8(b_vec);
46+
int8x16_t b_vec = vld1q_s8(b);
47+
b += ldb;
48+
row_sum_b[i] = vdotq_s32(row_sum_b[i], b_vec, ones);
4749
partial_sums[i] = vdotq_s32(partial_sums[i], a_vec, b_vec);
4850
}
4951
}
@@ -234,8 +236,9 @@ struct KernelImpl<true, true, false, true> {
234236
const int8_t* rhs_ptr = (const int8_t*)rhs + n_idx * rhs_stride_n;
235237
int32x4_t int32_sums[nr] = {vdupq_n_s32(0)};
236238
int32_t row_sum_lhs = 0;
237-
int32_t row_sum_rhs[nr] = {0, 0, 0, 0, 0, 0, 0, 0};
239+
int32x4_t row_sum_rhs_vec[nr] = {vdupq_n_s32(0)};
238240
int32_t sums[nr];
241+
int32_t row_sum_rhs[nr];
239242

240243
// Loop k_idx by group
241244
int k_idx = 0;
@@ -246,12 +249,13 @@ struct KernelImpl<true, true, false, true> {
246249
rhs_stride_n,
247250
int32_sums,
248251
row_sum_lhs,
249-
row_sum_rhs);
252+
row_sum_rhs_vec);
250253
lhs_ptr += kr;
251254
rhs_ptr += kr;
252255
}
253256

254257
reduce_1x8_int32x4_t_sums(int32_sums, sums);
258+
reduce_1x8_int32x4_t_sums(row_sum_rhs_vec, row_sum_rhs);
255259
for (int ki = 0; ki < (k - k_idx); ++ki) {
256260
row_sum_lhs += (int32_t)lhs_ptr[ki];
257261
}

0 commit comments

Comments
 (0)