|
| 1 | +// Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +// All rights reserved. |
| 3 | +// |
| 4 | +// This source code is licensed under the license found in the |
| 5 | +// LICENSE file in the root directory of this source tree. |
| 6 | +#pragma once |
| 7 | + |
| 8 | +#if defined(aarch64) || defined(__ARM_NEON) |
| 9 | +#include <arm_neon.h> |
| 10 | +#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/pack_weights.h> |
| 11 | +#include <torchao/experimental/kernels/cpu/aarch64/lut/lut.h> |
| 12 | +#include <array> |
| 13 | +#include <cassert> |
| 14 | +#include <cstring> |
| 15 | + |
| 16 | +namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut:: |
| 17 | + kernel { |
| 18 | + |
| 19 | +namespace lut_utils = torchao::lut; |
| 20 | +namespace weight_packing = torchao::kernels::cpu::aarch64::linear:: |
| 21 | + groupwise_lowbit_weight_lut::weight_packing; |
| 22 | + |
| 23 | +namespace internal { |
| 24 | + |
| 25 | +/* |
| 26 | + * @brief Computes a single tile of the output matrix. |
| 27 | + * @tparam weight_nbit_ The bit-precision of the quantized weight indices. |
| 28 | + * @tparam has_scales A compile-time flag to enable the application of scales. |
| 29 | + * |
| 30 | + * @param accum A NEON vector of 4 floats used as an in-out accumulator. |
| 31 | + * @param activation_tile_ptr Pointer to the 32-float activation tile. |
| 32 | + * @param packed_indices_ptr Pointer to the bit-packed weight indices. |
| 33 | + * @param lut_neon The dequantization LUT, pre-formatted for NEON lookups. |
| 34 | + * @param scale_vec A NEON vector with the four dequantization scales. |
| 35 | + */ |
| 36 | +template <int weight_nbit_, bool has_scales> |
| 37 | +TORCHAO_ALWAYS_INLINE static inline void compute_tile_1x4x32( |
| 38 | + float32x4_t& accum, |
| 39 | + const float* __restrict__ activation_tile_ptr, |
| 40 | + const uint8_t* __restrict__ packed_indices_ptr, |
| 41 | + const uint8x16x4_t& lut_neon, |
| 42 | + const float32x4_t scale_vec) { |
| 43 | + // 1. Unpack indices |
| 44 | + uint8x16_t idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7; |
| 45 | + bitpacking::vec_unpack_128_uintx_values<weight_nbit_>( |
| 46 | + idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7, packed_indices_ptr); |
| 47 | + |
| 48 | + const std::array<uint8x16_t, 8> unpacked_indices = { |
| 49 | + idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7}; |
| 50 | + |
| 51 | + for (int sr_idx = 0; sr_idx < 8; ++sr_idx) { |
| 52 | + // Load the 4 activations corresponding to this chunk |
| 53 | + const float* activation_chunk_ptr = activation_tile_ptr + sr_idx * 4; |
| 54 | + float32x4_t a = vld1q_f32(activation_chunk_ptr); |
| 55 | + |
| 56 | + // Lookup the 4x4 weight sub-tile (as columns) |
| 57 | + float32x4_t w_col0, w_col1, w_col2, w_col3; |
| 58 | + lut_utils::lookup_from_fp32_lut( |
| 59 | + w_col0, w_col1, w_col2, w_col3, lut_neon, unpacked_indices[sr_idx]); |
| 60 | + |
| 61 | + float32x4x2_t tmp0 = vtrnq_f32(w_col0, w_col1); |
| 62 | + float32x4x2_t tmp1 = vtrnq_f32(w_col2, w_col3); |
| 63 | + float32x4_t w_row0 = |
| 64 | + vcombine_f32(vget_low_f32(tmp0.val[0]), vget_low_f32(tmp1.val[0])); |
| 65 | + float32x4_t w_row1 = |
| 66 | + vcombine_f32(vget_low_f32(tmp0.val[1]), vget_low_f32(tmp1.val[1])); |
| 67 | + float32x4_t w_row2 = |
| 68 | + vcombine_f32(vget_high_f32(tmp0.val[0]), vget_high_f32(tmp1.val[0])); |
| 69 | + float32x4_t w_row3 = |
| 70 | + vcombine_f32(vget_high_f32(tmp0.val[1]), vget_high_f32(tmp1.val[1])); |
| 71 | + |
| 72 | + // Conditionally apply scales at compile time |
| 73 | + if constexpr (has_scales) { |
| 74 | + w_row0 = vmulq_f32(w_row0, scale_vec); |
| 75 | + w_row1 = vmulq_f32(w_row1, scale_vec); |
| 76 | + w_row2 = vmulq_f32(w_row2, scale_vec); |
| 77 | + w_row3 = vmulq_f32(w_row3, scale_vec); |
| 78 | + } |
| 79 | + |
| 80 | + // Use vfmaq_n_f32 to multiply each row vector by the corresponding scalar |
| 81 | + // activation. |
| 82 | + accum = vfmaq_n_f32( |
| 83 | + accum, w_row0, vgetq_lane_f32(a, 0)); // accum += w_row0 * a[0] |
| 84 | + accum = vfmaq_n_f32( |
| 85 | + accum, w_row1, vgetq_lane_f32(a, 1)); // accum += w_row1 * a[1] |
| 86 | + accum = vfmaq_n_f32( |
| 87 | + accum, w_row2, vgetq_lane_f32(a, 2)); // accum += w_row2 * a[2] |
| 88 | + accum = vfmaq_n_f32( |
| 89 | + accum, w_row3, vgetq_lane_f32(a, 3)); // accum += w_row3 * a[3] |
| 90 | + } |
| 91 | +} |
| 92 | + |
| 93 | +/** |
| 94 | + * @brief Stores the accumulated values to the output matrix. |
| 95 | + * @tparam mr_ The row-tiling factor of the micro-kernel. |
| 96 | + * @tparam nr_ The column-tiling factor of the micro-kernel. |
| 97 | + * |
| 98 | + * @param output The output matrix. |
| 99 | + * @param ldc The leading dimension of the output matrix. |
| 100 | + * @param n_cols The number of columns in the output matrix. |
| 101 | + * @param n_tile_start The starting column index of the current tile. |
| 102 | + * @param accum The accumulated values. |
| 103 | + * @param bias_ptr The pointer to the bias vector. |
| 104 | + * @param has_clamp Whether to apply clamping. |
| 105 | + * @param clamp_min_vec The minimum value for clamping. |
| 106 | + * @param clamp_max_vec The maximum value for clamping. |
| 107 | + */ |
| 108 | +template <int mr_, int nr_> |
| 109 | +TORCHAO_ALWAYS_INLINE static inline void post_process_and_store( |
| 110 | + float* __restrict__ output, |
| 111 | + int ldc, |
| 112 | + int n_cols, |
| 113 | + int n_tile_start, |
| 114 | + const float32x4_t accum[mr_][nr_ / 4], |
| 115 | + const float* __restrict__ bias_ptr, |
| 116 | + bool has_clamp, |
| 117 | + const float32x4_t& clamp_min_vec, |
| 118 | + const float32x4_t& clamp_max_vec) { |
| 119 | + constexpr int NR_VEC = nr_ / 4; |
| 120 | + for (int m = 0; m < mr_; ++m) { |
| 121 | + float* out_row = output + m * ldc; |
| 122 | + for (int nb = 0; nb < NR_VEC; ++nb) { |
| 123 | + float32x4_t res = accum[m][nb]; |
| 124 | + if (bias_ptr != nullptr) { |
| 125 | + float32x4_t bias_vec = vld1q_f32(bias_ptr + nb * 4); |
| 126 | + res = vaddq_f32(res, bias_vec); |
| 127 | + } |
| 128 | + if (has_clamp) { |
| 129 | + res = vmaxq_f32(res, clamp_min_vec); |
| 130 | + res = vminq_f32(res, clamp_max_vec); |
| 131 | + } |
| 132 | + |
| 133 | + const int current_n_offset = n_tile_start + nb * 4; |
| 134 | + const int remaining_cols = n_cols - current_n_offset; |
| 135 | + if (remaining_cols < 4) { |
| 136 | + float temp_res[4]; |
| 137 | + vst1q_f32(temp_res, res); |
| 138 | + for (int i = 0; i < remaining_cols; ++i) { |
| 139 | + *(out_row + current_n_offset + i) = temp_res[i]; |
| 140 | + } |
| 141 | + } else { |
| 142 | + vst1q_f32(out_row + current_n_offset, res); |
| 143 | + } |
| 144 | + } |
| 145 | + } |
| 146 | +} |
| 147 | + |
| 148 | +} // namespace internal |
| 149 | + |
| 150 | +/* |
| 151 | + * @brief The main kernel for groupwise low-bit weight LUT. |
| 152 | + */ |
| 153 | +template <int weight_nbit_, bool has_scales> |
| 154 | +void groupwise_lowbit_weight_lut_kernel_1x4x32( |
| 155 | + float* output, |
| 156 | + int output_m_stride, |
| 157 | + int m, |
| 158 | + int n, |
| 159 | + int k, |
| 160 | + int scale_group_size, |
| 161 | + int lut_group_size, |
| 162 | + const void* packed_weights, |
| 163 | + const void* packed_activations, |
| 164 | + float clamp_min, |
| 165 | + float clamp_max, |
| 166 | + bool has_bias, |
| 167 | + bool has_clamp) { |
| 168 | + constexpr int mr_ = 1; |
| 169 | + constexpr int nr_ = 4; |
| 170 | + constexpr int kr_ = 32; |
| 171 | + |
| 172 | + const auto* typed_activations_ptr = |
| 173 | + static_cast<const float*>(packed_activations); |
| 174 | + const float32x4_t clamp_min_vec = vdupq_n_f32(clamp_min); |
| 175 | + const float32x4_t clamp_max_vec = vdupq_n_f32(clamp_max); |
| 176 | + constexpr int bytes_per_weight_tile = ((nr_ * kr_ * weight_nbit_) + 7) / 8; |
| 177 | + |
| 178 | + for (int m_tile_start = 0; m_tile_start < m; m_tile_start += mr_) { |
| 179 | + const float* activation_row_ptr = typed_activations_ptr + m_tile_start * k; |
| 180 | + const uint8_t* packed_ptr = static_cast<const uint8_t*>(packed_weights); |
| 181 | + |
| 182 | + for (int n_tile_start = 0; n_tile_start < n; n_tile_start += nr_) { |
| 183 | + float32x4_t accumulators[mr_][nr_ / 4] = {{vdupq_n_f32(0.0f)}}; |
| 184 | + |
| 185 | + uint8x16x4_t lut_neon; |
| 186 | + // Load the 16-float LUT for this tile. |
| 187 | + lut_utils::load_fp32_lut( |
| 188 | + lut_neon, reinterpret_cast<const float*>(packed_ptr)); |
| 189 | + // Advance the pointer past the LUT. |
| 190 | + packed_ptr += 16 * sizeof(float); |
| 191 | + float32x4_t scale_vec = vdupq_n_f32(1.0f); |
| 192 | + for (int k_tile_start = 0; k_tile_start < k; k_tile_start += kr_) { |
| 193 | + if constexpr (has_scales) { |
| 194 | + const float* scale_for_tile = nullptr; |
| 195 | + |
| 196 | + if (k_tile_start % scale_group_size == 0) { |
| 197 | + scale_for_tile = reinterpret_cast<const float*>(packed_ptr); |
| 198 | + scale_vec = vld1q_f32(scale_for_tile); |
| 199 | + packed_ptr += nr_ * sizeof(float); |
| 200 | + } |
| 201 | + } |
| 202 | + |
| 203 | + // The current packed_ptr points to the weight indices. |
| 204 | + const uint8_t* indices_ptr = packed_ptr; |
| 205 | + |
| 206 | + internal::compute_tile_1x4x32<weight_nbit_, has_scales>( |
| 207 | + accumulators[0][0], |
| 208 | + activation_row_ptr + k_tile_start, |
| 209 | + indices_ptr, |
| 210 | + lut_neon, |
| 211 | + scale_vec); |
| 212 | + |
| 213 | + // Advance pointer past the weights that were just used. |
| 214 | + packed_ptr += bytes_per_weight_tile; |
| 215 | + } |
| 216 | + |
| 217 | + const float* bias_for_tile = nullptr; |
| 218 | + if (has_bias) { |
| 219 | + bias_for_tile = reinterpret_cast<const float*>(packed_ptr); |
| 220 | + packed_ptr += nr_ * sizeof(float); |
| 221 | + } |
| 222 | + |
| 223 | + float* output_row_ptr = output + m_tile_start * output_m_stride; |
| 224 | + internal::post_process_and_store<mr_, nr_>( |
| 225 | + output_row_ptr, |
| 226 | + output_m_stride, |
| 227 | + n, |
| 228 | + n_tile_start, |
| 229 | + accumulators, |
| 230 | + bias_for_tile, |
| 231 | + has_clamp, |
| 232 | + clamp_min_vec, |
| 233 | + clamp_max_vec); |
| 234 | + } |
| 235 | + } |
| 236 | +} |
| 237 | +} // namespace |
| 238 | + // torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut::kernel |
| 239 | +#endif // defined(aarch64) || defined(__ARM_NEON) |
0 commit comments