Skip to content

Commit b31302b

Browse files
authored
Add api interface for kernel.
Differential Revision: D77312726 Pull Request resolved: #2492
1 parent 55bc882 commit b31302b

File tree

1 file changed

+187
-0
lines changed

1 file changed

+187
-0
lines changed
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
#pragma once
2+
3+
#if defined(__aarch64__) || defined(__ARM_NEON)
4+
5+
#include <arm_neon.h>
6+
#include <stddef.h>
7+
#include <cassert>
8+
#include <stdexcept>
9+
10+
#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/kernel_f32-impl.h>
11+
#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/pack_activations.h>
12+
#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/pack_weights.h>
13+
14+
namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut {
15+
16+
/**
17+
* @brief Calculates the total size in bytes required for the packed weight.
18+
*
19+
* @param m The number of rows in the source activation matrix.
20+
* @param k The number of columns in the source activation matrix.
21+
* @param mr The row-tiling factor of the micro-kernel.
22+
* @param kr The column-tiling factor of the micro-kernel.
23+
* @param sr The split ratio of the micro-kernel.
24+
*/
25+
inline size_t packed_activations_size(int m, int k, int mr, int kr, int sr) {
26+
(void)mr; // unused
27+
(void)kr; // unused
28+
(void)sr; // unused
29+
return activation_packing::packed_activations_size(m, k);
30+
}
31+
32+
/**
33+
* @brief Packs a row-major activation matrix into a kernel-optimized blocked
34+
layout.
35+
*
36+
* @tparam mr_ The row-tiling factor of the micro-kernel (Currently only have
37+
1).
38+
* @tparam kr_ The column-tiling factor of the micro-kernel (e.g., 32).
39+
* @tparam sr_ Split ratio determine how the k dimension of a weight tile is
40+
chunked and interleaved during the packing process.
41+
* @param output Pointer to the destination buffer.
42+
* @param m The number of rows in the source activation matrix.
43+
* @param k The number of columns in the source activation matrix.
44+
* @param input Pointer to the source activation matrix (float32, row-major).
45+
*/
46+
template <int mr_, int kr_, int sr_>
47+
inline void pack_activations(float* output, int m, int k, const float* input) {
48+
activation_packing::pack_activations<mr_, kr_, sr_>(output, m, k, input);
49+
}
50+
51+
/**
52+
* @brief Calculates the total size in bytes required for the packed weight
53+
* buffer for the groupwise LUT kernel format.
54+
*
55+
* @param n The number of columns in the weight matrix.
56+
* @param k The number of rows in the weight matrix.
57+
* @param weight_nbit The number of bits per weight (e.g., 2, 3, 4).
58+
* @param scale_group_size The number of weights along the K dim that share a
59+
* scale factor.
60+
* @param has_scales If true, the packed buffer will contain scale factors.
61+
* @param has_bias If true, the packed buffer will contain bias terms.
62+
* @param nr The column-tiling factor for the kernel (e.g., 16).
63+
* @param kr The column-tiling factor for the kernel (e.g., 16).
64+
* @param sr The split ratio of the micro-kernel.
65+
* @return The total required size of the packed buffer in bytes.
66+
*/
67+
inline size_t packed_weights_size(
68+
int n,
69+
int k,
70+
int weight_nbit,
71+
int scale_group_size,
72+
bool has_scales,
73+
bool has_bias,
74+
int nr,
75+
int kr,
76+
int sr) {
77+
(void)sr; // unused
78+
return weight_packing::packed_weights_size(
79+
n, k, weight_nbit, scale_group_size, has_scales, has_bias, nr, kr);
80+
}
81+
82+
/**
83+
* @brief Packs weights, LUTs, scales and bias into a kernel-optimized format.
84+
* @tparam weight_nbit_ The true bit-width of the weights.
85+
* @tparam nr_ The column-tiling factor for the kernel (e.g., 4).
86+
* @tparam kr_ The column-tiling factor of the micro-kernel (e.g., 32).
87+
* @tparam sr_ Split ratio determine how the k dimension of a weight tile is
88+
chunked and interleaved during the packing process.
89+
* @param packed_weights_ptr Pointer to the destination buffer.
90+
* @param weight_qvals_indices Pointer to the quantized weight matrix (uint8,
91+
row-major).
92+
* @param weight_scales Pointer to the scale factors (float32, row-major).
93+
* @param weight_luts Pointer to the LUTs (float32, row-major).
94+
* @param n The number of columns in the weight matrix.
95+
* @param k The number of rows in the weight matrix.
96+
* @param scale_group_size The number of weights that share a scale factor.
97+
* @param lut_group_size The number of weights that share a LUT.
98+
* @param has_scales If true, the packed buffer will contain scale factors.
99+
* @param has_bias If true, the packed buffer will contain bias terms.
100+
* @param bias Pointer to the bias vector (float32, row-major).
101+
*/
102+
template <int weight_nbit_, int nr_, int kr_, int sr_>
103+
void pack_weights_for_groupwise_lut_kernel(
104+
/*output*/
105+
void* packed_weights_ptr,
106+
/*inputs*/
107+
const uint8_t* weight_qvals_indices,
108+
const float* weight_scales,
109+
const float* weight_luts,
110+
int n,
111+
int k,
112+
int scale_group_size,
113+
int lut_group_size,
114+
bool has_scales,
115+
bool has_bias,
116+
const float* bias) {
117+
weight_packing::pack_weights<weight_nbit_, nr_, kr_, sr_>(
118+
packed_weights_ptr,
119+
weight_qvals_indices,
120+
weight_scales,
121+
weight_luts,
122+
n,
123+
k,
124+
scale_group_size,
125+
lut_group_size,
126+
has_scales,
127+
has_bias,
128+
bias);
129+
}
130+
131+
/**
132+
* @brief Computes a group-wise low-bit GEMM using an optimized NEON kernel.
133+
*
134+
* This function selects the best available micro-kernel based on the provided
135+
* tile sizes (MR and NR) and dispatches the computation.
136+
* @tparam weight_nbit_ The true bit-width of the weights (e.g., 2, 3, 4).
137+
* @tparam has_scales_ If true, applies the scales.
138+
* @param output Pointer to the output matrix C.
139+
* @param output_m_stride The stride (in elements) between rows of the output
140+
* matrix.
141+
* @param m Number of rows in A and C.
142+
* @param n Number of columns in B and C.
143+
* @param k Number of columns in A and rows in B.
144+
* @param scale_group_size The grouping factor for scales.
145+
* @param lut_group_size The grouping factor for LUTs.
146+
* @param packed_weights Pointer to the pre-packed weight buffer.
147+
* @param packed_activations Pointer to the pre-packed activation buffer.
148+
* @param biases Pointer to the bias vector.
149+
* @param clamp_min Minimum value for the fused clamp (ReLU) operation.
150+
* @param clamp_max Maximum value for the fused clamp (ReLU6) operation.
151+
* @param has_bias If true, applies the bias.
152+
* @param has_clamp If true, applies the clamping.
153+
*/
154+
template <int weight_nbit_, bool has_scales_>
155+
inline void groupwise_lowbit_weight_lut_kernel_1x4x32(
156+
float* output,
157+
int output_m_stride,
158+
int m,
159+
int n,
160+
int k,
161+
int scale_group_size,
162+
int lut_group_size,
163+
const void* packed_weights,
164+
const void* packed_activations,
165+
float clamp_min,
166+
float clamp_max,
167+
bool has_bias,
168+
bool has_clamp) {
169+
kernel::groupwise_lowbit_weight_lut_kernel_1x4x32<weight_nbit_, has_scales_>(
170+
output,
171+
output_m_stride,
172+
m,
173+
n,
174+
k,
175+
scale_group_size,
176+
lut_group_size,
177+
packed_weights,
178+
packed_activations,
179+
clamp_min,
180+
clamp_max,
181+
has_bias,
182+
has_clamp);
183+
}
184+
} // namespace
185+
// torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut
186+
187+
#endif // defined(__aarch64__) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)