Skip to content

Commit b49f23c

Browse files
authored
Add fp32xint8 matmul
Differential Revision: D71370597 Pull Request resolved: #2004
1 parent e2369d3 commit b49f23c

File tree

4 files changed

+489
-9
lines changed

4 files changed

+489
-9
lines changed
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#if defined(__aarch64__) || defined(__ARM_NEON)
10+
11+
#include <algorithm>
12+
#include <cassert>
13+
#include <cstring>
14+
15+
#include <arm_neon.h>
16+
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
17+
#include <torchao/experimental/kernels/cpu/aarch64/matmul/matmul_utils.h>
18+
19+
namespace torchao::kernels::cpu::aarch64::quantized_matmul {
20+
namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal {
21+
22+
namespace {
23+
24+
/*
25+
This function loads float32x4_t value from a, and 16 int8x16_t values from b.
26+
For each int8x16_t of b:
27+
- 4 float32x4 accumulated values
28+
- load 4 a in float32x4_t
29+
- [The following repeats for each of the 4 lanes of a]
30+
- for i in [0, 4]:
31+
- load b[i] in int8x16_t
32+
- subl to subtract b_zero_point from b, to get b_low, b_high
33+
- vmovl to get b_low_low, b_low_high, b_high_low, b_high_high
34+
- vcvtq to convert to float32x4_t, we will have 4 of these.
35+
- for i in [0, 4]: for each of the 4 float32x4_t of b:
36+
- vfmaq_lane_fp32 to multiply a[lane] and b[i]
37+
- vfmaq_lane_fp32 to multiply a[lane] and b[i]
38+
- vfmaq_lane_fp32 to multiply a[lane] and b[i]
39+
- vfmaq_lane_fp32 to multiply a[lane] and b[i]
40+
- By doing the above 4 times (lane=[0-3]), we used all values along k dim of a
41+
and accumulated 4 float32x4_t values
42+
*/
43+
TORCHAO_ALWAYS_INLINE void block_mul_1x16x1(
44+
const float32_t a,
45+
const int8x16_t& b_vec,
46+
const int8_t b_zero_point,
47+
const float b_scale,
48+
float32x4_t (&partial_sums)[4]) {
49+
int8x8_t b_zero_point_vec = vdup_n_s8(b_zero_point);
50+
int16x8_t b_vec_low = vsubl_s8(vget_low_s8(b_vec), b_zero_point_vec);
51+
int16x8_t b_vec_high = vsubl_s8(vget_high_s8(b_vec), b_zero_point_vec);
52+
float32x4_t b_vec_low_low = vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_low)));
53+
float32x4_t b_vec_low_high =
54+
vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_low)));
55+
float32x4_t b_vec_high_low =
56+
vcvtq_f32_s32(vmovl_s16(vget_low_s16(b_vec_high)));
57+
float32x4_t b_vec_high_high =
58+
vcvtq_f32_s32(vmovl_s16(vget_high_s16(b_vec_high)));
59+
b_vec_low_low = vmulq_n_f32(b_vec_low_low, b_scale);
60+
b_vec_low_high = vmulq_n_f32(b_vec_low_high, b_scale);
61+
b_vec_high_low = vmulq_n_f32(b_vec_high_low, b_scale);
62+
b_vec_high_high = vmulq_n_f32(b_vec_high_high, b_scale);
63+
64+
partial_sums[0] = vfmaq_n_f32(partial_sums[0], b_vec_low_low, a);
65+
partial_sums[1] = vfmaq_n_f32(partial_sums[1], b_vec_low_high, a);
66+
partial_sums[2] = vfmaq_n_f32(partial_sums[2], b_vec_high_low, a);
67+
partial_sums[3] = vfmaq_n_f32(partial_sums[3], b_vec_high_high, a);
68+
}
69+
70+
void block_mul_1x16x4(
71+
const float32_t* a,
72+
const int8_t* b,
73+
const size_t ldb,
74+
const int8_t* b_zero_point,
75+
const float* b_scale,
76+
float32x4_t (&partial_sums)[4]) {
77+
#pragma unroll(8)
78+
for (int i = 0; i < 4; i++) {
79+
int8x16_t b_vec = vld1q_s8(b + i * ldb);
80+
block_mul_1x16x1(a[i], b_vec, b_zero_point[i], b_scale[i], partial_sums);
81+
}
82+
}
83+
84+
} // namespace
85+
86+
template <bool b_has_zeros, bool a_transposed, bool b_transposed>
87+
struct KernelImpl {
88+
static void run(
89+
int m,
90+
int n,
91+
int k,
92+
const void* lhs,
93+
int lhs_stride_m,
94+
const void* rhs,
95+
int rhs_stride_n,
96+
float32_t* output,
97+
int out_stride_m,
98+
const int8_t* rhs_zero_points,
99+
const float* rhs_scales,
100+
const float beta,
101+
const int rhs_qparams_stride);
102+
};
103+
104+
template <>
105+
struct KernelImpl<true, false, false> {
106+
static void run(
107+
int m,
108+
int n,
109+
int k,
110+
const float* lhs,
111+
int lhs_stride_m,
112+
const int8_t* rhs,
113+
int rhs_stride_n,
114+
float32_t* output,
115+
int out_stride_m,
116+
const int8_t* rhs_zero_points,
117+
const float* rhs_scales,
118+
const float beta,
119+
const int rhs_qparams_stride) {
120+
std::unique_ptr<int8_t []> rhs_zero_points_transposed = std::make_unique<int8_t []>(k);
121+
std::unique_ptr<float []> rhs_scales_transposed = std::make_unique<float []>(k);
122+
if (rhs_qparams_stride > 1) {
123+
utils::transpose_scales_and_zero_points(
124+
rhs_zero_points,
125+
rhs_scales,
126+
rhs_zero_points_transposed.get(),
127+
rhs_scales_transposed.get(),
128+
k,
129+
rhs_qparams_stride);
130+
rhs_zero_points = rhs_zero_points_transposed.get();
131+
rhs_scales = rhs_scales_transposed.get();
132+
}
133+
134+
constexpr int nr = 16;
135+
constexpr int kr = 4;
136+
for (int m_idx = 0; m_idx < m; m_idx++) {
137+
// Loop over 16 cols at a time
138+
// Access to partial tiles must be protected:w
139+
assert(n >= nr);
140+
for (int n_idx = 0; n_idx < n; n_idx += nr) {
141+
// If remaining is < nr, that must mean that (nr - remaining) items
142+
// dont need to be computed.
143+
// In order to avoid out-of-bounds access, we need to rewind n_indx a
144+
// bit
145+
// |-------------------|-------------------|
146+
// 0-------------------8-------------------16
147+
// 0-------------------8-----10
148+
// If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to
149+
// 8 - (8 - 10) = 2
150+
int remaining = std::min(n - n_idx, nr);
151+
n_idx = n_idx - (nr - remaining);
152+
// Set activation_ptr to start of activation qvals for row m_idx
153+
const float* lhs_ptr = lhs + m_idx * lhs_stride_m;
154+
const int8_t* rhs_ptr = rhs + n_idx;
155+
float32x4_t sums[nr / 4] = {vdupq_n_f32(0)};
156+
157+
// Loop k_idx by group
158+
int k_idx = 0;
159+
for (; (k_idx + kr) <= k; k_idx += kr) {
160+
block_mul_1x16x4(
161+
lhs_ptr,
162+
rhs_ptr,
163+
rhs_stride_n,
164+
rhs_zero_points + k_idx,
165+
rhs_scales + k_idx,
166+
sums);
167+
lhs_ptr += kr;
168+
rhs_ptr += kr * rhs_stride_n;
169+
}
170+
171+
for (int ki = 0; ki < (k - k_idx); ++ki) {
172+
// For each of the remaining k values
173+
// Load 1 int8_t from lhs
174+
// Load 16 int8_t from rhs
175+
// And multiply + add into the 16 accumulators
176+
// arranged as int32x4_t[4]
177+
int8x16_t rhs_vec = vld1q_s8(rhs_ptr + ki * rhs_stride_n);
178+
block_mul_1x16x1(
179+
lhs_ptr[ki],
180+
rhs_vec,
181+
rhs_zero_points[k_idx + ki],
182+
rhs_scales[k_idx + ki],
183+
sums);
184+
}
185+
186+
// Store result
187+
// Because we adjust n_idx, we may end up writing the same location
188+
// twice
189+
// Note that the reason this case is being handled only for this kernel
190+
// and not others in this directory is because only for this kernel
191+
// we support accumulation.
192+
float* store_loc = output + m_idx * out_stride_m + n_idx;
193+
if (remaining < 16) {
194+
// If remaining is < 16, then not all of the 16 accumulators are
195+
// valid. That is not all of float32x4_t[4] are valid. We need to
196+
// find the first valid one, and then store the rest of the
197+
// accumulators in the same order.
198+
// First valid one is at 3 - ((remaining - 1) / 4) because:
199+
// If remaining is say 10 then first 6 are not valid.
200+
// Thus first group of 4 at sums[0] is not valid.
201+
// In the second group of 4, the first 2 are not valid.
202+
// Rest are valid.
203+
int start_sum_idx = 3 - ((remaining - 1) / 4);
204+
// If remaining is 11, then the sums[1] has 3 valid values
205+
// so 3 - (11 -1) % 4 = 3 - 10 % 4 = 3 - 2 = 1
206+
// Thus there is 1 invalid value in the first group of 4
207+
int invalid_values_in_32x4_reg = 3 - (remaining - 1) % 4;
208+
store_loc += start_sum_idx * 4;
209+
store_loc += invalid_values_in_32x4_reg;
210+
if (invalid_values_in_32x4_reg > 0) {
211+
for (int val_idx = invalid_values_in_32x4_reg; val_idx < 4;
212+
++val_idx) {
213+
*store_loc = sums[start_sum_idx][val_idx] + (*store_loc) * beta;
214+
store_loc += 1;
215+
}
216+
start_sum_idx++;
217+
}
218+
for (int out_idx = 0, sum_idx = start_sum_idx; sum_idx < nr / 4;
219+
out_idx += 4, ++sum_idx) {
220+
float32x4_t sum_val = vld1q_f32(store_loc + out_idx);
221+
sums[sum_idx] = vfmaq_n_f32(sums[sum_idx], sum_val, beta);
222+
vst1q_f32(store_loc + out_idx, sums[sum_idx]);
223+
}
224+
} else {
225+
for (int out_idx = 0, sum_idx = 0; out_idx < nr;
226+
out_idx += 4, ++sum_idx) {
227+
float32x4_t sum_val = vld1q_f32(store_loc + out_idx);
228+
sums[sum_idx] = vfmaq_n_f32(sums[sum_idx], sum_val, beta);
229+
vst1q_f32(store_loc + out_idx, sums[sum_idx]);
230+
}
231+
}
232+
} // n_idx
233+
} // m_idx
234+
}
235+
};
236+
237+
} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal
238+
239+
namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 {
240+
template <bool b_has_zeros, bool a_transposed, bool b_transposed>
241+
void kernel(
242+
int m,
243+
int n,
244+
int k,
245+
const float* lhs,
246+
int lhs_stride_m,
247+
const int8_t* rhs,
248+
int rhs_stride_n,
249+
float32_t* output,
250+
int out_stride_m,
251+
const int8_t* rhs_zero_points,
252+
const float* rhs_scales,
253+
const float beta,
254+
const int rhs_qparams_stride) {
255+
torchao::kernels::cpu::aarch64::quantized_matmul::
256+
fp32_a_input_channelwise_8bit_b_1x16x4_f32::internal::
257+
KernelImpl<b_has_zeros, a_transposed, b_transposed>::run(
258+
m,
259+
n,
260+
k,
261+
lhs,
262+
lhs_stride_m,
263+
rhs,
264+
rhs_stride_n,
265+
output,
266+
out_stride_m,
267+
rhs_zero_points,
268+
rhs_scales,
269+
beta,
270+
rhs_qparams_stride);
271+
}
272+
} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32
273+
} // namespace torchao::kernels::cpu::aarch64::quantized_matmul
274+
275+
#endif // defined(__aarch64__) || defined(__ARM_NEON)

torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,30 @@ void kernel(
6666
const int rhs_qparams_stride);
6767

6868
} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal
69+
70+
namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32 {
71+
72+
template <bool b_has_zeros, bool a_transposed, bool b_tranposed>
73+
void kernel(
74+
int m,
75+
int n,
76+
int k,
77+
const float* lhs,
78+
int lhs_stride_m,
79+
const int8_t* rhs,
80+
int rhs_stride_n,
81+
float32_t* output,
82+
int out_stride_m,
83+
const int8_t* rhs_zero_points,
84+
const float* rhs_scales,
85+
const float beta,
86+
const int rhs_qparams_stride);
87+
88+
} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32
6989
} // namespace torchao::kernels::cpu::aarch64::quantized_matmul
7090

7191
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h>
7292
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h>
93+
#include <torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h>
7394

7495
#endif // defined(__aarch64__) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)