|
5 | 5 | #include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
|
6 | 6 | #include <torchao/experimental/kernels/cpu/aarch64/macro.h>
|
7 | 7 | #include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
|
| 8 | +#include <torchao/experimental/kernels/cpu/aarch64/packing/utils.h> |
8 | 9 | #include <array>
|
9 | 10 | #include <cstring>
|
10 | 11 |
|
@@ -125,61 +126,6 @@ TORCHAO_ALWAYS_INLINE inline void unpack_buffer(
|
125 | 126 | assert(false);
|
126 | 127 | }
|
127 | 128 |
|
128 |
| -// Packs nr * kr values for GEMM with packing params (nr, kr, sr) |
129 |
| -// It takes (kr / sr) values from each of nr columns and writes to packed_values |
130 |
| -// This is repeated sr times |
131 |
| -template <typename T> |
132 |
| -void pack_values( |
133 |
| - // Output |
134 |
| - T* packed_values, |
135 |
| - // Inputs |
136 |
| - const T* values, |
137 |
| - int nr, |
138 |
| - int kr, |
139 |
| - int sr) { |
140 |
| - assert(kr % sr == 0); |
141 |
| - int kr_per_sr = kr / sr; |
142 |
| - int dst_idx = 0; |
143 |
| - for (int sr_idx = 0; sr_idx < sr; sr_idx++) { |
144 |
| - for (int n_idx = 0; n_idx < nr; n_idx++) { |
145 |
| - // Take kr_per_sr values from column n_idx |
146 |
| - std::memcpy( |
147 |
| - packed_values + dst_idx, |
148 |
| - values + n_idx * kr + sr_idx * kr_per_sr, |
149 |
| - sizeof(T) * kr_per_sr); |
150 |
| - dst_idx += kr_per_sr; |
151 |
| - } |
152 |
| - } |
153 |
| -} |
154 |
| - |
155 |
| -// Undoes pack_values |
156 |
| -template <typename T> |
157 |
| -void unpack_values( |
158 |
| - // Output |
159 |
| - T* values, |
160 |
| - // Inputs |
161 |
| - const T* packed_values, |
162 |
| - int nr, |
163 |
| - int kr, |
164 |
| - int sr) { |
165 |
| - // packed_values and values should have size nr * kr |
166 |
| - // This function takes (kr / sr) from each column of nr columns and writes to |
167 |
| - // output This is repeated sr times |
168 |
| - assert(kr % sr == 0); |
169 |
| - int kr_per_sr = kr / sr; |
170 |
| - int dst_idx = 0; |
171 |
| - for (int sr_idx = 0; sr_idx < sr; sr_idx++) { |
172 |
| - for (int n_idx = 0; n_idx < nr; n_idx++) { |
173 |
| - // Take kr_per_sr values from column n_idx |
174 |
| - std::memcpy( |
175 |
| - values + n_idx * kr + sr_idx * kr_per_sr, |
176 |
| - packed_values + dst_idx, |
177 |
| - sizeof(T) * kr_per_sr); |
178 |
| - dst_idx += kr_per_sr; |
179 |
| - } |
180 |
| - } |
181 |
| -} |
182 |
| - |
183 | 129 | // Size in bytes of 1 packed weights column
|
184 | 130 | size_t inline packed_weights_size_per_n(
|
185 | 131 | int k,
|
@@ -344,7 +290,7 @@ TORCHAO_ALWAYS_INLINE inline void pack_weights_impl(
|
344 | 290 | }
|
345 | 291 |
|
346 | 292 | // Pack buffer
|
347 |
| - internal::pack_values(packed_values, buffer.data(), nr, kr, sr); |
| 293 | + torchao::packing::pack_values(packed_values, buffer.data(), nr, kr, sr); |
348 | 294 | if constexpr (has_lut) {
|
349 | 295 | internal::pack_buffer_for_lut<weight_nbit, kr, nr>(
|
350 | 296 | packed_weights_byte_ptr, packed_values);
|
@@ -498,7 +444,7 @@ void unpack_weights_at_n_idx(
|
498 | 444 | internal::unpack_buffer<weight_nbit, kr, nr>(
|
499 | 445 | packed_values, packed_weights_byte_ptr);
|
500 | 446 | packed_weights_byte_ptr += packed_buffer_bytes;
|
501 |
| - internal::unpack_values(buffer.data(), packed_values, nr, kr, sr); |
| 447 | + torchao::packing::unpack_values(buffer.data(), packed_values, nr, kr, sr); |
502 | 448 |
|
503 | 449 | // Write weight_qvals
|
504 | 450 | for (int j = 0; j < nr; j++) {
|
|
0 commit comments