Skip to content

Commit 9da7ad5

Browse files
authored
Update function params and corresponding usages.
Differential Revision: D78056221 Pull Request resolved: #2524
1 parent 0b62f3f commit 9da7ad5

File tree

3 files changed

+30
-16
lines changed

3 files changed

+30
-16
lines changed

torchao/experimental/kernels/cpu/aarch64/linear/groupwise_lowbit_weight/groupwise_lowbit_weight_lut.h

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,17 @@ chunked and interleaved during the packing process.
4444
* @param input Pointer to the source activation matrix (float32, row-major).
4545
*/
4646
template <int mr_, int kr_, int sr_>
47-
inline void pack_activations(float* output, int m, int k, const float* input) {
47+
inline void pack_activations(
48+
float* output,
49+
int m,
50+
int k,
51+
const float* input,
52+
int mr,
53+
int kr,
54+
int sr) {
55+
(void)mr; // unused
56+
(void)kr; // unused
57+
(void)sr; // unused
4858
activation_packing::pack_activations<mr_, kr_, sr_>(output, m, k, input);
4959
}
5060

@@ -100,7 +110,7 @@ row-major).
100110
* @param bias Pointer to the bias vector (float32, row-major).
101111
*/
102112
template <int weight_nbit_, int nr_, int kr_, int sr_>
103-
void pack_weights_for_groupwise_lut_kernel(
113+
void pack_weights(
104114
/*output*/
105115
void* packed_weights_ptr,
106116
/*inputs*/
@@ -113,7 +123,14 @@ void pack_weights_for_groupwise_lut_kernel(
113123
int lut_group_size,
114124
bool has_scales,
115125
bool has_bias,
116-
const float* bias) {
126+
const float* bias,
127+
int nr,
128+
int kr,
129+
int sr) {
130+
(void)nr; // unused
131+
(void)kr; // unused
132+
(void)sr; // unused
133+
117134
weight_packing::pack_weights<weight_nbit_, nr_, kr_, sr_>(
118135
packed_weights_ptr,
119136
weight_qvals_indices,
@@ -190,7 +207,12 @@ inline void groupwise_lowbit_weight_lut_kernel_1x4x32(
190207
* @param k The K dimension (width) of the activation matrix.
191208
* @return The byte offset from the start of the buffer.
192209
*/
193-
inline size_t packed_activations_offset(int m_idx, int k) {
210+
inline size_t
211+
packed_activations_offset(int m_idx, int k, int mr, int kr, int sr) {
212+
(void)mr; // unused
213+
(void)kr; // unused
214+
(void)sr; // unused
215+
194216
// For a simple padded row-major format, the offset is just m_idx * k.
195217
return sizeof(float) * m_idx * k;
196218
}

torchao/experimental/kernels/cpu/aarch64/tests/test_lut.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ void test_groupwise_lowbit_lut_kernel(
7171
std::vector<float> packed_activations_buffer(
7272
kernel_api::packed_activations_size(m, k, mr_, kr_, sr_));
7373
kernel_api::pack_activations<mr_, kr_, sr_>(
74-
packed_activations_buffer.data(), m, k, source_activations.data());
74+
packed_activations_buffer.data(), m, k, source_activations.data(), mr_, kr_, sr_);
7575
// 3. Pack Weights
7676
std::vector<char> packed_weights(kernel_api::packed_weights_size(
7777
n,
@@ -84,7 +84,7 @@ void test_groupwise_lowbit_lut_kernel(
8484
kr_,
8585
sr_));
8686
kernel_api::
87-
pack_weights_for_groupwise_lut_kernel<weight_nbit_, nr_, kr_, sr_>(
87+
pack_weights<weight_nbit_, nr_, kr_, sr_>(
8888
packed_weights.data(),
8989
test_case.weight_qval_indices.data(),
9090
test_case.weight_scales.data(),
@@ -95,7 +95,7 @@ void test_groupwise_lowbit_lut_kernel(
9595
flat_lut_group_size,
9696
has_scales_,
9797
has_bias,
98-
test_case.bias.data());
98+
test_case.bias.data(), nr_, kr_, sr_);
9999

100100
// 4. Run the kernel
101101
std::vector<float> output(m * n);

torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -640,11 +640,10 @@ struct groupwise_lowbit_weight_lut_test_case {
640640
const int total_weights = n * k;
641641
// Frequencies are controlled by their group sizes.
642642
assert(total_weights % scale_group_size == 0);
643-
assert(total_weights % lut_group_size == 0);
644643

645644
// The number of unique scales/LUTs is derived directly from their group size.
646645
const int num_scales = total_weights / scale_group_size;
647-
const int num_luts = total_weights / lut_group_size;
646+
const int num_luts = (total_weights + lut_group_size - 1) / lut_group_size;
648647
const int lut_size = 1 << weight_nbit;
649648
std::mt19937 gen(std::random_device{}());
650649

@@ -726,9 +725,6 @@ struct groupwise_lowbit_weight_lut_test_case {
726725
int weight_nbit, bool has_scales,
727726
bool has_bias, bool has_clamp) {
728727

729-
std::cout << "[Generator Info] Using 'Per-Group' model.\n"
730-
<< " - Both scales and LUTs will switch every " << group_size << " weights." << std::endl;
731-
732728
// Just call the decoupled generator with the same group size for both.
733729
return _generate_master(
734730
m, k, n,
@@ -748,10 +744,6 @@ struct groupwise_lowbit_weight_lut_test_case {
748744
int scale_group_size, int lut_group_size, int weight_nbit, bool has_scales,
749745
bool has_bias, bool has_clamp) {
750746

751-
std::cout << "[Generator Info] Using 'Decoupled Grouping' model.\n"
752-
<< " - Scales will switch every " << scale_group_size << " weights.\n"
753-
<< " - LUTs will switch every " << lut_group_size << " weights." << std::endl;
754-
755747
return _generate_master(
756748
m, k, n,
757749
scale_group_size, lut_group_size,

0 commit comments

Comments
 (0)