Skip to content

Commit 5cb1fa1

Browse files
authored
Add gemm for fp32_a_int8_b matmul kernel
Differential Revision: D71833070 Pull Request resolved: #2039
1 parent a3b857f commit 5cb1fa1

File tree

3 files changed

+429
-6
lines changed

3 files changed

+429
-6
lines changed
Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#if defined(__aarch64__) && defined(__ARM_NEON)
10+
11+
#include <algorithm>
12+
#include <cassert>
13+
#include <cstring>
14+
15+
#include <arm_neon.h>
16+
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
17+
#include <torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h>
18+
19+
namespace torchao::kernels::cpu::aarch64::quantized_matmul {
20+
namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32::internal {
21+
22+
namespace {
23+
24+
/*
25+
This function loads float32x4_t value from a, and 16 int8x16_t values from b.
26+
For each int8x16_t of b:
27+
- 4 float32x4 accumulated values
28+
- load 4 a in float32x4_t
29+
- [The following repeats for each of the 4 lanes of a]
30+
- for i in [0, 4]:
31+
- load b[i] in int8x16_t
32+
- subl to subtarct b_zero_point from b, to get b_low, b_high
33+
- vmovl to get b_low_low, b_low_high, b_high_low, b_high_high
34+
- vcvtq to convert to float32x4_t, we will have 4 of these.
35+
- for i in [0, 4]: for each of the 4 float32x4_t of b:
36+
- vfmaq_lane_fp32 to multiply a[lane] and b[i]
37+
- vfmaq_lane_fp32 to multiply a[lane] and b[i]
38+
- vfmaq_lane_fp32 to multiply a[lane] and b[i]
39+
- vfmaq_lane_fp32 to multiply a[lane] and b[i]
40+
- By doing the above 4 times (lane=[0-3]), we used all values along k dim of a
41+
and accumulated 4 float32x4_t values
42+
*/
43+
TORCHAO_ALWAYS_INLINE void block_mul_4x16x1(
44+
const float32x4_t& a,
45+
const int8x16_t& b_vec,
46+
const int8_t b_zero_point,
47+
const float b_scale,
48+
float32x4_t (&partial_sums)[4][4]) {
49+
int8x8_t b_zero_point_vec = vdup_n_s8(b_zero_point);
50+
int16x8_t b_vec_low = vsubl_s8(vget_low_s8(b_vec), b_zero_point_vec);
51+
int16x8_t b_vec_high = vsubl_s8(vget_high_s8(b_vec), b_zero_point_vec);
52+
float32x4_t b_vec_low_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_low)));
53+
float32x4_t b_vec_low_high =
54+
vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_low)));
55+
float32x4_t b_vec_high_low =
56+
vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_high)));
57+
float32x4_t b_vec_high_high =
58+
vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_high)));
59+
b_vec_low_low = vmulq_n_f32(b_vec_low_low, b_scale);
60+
b_vec_low_high = vmulq_n_f32(b_vec_low_high, b_scale);
61+
b_vec_high_low = vmulq_n_f32(b_vec_high_low, b_scale);
62+
b_vec_high_high = vmulq_n_f32(b_vec_high_high, b_scale);
63+
64+
partial_sums[0][0] = vfmaq_n_f32(partial_sums[0][0], b_vec_low_low, a[0]);
65+
partial_sums[0][1] = vfmaq_n_f32(partial_sums[0][1], b_vec_low_high, a[0]);
66+
partial_sums[0][2] = vfmaq_n_f32(partial_sums[0][2], b_vec_high_low, a[0]);
67+
partial_sums[0][3] = vfmaq_n_f32(partial_sums[0][3], b_vec_high_high, a[0]);
68+
69+
partial_sums[1][0] = vfmaq_n_f32(partial_sums[1][0], b_vec_low_low, a[1]);
70+
partial_sums[1][1] = vfmaq_n_f32(partial_sums[1][1], b_vec_low_high, a[1]);
71+
partial_sums[1][2] = vfmaq_n_f32(partial_sums[1][2], b_vec_high_low, a[1]);
72+
partial_sums[1][3] = vfmaq_n_f32(partial_sums[1][3], b_vec_high_high, a[1]);
73+
74+
partial_sums[2][0] = vfmaq_n_f32(partial_sums[2][0], b_vec_low_low, a[2]);
75+
partial_sums[2][1] = vfmaq_n_f32(partial_sums[2][1], b_vec_low_high, a[2]);
76+
partial_sums[2][2] = vfmaq_n_f32(partial_sums[2][2], b_vec_high_low, a[2]);
77+
partial_sums[2][3] = vfmaq_n_f32(partial_sums[2][3], b_vec_high_high, a[2]);
78+
79+
partial_sums[3][0] = vfmaq_n_f32(partial_sums[3][0], b_vec_low_low, a[3]);
80+
partial_sums[3][1] = vfmaq_n_f32(partial_sums[3][1], b_vec_low_high, a[3]);
81+
partial_sums[3][2] = vfmaq_n_f32(partial_sums[3][2], b_vec_high_low, a[3]);
82+
partial_sums[3][3] = vfmaq_n_f32(partial_sums[3][3], b_vec_high_high, a[3]);
83+
}
84+
85+
TORCHAO_ALWAYS_INLINE void transpose_4x4(
86+
const float32_t* a,
87+
const size_t lda,
88+
float32x4_t (&tranposed)[4]) {
89+
float32x4_t a_vec_0 = vld1q_f32(a + 0 * lda);
90+
float32x4_t a_vec_1 = vld1q_f32(a + lda);
91+
float32x4_t a_vec_2 = vld1q_f32(a + 2 * lda);
92+
float32x4_t a_vec_3 = vld1q_f32(a + 3 * lda);
93+
// Transpose the 4x4 matrix formed by a_vec_0, a_vec_1, a_vec_2, a_vec_3
94+
float32x4x2_t a01 = vtrnq_f32(a_vec_0, a_vec_1);
95+
float32x4x2_t a23 = vtrnq_f32(a_vec_2, a_vec_3);
96+
97+
float32x4_t a_vec_0_t =
98+
vcombine_f32(vget_low_f32(a01.val[0]), vget_low_f32(a23.val[0]));
99+
float32x4_t a_vec_1_t =
100+
vcombine_f32(vget_low_f32(a01.val[1]), vget_low_f32(a23.val[1]));
101+
float32x4_t a_vec_2_t =
102+
vcombine_f32(vget_high_f32(a01.val[0]), vget_high_f32(a23.val[0]));
103+
float32x4_t a_vec_3_t =
104+
vcombine_f32(vget_high_f32(a01.val[1]), vget_high_f32(a23.val[1]));
105+
106+
tranposed[0] = a_vec_0_t;
107+
tranposed[1] = a_vec_1_t;
108+
tranposed[2] = a_vec_2_t;
109+
tranposed[3] = a_vec_3_t;
110+
}
111+
112+
TORCHAO_ALWAYS_INLINE void block_mul_4x16x4(
113+
const float32_t* a,
114+
const size_t lda,
115+
const int8_t* b,
116+
const size_t ldb,
117+
const int8_t* b_zero_point,
118+
const float* b_scale,
119+
float32x4_t (&partial_sums)[4][4]) {
120+
float32x4_t a_vec[4];
121+
transpose_4x4(a, lda, a_vec);
122+
123+
int8x16_t b_vec = vld1q_s8(b + 0 * ldb);
124+
block_mul_4x16x1(a_vec[0], b_vec, b_zero_point[0], b_scale[0], partial_sums);
125+
b_vec = vld1q_s8(b + 1 * ldb);
126+
block_mul_4x16x1(a_vec[1], b_vec, b_zero_point[1], b_scale[1], partial_sums);
127+
b_vec = vld1q_s8(b + 2 * ldb);
128+
block_mul_4x16x1(a_vec[2], b_vec, b_zero_point[2], b_scale[2], partial_sums);
129+
b_vec = vld1q_s8(b + 3 * ldb);
130+
block_mul_4x16x1(a_vec[3], b_vec, b_zero_point[3], b_scale[3], partial_sums);
131+
}
132+
133+
} // namespace
134+
135+
template <bool b_has_zeros, bool a_transposed, bool b_transposed>
136+
struct KernelImpl {
137+
static void run(
138+
int m,
139+
int n,
140+
int k,
141+
const void* lhs,
142+
int lhs_stride_m,
143+
const void* rhs,
144+
int rhs_stride_n,
145+
float32_t* output,
146+
int out_stride_m,
147+
const int8_t* rhs_zero_points,
148+
const float* rhs_scales,
149+
const float beta,
150+
const int rhs_qparams_stride);
151+
};
152+
153+
/*
154+
Document param meaning
155+
rhs_stride_n: Since rhs transposed == false, the expected shape of rhs is k x n.
156+
Thus rhs_stride_n is the stride of k dim, that how many bytes aparts elements
157+
in k dim are.
158+
*/
159+
template <>
160+
struct KernelImpl<true, false, false> {
161+
static void run(
162+
int m,
163+
int n,
164+
int k,
165+
const float* lhs,
166+
int lhs_stride_m,
167+
const int8_t* rhs,
168+
int rhs_stride_n,
169+
float32_t* output,
170+
int out_stride_m,
171+
const int8_t* rhs_zero_points,
172+
const float* rhs_scales,
173+
const float beta,
174+
const int rhs_qparams_stride) {
175+
std::vector<int8_t> rhs_zero_points_transposed;
176+
std::vector<float> rhs_scales_transposed;
177+
if (rhs_qparams_stride > 1) {
178+
rhs_zero_points_transposed.resize(k);
179+
rhs_scales_transposed.resize(k);
180+
utils::transpose_scales_and_zero_points(
181+
rhs_zero_points,
182+
rhs_scales,
183+
rhs_zero_points_transposed.data(),
184+
rhs_scales_transposed.data(),
185+
k,
186+
rhs_qparams_stride);
187+
rhs_zero_points = rhs_zero_points_transposed.data();
188+
rhs_scales = rhs_scales_transposed.data();
189+
}
190+
191+
constexpr int mr = 4;
192+
constexpr int nr = 16;
193+
constexpr int kr = 4;
194+
assert(m % mr == 0);
195+
assert(kr == 4);
196+
assert(n >= nr);
197+
for (int m_idx = 0; m_idx < m; m_idx += mr) {
198+
const float* lhs_ptr = lhs + m_idx * lhs_stride_m;
199+
// Loop over 16 cols at a time
200+
// Access to partial tiles must be protected
201+
for (int n_idx = 0; n_idx < n; n_idx += nr) {
202+
// If remaining is < nr, that must mean that (nr - remaining) items
203+
// dont need to be computed.
204+
// In order to avoid out-of-bounds access, we need to rewind n_indx a
205+
// bit
206+
// |-------------------|-------------------|
207+
// 0-------------------8-------------------16
208+
// 0-------------------8-----10
209+
// If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to
210+
// 8 - (8 - 10) = 2
211+
int remaining = std::min(n - n_idx, nr);
212+
n_idx = n_idx - (nr - remaining);
213+
// Set activation_ptr to start of activation qvals for row m_idx
214+
const int8_t* rhs_ptr = rhs + n_idx;
215+
float32x4_t sums[mr][(nr / 4)] = {{vdupq_n_f32(0)}};
216+
217+
// Loop k_idx by group
218+
int k_idx = 0;
219+
const float* current_lhs_ptr = lhs_ptr;
220+
for (; (k_idx + kr) <= k; k_idx += kr) {
221+
block_mul_4x16x4(
222+
current_lhs_ptr,
223+
lhs_stride_m,
224+
rhs_ptr,
225+
rhs_stride_n,
226+
rhs_zero_points + k_idx,
227+
rhs_scales + k_idx,
228+
sums);
229+
current_lhs_ptr += kr;
230+
rhs_ptr += kr * rhs_stride_n;
231+
}
232+
233+
for (int ki = 0; ki < (k - k_idx); ++ki) {
234+
// For each of the remaining k values
235+
// Load 1 int8_t from lhs
236+
// Load 16 int8_t from rhs
237+
// And multiply + add into the 16 accumulators
238+
// arranged as int32x4_t[4]
239+
int8x16_t rhs_vec = vld1q_s8(rhs_ptr + ki * rhs_stride_n);
240+
float32x4_t lhs_vec = {
241+
current_lhs_ptr[ki + 0 * lhs_stride_m],
242+
current_lhs_ptr[ki + 1 * lhs_stride_m],
243+
current_lhs_ptr[ki + 2 * lhs_stride_m],
244+
current_lhs_ptr[ki + 3 * lhs_stride_m]};
245+
block_mul_4x16x1(
246+
lhs_vec,
247+
rhs_vec,
248+
rhs_zero_points[k_idx + ki],
249+
rhs_scales[k_idx + ki],
250+
sums);
251+
}
252+
253+
// Store result
254+
// Because we adjust n_idx, we may end up writing the same location
255+
// twice
256+
// Note that the reason this case is being handld only for this kernel
257+
// and not others in this directory is because only for this kernel
258+
// we support accumulation.
259+
float* store_loc = output + m_idx * out_stride_m + n_idx;
260+
if (remaining < 16) {
261+
// If remaining is < 16, then not all of the 16 accumulators are
262+
// valid. That is not all of float32x4_t[4] are valid. We need to
263+
// find the first valid one, and then store the rest of the
264+
// accumulators in the same order.
265+
// First valid one is at 3 - ((remaining - 1) / 4) because:
266+
// If remaining is say 10 then first 6 are not valid.
267+
// Thus first group of 4 at sums[0] is not valid.
268+
// In the second group of 4, the first 2 are not valid.
269+
// Rest are valid.
270+
int start_sum_idx = 3 - ((remaining - 1) / 4);
271+
// If remaining is 11, then the sums[1] has 3 valid values
272+
// so 3 - (11 -1) % 4 = 3 - 10 % 4 = 3 - 2 = 1
273+
// Thus there is 1 invalid value in the first group of 4
274+
int invalid_values_in_32x4_reg = 3 - (remaining - 1) % 4;
275+
store_loc += start_sum_idx * 4;
276+
store_loc += invalid_values_in_32x4_reg;
277+
if (invalid_values_in_32x4_reg > 0) {
278+
for (int m_out_idx = 0; m_out_idx < mr; m_out_idx++) {
279+
float* store_loc_local = store_loc + m_out_idx * out_stride_m;
280+
for (int val_idx = invalid_values_in_32x4_reg; val_idx < 4;
281+
++val_idx) {
282+
*store_loc_local = sums[m_out_idx][start_sum_idx][val_idx] +
283+
(*store_loc_local) * beta;
284+
store_loc_local += 1;
285+
}
286+
}
287+
start_sum_idx++;
288+
store_loc += (4 - invalid_values_in_32x4_reg);
289+
}
290+
for (int m_out_idx = 0; m_out_idx < mr; m_out_idx++) {
291+
float* store_loc_local = store_loc + m_out_idx * out_stride_m;
292+
for (int out_idx = 0, sum_idx = start_sum_idx; sum_idx < nr / 4;
293+
out_idx += 4, ++sum_idx) {
294+
float32x4_t sum_val = vld1q_f32(store_loc_local + out_idx);
295+
sums[m_out_idx][sum_idx] =
296+
vfmaq_n_f32(sums[m_out_idx][sum_idx], sum_val, beta);
297+
vst1q_f32(store_loc_local + out_idx, sums[m_out_idx][sum_idx]);
298+
}
299+
}
300+
} else {
301+
for (int m_out_idx = 0; m_out_idx < mr; m_out_idx++) {
302+
float* store_loc_local = store_loc + m_out_idx * out_stride_m;
303+
for (int out_idx = 0, sum_idx = 0; out_idx < nr;
304+
out_idx += 4, ++sum_idx) {
305+
float32x4_t sum_val = vld1q_f32(store_loc_local + out_idx);
306+
sums[m_out_idx][sum_idx] =
307+
vfmaq_n_f32(sums[m_out_idx][sum_idx], sum_val, beta);
308+
vst1q_f32(store_loc_local + out_idx, sums[m_out_idx][sum_idx]);
309+
}
310+
}
311+
}
312+
} // n_idx
313+
} // m_idx
314+
}
315+
};
316+
317+
} // namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32::internal
318+
319+
namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32 {
320+
template <bool b_has_zeros, bool a_transposed, bool b_transposed>
321+
void kernel(
322+
int m,
323+
int n,
324+
int k,
325+
const float* lhs,
326+
int lhs_stride_m,
327+
const int8_t* rhs,
328+
int rhs_stride_n,
329+
float32_t* output,
330+
int out_stride_m,
331+
const int8_t* rhs_zero_points,
332+
const float* rhs_scales,
333+
const float beta,
334+
const int rhs_qparams_stride) {
335+
torchao::kernels::cpu::aarch64::quantized_matmul::
336+
fp32_a_input_channelwise_8bit_b_4x16x4_f32::internal::
337+
KernelImpl<b_has_zeros, a_transposed, b_transposed>::run(
338+
m,
339+
n,
340+
k,
341+
lhs,
342+
lhs_stride_m,
343+
rhs,
344+
rhs_stride_n,
345+
output,
346+
out_stride_m,
347+
rhs_zero_points,
348+
rhs_scales,
349+
beta,
350+
rhs_qparams_stride);
351+
}
352+
} // namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32
353+
} // namespace torchao::kernels::cpu::aarch64::quantized_matmul
354+
355+
#endif // defined(__aarch64__) && defined(__ARM_NEON)

torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,31 @@ void kernel(
8686
const int rhs_qparams_stride);
8787

8888
} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32
89+
90+
namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32 {
91+
92+
template <bool b_has_zeros, bool a_transposed, bool b_tranposed>
93+
void kernel(
94+
int m,
95+
int n,
96+
int k,
97+
const float* lhs,
98+
int lhs_stride_m,
99+
const int8_t* rhs,
100+
int rhs_stride_n,
101+
float32_t* output,
102+
int out_stride_m,
103+
const int8_t* rhs_zero_points,
104+
const float* rhs_scales,
105+
const float beta,
106+
const int rhs_qparams_stride);
107+
108+
} // namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32
89109
} // namespace torchao::kernels::cpu::aarch64::quantized_matmul
90110

91111
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h>
92112
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h>
93113
#include <torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h>
114+
#include <torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_4x16x4_f32_impl.h>
94115

95116
#endif // defined(__aarch64__) && defined(__ARM_NEON)

0 commit comments

Comments
 (0)