|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#if defined(__aarch64__) || defined(__ARM_NEON) |
| 4 | + |
| 5 | +#include <arm_neon.h> |
| 6 | +#include <stddef.h> |
| 7 | +#include <cassert> |
| 8 | +#include <stdexcept> |
| 9 | + |
| 10 | +#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/kernel_f32-impl.h> |
| 11 | +#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/pack_activations.h> |
| 12 | +#include <torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/pack_weights.h> |
| 13 | + |
| 14 | +namespace torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut { |
| 15 | + |
| 16 | +/** |
| 17 | + * @brief Calculates the total size in bytes required for the packed weight. |
| 18 | + * |
| 19 | + * @param m The number of rows in the source activation matrix. |
| 20 | + * @param k The number of columns in the source activation matrix. |
| 21 | + * @param mr The row-tiling factor of the micro-kernel. |
| 22 | + * @param kr The column-tiling factor of the micro-kernel. |
| 23 | + * @param sr The split ratio of the micro-kernel. |
| 24 | + */ |
| 25 | +inline size_t packed_activations_size(int m, int k, int mr, int kr, int sr) { |
| 26 | + (void)mr; // unused |
| 27 | + (void)kr; // unused |
| 28 | + (void)sr; // unused |
| 29 | + return activation_packing::packed_activations_size(m, k); |
| 30 | +} |
| 31 | + |
| 32 | +/** |
| 33 | + * @brief Packs a row-major activation matrix into a kernel-optimized blocked |
| 34 | +layout. |
| 35 | + * |
| 36 | + * @tparam mr_ The row-tiling factor of the micro-kernel (Currently only have |
| 37 | +1). |
| 38 | + * @tparam kr_ The column-tiling factor of the micro-kernel (e.g., 32). |
| 39 | + * @tparam sr_ Split ratio determine how the k dimension of a weight tile is |
| 40 | +chunked and interleaved during the packing process. |
| 41 | + * @param output Pointer to the destination buffer. |
| 42 | + * @param m The number of rows in the source activation matrix. |
| 43 | + * @param k The number of columns in the source activation matrix. |
| 44 | + * @param input Pointer to the source activation matrix (float32, row-major). |
| 45 | + */ |
| 46 | +template <int mr_, int kr_, int sr_> |
| 47 | +inline void pack_activations(float* output, int m, int k, const float* input) { |
| 48 | + activation_packing::pack_activations<mr_, kr_, sr_>(output, m, k, input); |
| 49 | +} |
| 50 | + |
| 51 | +/** |
| 52 | + * @brief Calculates the total size in bytes required for the packed weight |
| 53 | + * buffer for the groupwise LUT kernel format. |
| 54 | + * |
| 55 | + * @param n The number of columns in the weight matrix. |
| 56 | + * @param k The number of rows in the weight matrix. |
| 57 | + * @param weight_nbit The number of bits per weight (e.g., 2, 3, 4). |
| 58 | + * @param scale_group_size The number of weights along the K dim that share a |
| 59 | + * scale factor. |
| 60 | + * @param has_scales If true, the packed buffer will contain scale factors. |
| 61 | + * @param has_bias If true, the packed buffer will contain bias terms. |
| 62 | + * @param nr The column-tiling factor for the kernel (e.g., 16). |
| 63 | + * @param kr The column-tiling factor for the kernel (e.g., 16). |
| 64 | + * @param sr The split ratio of the micro-kernel. |
| 65 | + * @return The total required size of the packed buffer in bytes. |
| 66 | + */ |
| 67 | +inline size_t packed_weights_size( |
| 68 | + int n, |
| 69 | + int k, |
| 70 | + int weight_nbit, |
| 71 | + int scale_group_size, |
| 72 | + bool has_scales, |
| 73 | + bool has_bias, |
| 74 | + int nr, |
| 75 | + int kr, |
| 76 | + int sr) { |
| 77 | + (void)sr; // unused |
| 78 | + return weight_packing::packed_weights_size( |
| 79 | + n, k, weight_nbit, scale_group_size, has_scales, has_bias, nr, kr); |
| 80 | +} |
| 81 | + |
| 82 | +/** |
| 83 | + * @brief Packs weights, LUTs, scales and bias into a kernel-optimized format. |
| 84 | + * @tparam weight_nbit_ The true bit-width of the weights. |
| 85 | + * @tparam nr_ The column-tiling factor for the kernel (e.g., 4). |
| 86 | + * @tparam kr_ The column-tiling factor of the micro-kernel (e.g., 32). |
| 87 | + * @tparam sr_ Split ratio determine how the k dimension of a weight tile is |
| 88 | +chunked and interleaved during the packing process. |
| 89 | + * @param packed_weights_ptr Pointer to the destination buffer. |
| 90 | + * @param weight_qvals_indices Pointer to the quantized weight matrix (uint8, |
| 91 | +row-major). |
| 92 | + * @param weight_scales Pointer to the scale factors (float32, row-major). |
| 93 | + * @param weight_luts Pointer to the LUTs (float32, row-major). |
| 94 | + * @param n The number of columns in the weight matrix. |
| 95 | + * @param k The number of rows in the weight matrix. |
| 96 | + * @param scale_group_size The number of weights that share a scale factor. |
| 97 | + * @param lut_group_size The number of weights that share a LUT. |
| 98 | + * @param has_scales If true, the packed buffer will contain scale factors. |
| 99 | + * @param has_bias If true, the packed buffer will contain bias terms. |
| 100 | + * @param bias Pointer to the bias vector (float32, row-major). |
| 101 | + */ |
| 102 | +template <int weight_nbit_, int nr_, int kr_, int sr_> |
| 103 | +void pack_weights_for_groupwise_lut_kernel( |
| 104 | + /*output*/ |
| 105 | + void* packed_weights_ptr, |
| 106 | + /*inputs*/ |
| 107 | + const uint8_t* weight_qvals_indices, |
| 108 | + const float* weight_scales, |
| 109 | + const float* weight_luts, |
| 110 | + int n, |
| 111 | + int k, |
| 112 | + int scale_group_size, |
| 113 | + int lut_group_size, |
| 114 | + bool has_scales, |
| 115 | + bool has_bias, |
| 116 | + const float* bias) { |
| 117 | + weight_packing::pack_weights<weight_nbit_, nr_, kr_, sr_>( |
| 118 | + packed_weights_ptr, |
| 119 | + weight_qvals_indices, |
| 120 | + weight_scales, |
| 121 | + weight_luts, |
| 122 | + n, |
| 123 | + k, |
| 124 | + scale_group_size, |
| 125 | + lut_group_size, |
| 126 | + has_scales, |
| 127 | + has_bias, |
| 128 | + bias); |
| 129 | +} |
| 130 | + |
| 131 | +/** |
| 132 | + * @brief Computes a group-wise low-bit GEMM using an optimized NEON kernel. |
| 133 | + * |
| 134 | + * This function selects the best available micro-kernel based on the provided |
| 135 | + * tile sizes (MR and NR) and dispatches the computation. |
| 136 | + * @tparam weight_nbit_ The true bit-width of the weights (e.g., 2, 3, 4). |
| 137 | + * @tparam has_scales_ If true, applies the scales. |
| 138 | + * @param output Pointer to the output matrix C. |
| 139 | + * @param output_m_stride The stride (in elements) between rows of the output |
| 140 | + * matrix. |
| 141 | + * @param m Number of rows in A and C. |
| 142 | + * @param n Number of columns in B and C. |
| 143 | + * @param k Number of columns in A and rows in B. |
| 144 | + * @param scale_group_size The grouping factor for scales. |
| 145 | + * @param lut_group_size The grouping factor for LUTs. |
| 146 | + * @param packed_weights Pointer to the pre-packed weight buffer. |
| 147 | + * @param packed_activations Pointer to the pre-packed activation buffer. |
| 148 | + * @param biases Pointer to the bias vector. |
| 149 | + * @param clamp_min Minimum value for the fused clamp (ReLU) operation. |
| 150 | + * @param clamp_max Maximum value for the fused clamp (ReLU6) operation. |
| 151 | + * @param has_bias If true, applies the bias. |
| 152 | + * @param has_clamp If true, applies the clamping. |
| 153 | + */ |
| 154 | +template <int weight_nbit_, bool has_scales_> |
| 155 | +inline void groupwise_lowbit_weight_lut_kernel_1x4x32( |
| 156 | + float* output, |
| 157 | + int output_m_stride, |
| 158 | + int m, |
| 159 | + int n, |
| 160 | + int k, |
| 161 | + int scale_group_size, |
| 162 | + int lut_group_size, |
| 163 | + const void* packed_weights, |
| 164 | + const void* packed_activations, |
| 165 | + float clamp_min, |
| 166 | + float clamp_max, |
| 167 | + bool has_bias, |
| 168 | + bool has_clamp) { |
| 169 | + kernel::groupwise_lowbit_weight_lut_kernel_1x4x32<weight_nbit_, has_scales_>( |
| 170 | + output, |
| 171 | + output_m_stride, |
| 172 | + m, |
| 173 | + n, |
| 174 | + k, |
| 175 | + scale_group_size, |
| 176 | + lut_group_size, |
| 177 | + packed_weights, |
| 178 | + packed_activations, |
| 179 | + clamp_min, |
| 180 | + clamp_max, |
| 181 | + has_bias, |
| 182 | + has_clamp); |
| 183 | +} |
| 184 | +} // namespace |
| 185 | + // torchao::kernels::cpu::aarch64::linear::groupwise_lowbit_weight_lut |
| 186 | + |
| 187 | +#endif // defined(__aarch64__) || defined(__ARM_NEON) |
0 commit comments