Skip to content

Commit 2a24a00

Browse files
authored
Add packing activation and packing weight.
Differential Revision: D77312714 Pull Request resolved: #2486
1 parent 2d61be8 commit 2a24a00

File tree

2 files changed

+259
-0
lines changed

2 files changed

+259
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#pragma once
2+
3+
#if defined(__aarch64__) || defined(__ARM_NEON)
4+
5+
#include <cassert>
6+
#include <cstddef>
7+
#include <cstring>
8+
#include <vector>
9+
10+
namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut::
11+
activation_packing {
12+
13+
inline size_t packed_activations_size(int m, int k) {
14+
return m * k * sizeof(float);
15+
}
16+
17+
template <int mr_, int kr_, int sr_>
18+
void pack_activations(
19+
// Output
20+
float* packed_activations,
21+
// Inputs
22+
int m,
23+
int k,
24+
const float* activations) {
25+
static_assert(mr_ == 1);
26+
std::memcpy(packed_activations, activations, sizeof(float) * m * k);
27+
}
28+
} // namespace
29+
// torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut::activation_packing
30+
31+
#endif // defined(__aarch64__) || defined(__ARM_NEON)
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
#pragma once
2+
3+
#if defined(aarch64) || defined(__ARM_NEON)
4+
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
5+
#include <torchao/experimental/kernels/cpu/aarch64/lut/lut.h>
6+
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
7+
#include <torchao/experimental/kernels/cpu/aarch64/packing/utils.h>
8+
#include <cassert>
9+
#include <cstdint>
10+
#include <cstring>
11+
#include <numeric>
12+
#include <vector>
13+
14+
namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut::
15+
weight_packing {
16+
namespace lut_utils = torchao::lut;
17+
namespace packing_utils = torchao::packing;
18+
19+
/**
20+
* @brief Calculates the exact buffer size in bytes for packed weights.
21+
*
22+
* This function computes the total memory required for a weight buffer based on
23+
* a specific packing layout. The calculation accounts for tiled weights, a
24+
* Look-Up Table (LUT), and optional interleaved scales and biases. It assumes
25+
* the 'n' dimension is padded to be a multiple of the tile height 'nr'.
26+
*
27+
* @param n The number of output channels (columns) in the weight matrix.
28+
* @param k The number of input channels (rows) in the weight matrix.
29+
* @param weight_nbit The bit precision for each weight (e.g., 4, 8).
30+
* @param scale_group_size The number of weights that share a single scale
31+
* factor.
32+
* @param has_scales Set to true to include space for scaling factors.
33+
* @param has_bias Set to true to include space for a bias vector.
34+
* @param nr The tile height used for packing along the 'n' dimension.
35+
* @param kr The tile width used for packing along the 'k' dimension.
36+
* @return The total required size in bytes for the complete packed buffer.
37+
*/
38+
inline size_t packed_weights_size(
39+
int n,
40+
int k,
41+
int weight_nbit,
42+
int scale_group_size,
43+
bool has_scales,
44+
bool has_bias,
45+
int nr,
46+
int kr) {
47+
size_t size_per_n_strip = 0;
48+
49+
// 1. Size of the LUT, written once per strip.
50+
size_per_n_strip += 16 * sizeof(float);
51+
52+
// 2. Size of the interleaved scales.
53+
if (has_scales) {
54+
assert(
55+
k % scale_group_size == 0 &&
56+
"k must be a multiple of scale_group_size");
57+
size_t num_scale_blocks = k / scale_group_size;
58+
size_per_n_strip += num_scale_blocks * nr * sizeof(float);
59+
}
60+
61+
// 3. Size of the packed weight tiles.
62+
assert(k % kr == 0 && "k must be a multiple of kr");
63+
size_t num_k_tiles = k / kr;
64+
size_t bytes_per_weight_tile = ((nr * kr * weight_nbit) + 7) / 8;
65+
size_per_n_strip += num_k_tiles * bytes_per_weight_tile;
66+
67+
// 4. Size of the bias, written once per strip.
68+
if (has_bias) {
69+
size_per_n_strip += nr * sizeof(float);
70+
}
71+
72+
// Calculate the total number of n-strips, padding n to a multiple of nr.
73+
int num_n_strips = (n + nr - 1) / nr;
74+
75+
return size_per_n_strip * num_n_strips;
76+
}
77+
78+
/**
79+
* @brief Packs weights, LUTs, scales and bias into a kernel-optimized format.
80+
* @details The function organizes the output buffer into "n-strips," where
81+
each strip corresponds to a tile of `nr_` columns from the weight matrix.
82+
* The memory layout for each strip is as follows:
83+
* 1. **Look-Up Table (LUT):** A 16-element float LUT is written once at
84+
* the beginning of the strip.
85+
* 2. **Interleaved Scales:** If `has_scales` is true, dequantization
86+
* scales are interleaved. For each group of `scale_group_size`
87+
* elements along the k-dimension, `nr_` scale values (one for each
88+
* column in the strip) are written.
89+
* 3. **Packed Weight Tiles:** The core weight data is tiled into
90+
* (`nr_` x `kr_`) blocks. These blocks are then bit-packed and
91+
* interleaved according to the `sr_` ratio before being written.
92+
* 4. **Bias:** If `has_bias` is true, `nr_` bias values are appended
93+
* at the end of the strip.
94+
*
95+
* @tparam weight_nbit_ The true bit-width of the weights.
96+
* @tparam nr_ The column-tiling factor for the kernel (e.g., 4).
97+
* @tparam kr_ The column-tiling factor of the micro-kernel (e.g., 32).
98+
* @tparam sr_ Split ratio determine how the k dimension of a weight tile is
99+
chunked and interleaved during the packing process.
100+
* @param packed_weights_ptr Pointer to the destination buffer.
101+
* @param weight_qval_indices Pointer to the quantized weight matrix (uint8,
102+
row-major).
103+
* @param weight_scales Pointer to the scale factors (float32, row-major).
104+
* @param weight_luts Pointer to the LUTs (float32, row-major).
105+
* @param n The number of columns in the weight matrix.
106+
* @param k The number of rows in the weight matrix.
107+
* @param scale_group_size The number of weights that share a scale factor.
108+
* @param lut_group_size The number of weights that share a LUT.
109+
* @param has_scales If true, the packed buffer will contain scale factors.
110+
* @param has_bias If true, the packed buffer will contain bias terms.
111+
* @param bias Pointer to the bias vector (float32, row-major).
112+
*/
113+
template <int weight_nbit_, int nr_, int kr_, int sr_>
114+
TORCHAO_ALWAYS_INLINE inline void pack_weights(
115+
// Output
116+
void* packed_weights_ptr,
117+
// Inputs
118+
const uint8_t* weight_qval_indices,
119+
const float* weight_scales,
120+
const float* weight_luts,
121+
int n,
122+
int k,
123+
int scale_group_size,
124+
int lut_group_size,
125+
bool has_scales,
126+
bool has_bias,
127+
const float* bias) {
128+
static_assert(nr_ == 4);
129+
static_assert(kr_ == 32);
130+
static_assert(sr_ == 8);
131+
static_assert(kr_ % sr_ == 0, "kr must be divisible by sr");
132+
assert(k % kr_ == 0 && "K must be a multiple of tile dimension kr");
133+
assert(scale_group_size > 0 && "Scale group size must be positive");
134+
assert(lut_group_size > 0 && "LUT group size must be positive");
135+
136+
// Grouping hierarchy constraint
137+
assert(
138+
lut_group_size % scale_group_size == 0 &&
139+
"LUT group size must be a multiple of scale group size");
140+
141+
// Group compatibility constraints with tile dimensions
142+
assert(
143+
lut_group_size % (k * nr_) == 0 &&
144+
"LUT group size must be compatible with tile dimensions");
145+
assert(scale_group_size % kr_ == 0 && "Scale group size % kr must be 0");
146+
147+
auto* out_ptr = reinterpret_cast<uint8_t*>(packed_weights_ptr);
148+
constexpr int kLutBufferSize = 16;
149+
std::vector<float> lut_buffer(kLutBufferSize);
150+
151+
std::vector<uint8_t> padded_tile(nr_ * kr_);
152+
153+
std::vector<uint8_t> tmp_buffer(128);
154+
constexpr int bytes_per_128_packed_values =
155+
((nr_ * kr_ * weight_nbit_) + 7) / 8;
156+
157+
const int lut_size = 1 << weight_nbit_;
158+
const int scales_per_col = k / scale_group_size;
159+
160+
for (int n_idx = 0; n_idx < n; n_idx += nr_) {
161+
int current_lut_idx = (n_idx * k) / lut_group_size;
162+
163+
std::memset(lut_buffer.data(), 0, 16 * sizeof(float));
164+
std::memcpy(out_ptr, lut_buffer.data(), 16 * sizeof(float));
165+
166+
std::memcpy(
167+
lut_buffer.data(),
168+
weight_luts + current_lut_idx * lut_size,
169+
lut_size * sizeof(float));
170+
std::memcpy(out_ptr, lut_buffer.data(), 16 * sizeof(float));
171+
out_ptr += 16 * sizeof(float);
172+
173+
for (int k_idx = 0; k_idx < k; k_idx += kr_) {
174+
int w_idx = n_idx * k + k_idx;
175+
// Write scales if k_idx is a multiple of scale_group_size
176+
if (has_scales && (k_idx % scale_group_size == 0)) {
177+
int scale_idx = w_idx / scale_group_size;
178+
// Write scales for next nr columns
179+
for (int j = 0; j < nr_; j++) {
180+
float scale = 0.0;
181+
if (n_idx + j < n) {
182+
scale = weight_scales[scale_idx + j * scales_per_col];
183+
}
184+
std::memcpy(out_ptr, &scale, sizeof(float));
185+
out_ptr += sizeof(float);
186+
}
187+
}
188+
// Write 128 packed tile (kr x nr)
189+
std::memset(padded_tile.data(), 0, 128);
190+
for (int j = 0; j < nr_; j++) {
191+
if (n_idx + j < n) {
192+
std::memcpy(
193+
padded_tile.data() + j * kr_,
194+
weight_qval_indices + w_idx + j * k,
195+
kr_);
196+
}
197+
}
198+
packing_utils::pack_values(
199+
tmp_buffer.data(), padded_tile.data(), nr_, kr_, sr_);
200+
const uint8_t* buffer = tmp_buffer.data();
201+
torchao::bitpacking::vec_pack_128_uintx_values<weight_nbit_>(
202+
reinterpret_cast<uint8_t*>(out_ptr),
203+
vld1q_u8(buffer),
204+
vld1q_u8(buffer + 16),
205+
vld1q_u8(buffer + 32),
206+
vld1q_u8(buffer + 48),
207+
vld1q_u8(buffer + 64),
208+
vld1q_u8(buffer + 80),
209+
vld1q_u8(buffer + 96),
210+
vld1q_u8(buffer + 112));
211+
out_ptr += bytes_per_128_packed_values;
212+
} // k_idx
213+
214+
if (has_bias) {
215+
for (int i = 0; i < nr_; i++) {
216+
float current_bias = 0.0;
217+
if (n_idx + i < n) {
218+
current_bias = bias[n_idx + i];
219+
}
220+
std::memcpy(out_ptr, &current_bias, sizeof(float));
221+
out_ptr += sizeof(float);
222+
}
223+
}
224+
}
225+
}
226+
} // namespace
227+
// torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut::weight_packing
228+
#endif // defined(aarch64) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)