Skip to content

Commit 55bc882

Browse files
authored
Add kernel implementatin for the lut kernel
Differential Revision: D77315506 Pull Request resolved: #2489
1 parent 2a24a00 commit 55bc882

File tree

1 file changed

+239
-0
lines changed
  • torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight

1 file changed

+239
-0
lines changed
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
#pragma once
7+
8+
#if defined(aarch64) || defined(__ARM_NEON)
9+
#include <arm_neon.h>
10+
#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/pack_weights.h>
11+
#include <torchao/experimental/kernels/cpu/aarch64/lut/lut.h>
12+
#include <array>
13+
#include <cassert>
14+
#include <cstring>
15+
16+
namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut::
17+
kernel {
18+
19+
namespace lut_utils = torchao::lut;
20+
namespace weight_packing = torchao::kernels::cpu::aarch64::linear::
21+
groupwise_lowbit_weight_lut::weight_packing;
22+
23+
namespace internal {
24+
25+
/*
26+
* @brief Computes a single tile of the output matrix.
27+
* @tparam weight_nbit_ The bit-precision of the quantized weight indices.
28+
* @tparam has_scales A compile-time flag to enable the application of scales.
29+
*
30+
* @param accum A NEON vector of 4 floats used as an in-out accumulator.
31+
* @param activation_tile_ptr Pointer to the 32-float activation tile.
32+
* @param packed_indices_ptr Pointer to the bit-packed weight indices.
33+
* @param lut_neon The dequantization LUT, pre-formatted for NEON lookups.
34+
* @param scale_vec A NEON vector with the four dequantization scales.
35+
*/
36+
template <int weight_nbit_, bool has_scales>
37+
TORCHAO_ALWAYS_INLINE static inline void compute_tile_1x4x32(
38+
float32x4_t& accum,
39+
const float* __restrict__ activation_tile_ptr,
40+
const uint8_t* __restrict__ packed_indices_ptr,
41+
const uint8x16x4_t& lut_neon,
42+
const float32x4_t scale_vec) {
43+
// 1. Unpack indices
44+
uint8x16_t idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7;
45+
bitpacking::vec_unpack_128_uintx_values<weight_nbit_>(
46+
idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7, packed_indices_ptr);
47+
48+
const std::array<uint8x16_t, 8> unpacked_indices = {
49+
idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7};
50+
51+
for (int sr_idx = 0; sr_idx < 8; ++sr_idx) {
52+
// Load the 4 activations corresponding to this chunk
53+
const float* activation_chunk_ptr = activation_tile_ptr + sr_idx * 4;
54+
float32x4_t a = vld1q_f32(activation_chunk_ptr);
55+
56+
// Lookup the 4x4 weight sub-tile (as columns)
57+
float32x4_t w_col0, w_col1, w_col2, w_col3;
58+
lut_utils::lookup_from_fp32_lut(
59+
w_col0, w_col1, w_col2, w_col3, lut_neon, unpacked_indices[sr_idx]);
60+
61+
float32x4x2_t tmp0 = vtrnq_f32(w_col0, w_col1);
62+
float32x4x2_t tmp1 = vtrnq_f32(w_col2, w_col3);
63+
float32x4_t w_row0 =
64+
vcombine_f32(vget_low_f32(tmp0.val[0]), vget_low_f32(tmp1.val[0]));
65+
float32x4_t w_row1 =
66+
vcombine_f32(vget_low_f32(tmp0.val[1]), vget_low_f32(tmp1.val[1]));
67+
float32x4_t w_row2 =
68+
vcombine_f32(vget_high_f32(tmp0.val[0]), vget_high_f32(tmp1.val[0]));
69+
float32x4_t w_row3 =
70+
vcombine_f32(vget_high_f32(tmp0.val[1]), vget_high_f32(tmp1.val[1]));
71+
72+
// Conditionally apply scales at compile time
73+
if constexpr (has_scales) {
74+
w_row0 = vmulq_f32(w_row0, scale_vec);
75+
w_row1 = vmulq_f32(w_row1, scale_vec);
76+
w_row2 = vmulq_f32(w_row2, scale_vec);
77+
w_row3 = vmulq_f32(w_row3, scale_vec);
78+
}
79+
80+
// Use vfmaq_n_f32 to multiply each row vector by the corresponding scalar
81+
// activation.
82+
accum = vfmaq_n_f32(
83+
accum, w_row0, vgetq_lane_f32(a, 0)); // accum += w_row0 * a[0]
84+
accum = vfmaq_n_f32(
85+
accum, w_row1, vgetq_lane_f32(a, 1)); // accum += w_row1 * a[1]
86+
accum = vfmaq_n_f32(
87+
accum, w_row2, vgetq_lane_f32(a, 2)); // accum += w_row2 * a[2]
88+
accum = vfmaq_n_f32(
89+
accum, w_row3, vgetq_lane_f32(a, 3)); // accum += w_row3 * a[3]
90+
}
91+
}
92+
93+
/**
94+
* @brief Stores the accumulated values to the output matrix.
95+
* @tparam mr_ The row-tiling factor of the micro-kernel.
96+
* @tparam nr_ The column-tiling factor of the micro-kernel.
97+
*
98+
* @param output The output matrix.
99+
* @param ldc The leading dimension of the output matrix.
100+
* @param n_cols The number of columns in the output matrix.
101+
* @param n_tile_start The starting column index of the current tile.
102+
* @param accum The accumulated values.
103+
* @param bias_ptr The pointer to the bias vector.
104+
* @param has_clamp Whether to apply clamping.
105+
* @param clamp_min_vec The minimum value for clamping.
106+
* @param clamp_max_vec The maximum value for clamping.
107+
*/
108+
template <int mr_, int nr_>
109+
TORCHAO_ALWAYS_INLINE static inline void post_process_and_store(
110+
float* __restrict__ output,
111+
int ldc,
112+
int n_cols,
113+
int n_tile_start,
114+
const float32x4_t accum[mr_][nr_ / 4],
115+
const float* __restrict__ bias_ptr,
116+
bool has_clamp,
117+
const float32x4_t& clamp_min_vec,
118+
const float32x4_t& clamp_max_vec) {
119+
constexpr int NR_VEC = nr_ / 4;
120+
for (int m = 0; m < mr_; ++m) {
121+
float* out_row = output + m * ldc;
122+
for (int nb = 0; nb < NR_VEC; ++nb) {
123+
float32x4_t res = accum[m][nb];
124+
if (bias_ptr != nullptr) {
125+
float32x4_t bias_vec = vld1q_f32(bias_ptr + nb * 4);
126+
res = vaddq_f32(res, bias_vec);
127+
}
128+
if (has_clamp) {
129+
res = vmaxq_f32(res, clamp_min_vec);
130+
res = vminq_f32(res, clamp_max_vec);
131+
}
132+
133+
const int current_n_offset = n_tile_start + nb * 4;
134+
const int remaining_cols = n_cols - current_n_offset;
135+
if (remaining_cols < 4) {
136+
float temp_res[4];
137+
vst1q_f32(temp_res, res);
138+
for (int i = 0; i < remaining_cols; ++i) {
139+
*(out_row + current_n_offset + i) = temp_res[i];
140+
}
141+
} else {
142+
vst1q_f32(out_row + current_n_offset, res);
143+
}
144+
}
145+
}
146+
}
147+
148+
} // namespace internal
149+
150+
/*
151+
* @brief The main kernel for groupwise low-bit weight LUT.
152+
*/
153+
template <int weight_nbit_, bool has_scales>
154+
void groupwise_lowbit_weight_lut_kernel_1x4x32(
155+
float* output,
156+
int output_m_stride,
157+
int m,
158+
int n,
159+
int k,
160+
int scale_group_size,
161+
int lut_group_size,
162+
const void* packed_weights,
163+
const void* packed_activations,
164+
float clamp_min,
165+
float clamp_max,
166+
bool has_bias,
167+
bool has_clamp) {
168+
constexpr int mr_ = 1;
169+
constexpr int nr_ = 4;
170+
constexpr int kr_ = 32;
171+
172+
const auto* typed_activations_ptr =
173+
static_cast<const float*>(packed_activations);
174+
const float32x4_t clamp_min_vec = vdupq_n_f32(clamp_min);
175+
const float32x4_t clamp_max_vec = vdupq_n_f32(clamp_max);
176+
constexpr int bytes_per_weight_tile = ((nr_ * kr_ * weight_nbit_) + 7) / 8;
177+
178+
for (int m_tile_start = 0; m_tile_start < m; m_tile_start += mr_) {
179+
const float* activation_row_ptr = typed_activations_ptr + m_tile_start * k;
180+
const uint8_t* packed_ptr = static_cast<const uint8_t*>(packed_weights);
181+
182+
for (int n_tile_start = 0; n_tile_start < n; n_tile_start += nr_) {
183+
float32x4_t accumulators[mr_][nr_ / 4] = {{vdupq_n_f32(0.0f)}};
184+
185+
uint8x16x4_t lut_neon;
186+
// Load the 16-float LUT for this tile.
187+
lut_utils::load_fp32_lut(
188+
lut_neon, reinterpret_cast<const float*>(packed_ptr));
189+
// Advance the pointer past the LUT.
190+
packed_ptr += 16 * sizeof(float);
191+
float32x4_t scale_vec = vdupq_n_f32(1.0f);
192+
for (int k_tile_start = 0; k_tile_start < k; k_tile_start += kr_) {
193+
if constexpr (has_scales) {
194+
const float* scale_for_tile = nullptr;
195+
196+
if (k_tile_start % scale_group_size == 0) {
197+
scale_for_tile = reinterpret_cast<const float*>(packed_ptr);
198+
scale_vec = vld1q_f32(scale_for_tile);
199+
packed_ptr += nr_ * sizeof(float);
200+
}
201+
}
202+
203+
// The current packed_ptr points to the weight indices.
204+
const uint8_t* indices_ptr = packed_ptr;
205+
206+
internal::compute_tile_1x4x32<weight_nbit_, has_scales>(
207+
accumulators[0][0],
208+
activation_row_ptr + k_tile_start,
209+
indices_ptr,
210+
lut_neon,
211+
scale_vec);
212+
213+
// Advance pointer past the weights that were just used.
214+
packed_ptr += bytes_per_weight_tile;
215+
}
216+
217+
const float* bias_for_tile = nullptr;
218+
if (has_bias) {
219+
bias_for_tile = reinterpret_cast<const float*>(packed_ptr);
220+
packed_ptr += nr_ * sizeof(float);
221+
}
222+
223+
float* output_row_ptr = output + m_tile_start * output_m_stride;
224+
internal::post_process_and_store<mr_, nr_>(
225+
output_row_ptr,
226+
output_m_stride,
227+
n,
228+
n_tile_start,
229+
accumulators,
230+
bias_for_tile,
231+
has_clamp,
232+
clamp_min_vec,
233+
clamp_max_vec);
234+
}
235+
}
236+
}
237+
} // namespace
238+
// torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut::kernel
239+
#endif // defined(aarch64) || defined(__ARM_NEON)

0 commit comments

Comments
 (0)