Skip to content

Commit 773b8f6

Browse files
authored
Move pack functions to general location to share code
Differential Revision: D77040219 Pull Request resolved: #2480
1 parent 2defe30 commit 773b8f6

File tree

2 files changed

+70
-57
lines changed
  • torchao/experimental/kernels/cpu/aarch64

2 files changed

+70
-57
lines changed

torchao/experimental/kernels/cpu/aarch64/linear/channelwise_8bit_activation_groupwise_lowbit_weight/pack_weights.h

Lines changed: 3 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
66
#include <torchao/experimental/kernels/cpu/aarch64/macro.h>
77
#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
8+
#include <torchao/experimental/kernels/cpu/aarch64/packing/utils.h>
89
#include <array>
910
#include <cstring>
1011

@@ -125,61 +126,6 @@ TORCHAO_ALWAYS_INLINE inline void unpack_buffer(
125126
assert(false);
126127
}
127128

128-
// Packs nr * kr values for GEMM with packing params (nr, kr, sr)
129-
// It takes (kr / sr) values from each of nr columns and writes to packed_values
130-
// This is repeated sr times
131-
template <typename T>
132-
void pack_values(
133-
// Output
134-
T* packed_values,
135-
// Inputs
136-
const T* values,
137-
int nr,
138-
int kr,
139-
int sr) {
140-
assert(kr % sr == 0);
141-
int kr_per_sr = kr / sr;
142-
int dst_idx = 0;
143-
for (int sr_idx = 0; sr_idx < sr; sr_idx++) {
144-
for (int n_idx = 0; n_idx < nr; n_idx++) {
145-
// Take kr_per_sr values from column n_idx
146-
std::memcpy(
147-
packed_values + dst_idx,
148-
values + n_idx * kr + sr_idx * kr_per_sr,
149-
sizeof(T) * kr_per_sr);
150-
dst_idx += kr_per_sr;
151-
}
152-
}
153-
}
154-
155-
// Undoes pack_values
156-
template <typename T>
157-
void unpack_values(
158-
// Output
159-
T* values,
160-
// Inputs
161-
const T* packed_values,
162-
int nr,
163-
int kr,
164-
int sr) {
165-
// packed_values and values should have size nr * kr
166-
// This function takes (kr / sr) from each column of nr columns and writes to
167-
// output This is repeated sr times
168-
assert(kr % sr == 0);
169-
int kr_per_sr = kr / sr;
170-
int dst_idx = 0;
171-
for (int sr_idx = 0; sr_idx < sr; sr_idx++) {
172-
for (int n_idx = 0; n_idx < nr; n_idx++) {
173-
// Take kr_per_sr values from column n_idx
174-
std::memcpy(
175-
values + n_idx * kr + sr_idx * kr_per_sr,
176-
packed_values + dst_idx,
177-
sizeof(T) * kr_per_sr);
178-
dst_idx += kr_per_sr;
179-
}
180-
}
181-
}
182-
183129
// Size in bytes of 1 packed weights column
184130
size_t inline packed_weights_size_per_n(
185131
int k,
@@ -344,7 +290,7 @@ TORCHAO_ALWAYS_INLINE inline void pack_weights_impl(
344290
}
345291

346292
// Pack buffer
347-
internal::pack_values(packed_values, buffer.data(), nr, kr, sr);
293+
torchao::packing::pack_values(packed_values, buffer.data(), nr, kr, sr);
348294
if constexpr (has_lut) {
349295
internal::pack_buffer_for_lut<weight_nbit, kr, nr>(
350296
packed_weights_byte_ptr, packed_values);
@@ -498,7 +444,7 @@ void unpack_weights_at_n_idx(
498444
internal::unpack_buffer<weight_nbit, kr, nr>(
499445
packed_values, packed_weights_byte_ptr);
500446
packed_weights_byte_ptr += packed_buffer_bytes;
501-
internal::unpack_values(buffer.data(), packed_values, nr, kr, sr);
447+
torchao::packing::unpack_values(buffer.data(), packed_values, nr, kr, sr);
502448

503449
// Write weight_qvals
504450
for (int j = 0; j < nr; j++) {
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include <cassert>
8+
#include <cstring>
9+
10+
namespace torchao::packing {
11+
12+
// Packs nr * kr values for GEMM with packing params (nr, kr, sr)
13+
// It takes (kr / sr) values from each of nr columns and writes to packed_values
14+
// This is repeated sr times
15+
template <typename T>
16+
void pack_values(
17+
// Output
18+
T* packed_values,
19+
// Inputs
20+
const T* values,
21+
int nr,
22+
int kr,
23+
int sr) {
24+
assert(kr % sr == 0);
25+
int kr_per_sr = kr / sr;
26+
int dst_idx = 0;
27+
for (int sr_idx = 0; sr_idx < sr; sr_idx++) {
28+
for (int n_idx = 0; n_idx < nr; n_idx++) {
29+
// Take kr_per_sr values from column n_idx
30+
std::memcpy(
31+
packed_values + dst_idx,
32+
values + n_idx * kr + sr_idx * kr_per_sr,
33+
sizeof(T) * kr_per_sr);
34+
dst_idx += kr_per_sr;
35+
}
36+
}
37+
}
38+
39+
// Undoes pack_values
40+
template <typename T>
41+
void unpack_values(
42+
// Output
43+
T* values,
44+
// Inputs
45+
const T* packed_values,
46+
int nr,
47+
int kr,
48+
int sr) {
49+
// packed_values and values should have size nr * kr
50+
// This function takes (kr / sr) from each column of nr columns and writes to
51+
// output This is repeated sr times
52+
assert(kr % sr == 0);
53+
int kr_per_sr = kr / sr;
54+
int dst_idx = 0;
55+
for (int sr_idx = 0; sr_idx < sr; sr_idx++) {
56+
for (int n_idx = 0; n_idx < nr; n_idx++) {
57+
// Take kr_per_sr values from column n_idx
58+
std::memcpy(
59+
values + n_idx * kr + sr_idx * kr_per_sr,
60+
packed_values + dst_idx,
61+
sizeof(T) * kr_per_sr);
62+
dst_idx += kr_per_sr;
63+
}
64+
}
65+
}
66+
67+
} // namespace torchao::packing

0 commit comments

Comments
 (0)