Skip to content

Commit c99e37c

Browse files
authored
Add gemm kernel to interface
Differential Revision: D71833068 Pull Request resolved: #2040
1 parent 5cb1fa1 commit c99e37c

File tree

3 files changed

+97
-5
lines changed

3 files changed

+97
-5
lines changed

torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#pragma once
1212

13+
#include <cassert>
1314
#if defined(__aarch64__) && defined(__ARM_NEON)
1415

1516
#include <arm_neon.h>
@@ -106,6 +107,83 @@ void kernel(
106107
const int rhs_qparams_stride);
107108

108109
} // namespace fp32_a_input_channelwise_8bit_b_4x16x4_f32
110+
111+
namespace fp32_a_input_channelwise_8bit_b_f32 {
112+
113+
template <bool b_has_zeros, bool a_transposed, bool b_tranposed>
114+
void kernel(
115+
int m,
116+
int n,
117+
int k,
118+
const float* lhs,
119+
int lhs_stride_m,
120+
const int8_t* rhs,
121+
int rhs_stride_n,
122+
float32_t* output,
123+
int out_stride_m,
124+
const int8_t* rhs_zero_points,
125+
const float* rhs_scales,
126+
const float beta,
127+
const int rhs_qparams_stride);
128+
129+
template <bool b_has_zeros, bool a_transposed, bool b_tranposed>
130+
void kernel(
131+
int m,
132+
int n,
133+
int k,
134+
const float* lhs,
135+
int lhs_stride_m,
136+
const int8_t* rhs,
137+
int rhs_stride_n,
138+
float32_t* output,
139+
int out_stride_m,
140+
const int8_t* rhs_zero_points,
141+
const float* rhs_scales,
142+
const float beta,
143+
const int rhs_qparams_stride) {
144+
assert(n >= 16);
145+
if (m > 16) {
146+
auto remaining_m = m % 16;
147+
auto m_for_gemm_kernel = m - remaining_m;
148+
fp32_a_input_channelwise_8bit_b_4x16x4_f32::
149+
kernel<b_has_zeros, a_transposed, b_tranposed>(
150+
m_for_gemm_kernel,
151+
n,
152+
k,
153+
lhs,
154+
lhs_stride_m,
155+
rhs,
156+
rhs_stride_n,
157+
output,
158+
out_stride_m,
159+
rhs_zero_points,
160+
rhs_scales,
161+
beta,
162+
rhs_qparams_stride);
163+
output += m_for_gemm_kernel * out_stride_m;
164+
lhs += m_for_gemm_kernel * lhs_stride_m;
165+
m = remaining_m;
166+
}
167+
if (m > 0) {
168+
fp32_a_input_channelwise_8bit_b_1x16x4_f32::
169+
kernel<b_has_zeros, a_transposed, b_tranposed>(
170+
m,
171+
n,
172+
k,
173+
lhs,
174+
lhs_stride_m,
175+
rhs,
176+
rhs_stride_n,
177+
output,
178+
out_stride_m,
179+
rhs_zero_points,
180+
rhs_scales,
181+
beta,
182+
rhs_qparams_stride);
183+
}
184+
}
185+
186+
} // namespace fp32_a_input_channelwise_8bit_b_f32
109187
} // namespace torchao::kernels::cpu::aarch64::quantized_matmul
110188

111189
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h>

torchao/experimental/kernels/cpu/interface/quantized_matmul.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
#include <torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h>
1313

1414
#if defined(__aarch64__) && defined(__ARM_NEON)
15-
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h>
16-
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h>
17-
#include <torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h>
15+
#include <torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h>
1816
#endif // defined(__aarch64__) && defined(__ARM_NEON)
1917

2018
namespace torchao::kernels::cpu::quantized_matmul {
@@ -138,8 +136,8 @@ get_fp32_a_input_channelwise_8bit_b_f32_c_matmul(
138136
if (!a_transposed && !b_transposed && n >= 16) {
139137
a_stride_m = k;
140138
b_stride_n = n;
141-
return aarch64::quantized_matmul::
142-
fp32_a_input_channelwise_8bit_b_1x16x4_f32::kernel<true, false, false>;
139+
return aarch64::quantized_matmul::fp32_a_input_channelwise_8bit_b_f32::
140+
kernel<true, false, false>;
143141
}
144142
#endif // defined(__aarch64__) && defined(__ARM_NEON)
145143
assert(!a_transposed);

torchao/experimental/kernels/cpu/interface/test_qmatmul_interface.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,22 @@ TEST_P(
624624
/*m=*/4, /*k=*/5, /*n=*/3, beta(), *this, 32);
625625
}
626626

627+
TEST_P(
628+
FP32A_QuantizedB_FP32C_Interface_Test,
629+
BTranposedWithZeroPointsOddSizes2) {
630+
generate(19, 37, 35, true, false, false);
631+
test_fp32_a_input_channelwise_8bit_b(
632+
/*m=*/19, /*k=*/37, /*n=*/35, beta(), *this);
633+
}
634+
635+
TEST_P(
636+
FP32A_QuantizedB_FP32C_Interface_Test,
637+
BTranposedWithZeroPointsOddSizesStrided2) {
638+
generate(23, 37, 50, true, false, false, 32);
639+
test_fp32_a_input_channelwise_8bit_b(
640+
/*m=*/23, /*k=*/37, /*n=*/50, beta(), *this, 32);
641+
}
642+
627643
INSTANTIATE_TEST_SUITE_P(
628644
F32AInt8BFP32CTest,
629645
FP32A_QuantizedB_FP32C_Interface_Test,

0 commit comments

Comments
 (0)