Skip to content

Commit 663a95d

Browse files
authored
Add gemm int8 a x int8 b to interface
Differential Revision: D71936844 Pull Request resolved: #2055
1 parent 5a31ec8 commit 663a95d

File tree

4 files changed

+137
-13
lines changed

4 files changed

+137
-13
lines changed

torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot-impl.h

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ struct KernelImpl<true, true, false, true> {
289289
constexpr int kr = 8;
290290
assert(m % mr == 0);
291291
assert(k % 16 == 0);
292-
assert(n >= nr);
292+
assert(n % nr == 0);
293293
std::vector<int8_t> rhs_packed(n * k);
294294
// Since we are casting int8_t to float32_t in order to tranpose matrix in a
295295
// way to keep 4 of the k values to gether, we must adjust stride as well as
@@ -307,17 +307,6 @@ struct KernelImpl<true, true, false, true> {
307307

308308
for (int m_idx = 0; m_idx < m; m_idx += mr) {
309309
for (int n_idx = 0; n_idx < n; n_idx += nr) {
310-
// If remaining is < nr, that must mean that (nr - remaining) items
311-
// dont need to be computed.
312-
// In order to avoid out-of-bounds access, we need to rewind n_indx a
313-
// bit
314-
// |-------------------|-------------------|
315-
// 0-------------------8-------------------16
316-
// 0-------------------8-----10
317-
// If n = 10 and nr = 8 then at n_idx = 8, we need to rewind n_idx to
318-
// 8 - (8 - 10) = 2
319-
int remaining = std::min(n - n_idx, nr);
320-
n_idx = n_idx - (nr - remaining);
321310
// Set activation_ptr to start of activation qvals for row m_idx
322311
const int8_t* lhs_ptr = (const int8_t*)lhs + m_idx * lhs_stride_m;
323312
const int8_t* rhs_ptr = (const int8_t*)rhs_packed.data() +

torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,129 @@ void kernel(
4242

4343
} // namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot
4444

45+
namespace channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot {
46+
47+
template <
48+
bool a_has_zeros,
49+
bool b_has_zeros,
50+
bool a_transposed,
51+
bool b_tranposed>
52+
void kernel(
53+
int m,
54+
int n,
55+
int k,
56+
const void* lhs,
57+
int lhs_stride_m,
58+
const void* rhs,
59+
int rhs_stride_n,
60+
float32_t* output,
61+
int out_stride_m,
62+
const int8_t* lhs_zero_points,
63+
const int8_t* rhs_zero_points,
64+
const float* lhs_scales,
65+
const float* rhs_scales,
66+
const int lhs_qparams_stride,
67+
const int rhs_qparams_stride);
68+
69+
} // namespace channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot
70+
71+
namespace channelwise_8bit_a_channelwise_8bit_b_f32 {
72+
73+
template <
74+
bool a_has_zeros,
75+
bool b_has_zeros,
76+
bool a_transposed,
77+
bool b_tranposed>
78+
void kernel(
79+
int m,
80+
int n,
81+
int k,
82+
const void* lhs,
83+
int lhs_stride_m,
84+
const void* rhs,
85+
int rhs_stride_n,
86+
float32_t* output,
87+
int out_stride_m,
88+
const int8_t* lhs_zero_points,
89+
const int8_t* rhs_zero_points,
90+
const float* lhs_scales,
91+
const float* rhs_scales,
92+
const int lhs_qparams_stride,
93+
const int rhs_qparams_stride);
94+
95+
template <
96+
bool a_has_zeros,
97+
bool b_has_zeros,
98+
bool a_transposed,
99+
bool b_tranposed>
100+
void kernel(
101+
int m,
102+
int n,
103+
int k,
104+
const void* lhs,
105+
int lhs_stride_m,
106+
const void* rhs,
107+
int rhs_stride_n,
108+
float32_t* output,
109+
int out_stride_m,
110+
const int8_t* lhs_zero_points,
111+
const int8_t* rhs_zero_points,
112+
const float* lhs_scales,
113+
const float* rhs_scales,
114+
const int lhs_qparams_stride,
115+
const int rhs_qparams_stride) {
116+
// TODO: Replace this with KerneConfig based dispatch
117+
constexpr size_t gemm_nr = 8;
118+
constexpr size_t gemm_kr = 16;
119+
if ((n % gemm_nr == 0) && (k % gemm_kr == 0) && m > 4) {
120+
auto remaining_m = m % 4;
121+
auto m_for_gemm_kernel = m - remaining_m;
122+
channelwise_8bit_a_channelwise_8bit_b_4x8x8_f32_neondot::
123+
kernel<a_has_zeros, b_has_zeros, a_transposed, b_tranposed>(
124+
m_for_gemm_kernel,
125+
n,
126+
k,
127+
lhs,
128+
lhs_stride_m,
129+
rhs,
130+
rhs_stride_n,
131+
output,
132+
out_stride_m,
133+
lhs_zero_points,
134+
rhs_zero_points,
135+
lhs_scales,
136+
rhs_scales,
137+
lhs_qparams_stride,
138+
rhs_qparams_stride);
139+
output += m_for_gemm_kernel * out_stride_m;
140+
lhs = (static_cast<const int8_t*>(lhs) + m_for_gemm_kernel * lhs_stride_m);
141+
lhs_zero_points = lhs_zero_points + m_for_gemm_kernel * lhs_qparams_stride;
142+
lhs_scales = lhs_scales + m_for_gemm_kernel * lhs_qparams_stride;
143+
m = remaining_m;
144+
}
145+
if (m > 0) {
146+
channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::
147+
kernel<a_has_zeros, b_has_zeros, a_transposed, b_tranposed>(
148+
m,
149+
n,
150+
k,
151+
lhs,
152+
lhs_stride_m,
153+
rhs,
154+
rhs_stride_n,
155+
output,
156+
out_stride_m,
157+
lhs_zero_points,
158+
rhs_zero_points,
159+
lhs_scales,
160+
rhs_scales,
161+
lhs_qparams_stride,
162+
rhs_qparams_stride);
163+
}
164+
}
165+
166+
} // namespace channelwise_8bit_a_channelwise_8bit_b_f32
167+
45168
namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal {
46169

47170
template <

torchao/experimental/kernels/cpu/interface/quantized_matmul.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ get_int8_a_int8_b_channelwise_qmatmul(
7070
a_stride_m = k;
7171
b_stride_n = k;
7272
return aarch64::quantized_matmul::
73-
channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::
73+
channelwise_8bit_a_channelwise_8bit_b_f32::
7474
kernel<true, true, false, true>;
7575
}
7676
#endif // defined(__aarch64__) && defined(__ARM_NEON)

torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,18 @@ TEST(test_channelwise_8bit_channelwise_8bit_b, TranposeBWithZeroPointsLargeM) {
347347
/*m=*/4, /*k=*/128, /*n=*/16);
348348
}
349349

350+
TEST(
351+
test_channelwise_8bit_channelwise_8bit_b,
352+
TranposeBWithZeroPointsLargeMWithGemmGemvMix) {
353+
test_channelwise_8bit_channelwise_8bit_b<
354+
true /*a_has_zeros*/,
355+
true /*b_has_zeros*/,
356+
false /*a_transposed*/,
357+
true /*b_transposed*/>::
358+
Run(
359+
/*m=*/11, /*k=*/128, /*n=*/16);
360+
}
361+
350362
TEST(
351363
test_channelwise_8bit_channelwise_8bit_b,
352364
TranposedBWithZeroPointsOddSizes) {

0 commit comments

Comments
 (0)