Skip to content

Commit 9259584

Browse files
authored
Add bias support to torchao kernels (#1879)
Add bias support to torchao kernels (#1879) Summary: Pull Request resolved: #1879 This diff adds bias support for torchao kernels / quantizer. Differential Revision: D71093679
1 parent 81f0bf2 commit 9259584

18 files changed

+2468
-1401
lines changed

torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_linear.cpp

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,30 +34,38 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot(
3434
has_clamp);
3535

3636
std::vector<char> activation_data(
37-
activation_data_size<has_weight_zeros>(m, k, group_size));
38-
prepare_activation_data<has_weight_zeros>(
37+
activation_data_size(m, k, group_size, has_weight_zeros));
38+
prepare_activation_data(
3939
(void*)activation_data.data(),
4040
m,
4141
k,
4242
group_size,
43-
test_case.activations.data());
43+
test_case.activations.data(),
44+
has_weight_zeros);
4445

45-
std::vector<char> weight_data(
46-
weight_data_size<weight_nbit, has_weight_zeros, has_bias>(
47-
n, k, group_size));
48-
prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>(
46+
std::vector<char> weight_data(weight_data_size<weight_nbit>(
47+
n, k, group_size, has_weight_zeros, has_bias));
48+
int8_t* weight_zeros_ptr = nullptr;
49+
if (has_weight_zeros) {
50+
weight_zeros_ptr = test_case.weight_zeros.data();
51+
}
52+
float* bias_ptr = nullptr;
53+
if (has_bias) {
54+
bias_ptr = test_case.bias.data();
55+
}
56+
prepare_weight_data<weight_nbit>(
4957
(void*)weight_data.data(),
5058
n,
5159
k,
5260
group_size,
5361
test_case.weight_qvals.data(),
5462
test_case.weight_scales.data(),
55-
test_case.weight_zeros.data(),
56-
test_case.bias.data());
63+
weight_zeros_ptr,
64+
bias_ptr);
5765

5866
std::vector<float> output(m * k);
5967
for (auto _ : state) {
60-
kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>(
68+
kernel<weight_nbit>(
6169
output.data(),
6270
/*output_m_stride=*/n,
6371
m,
@@ -67,7 +75,10 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x1x32_f32_neondot(
6775
weight_data.data(),
6876
activation_data.data(),
6977
test_case.clamp_min,
70-
test_case.clamp_max);
78+
test_case.clamp_max,
79+
has_weight_zeros,
80+
has_bias,
81+
has_clamp);
7182
}
7283
}
7384

@@ -95,30 +106,38 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot(
95106
has_clamp);
96107

97108
std::vector<char> activation_data(
98-
activation_data_size<has_weight_zeros>(m, k, group_size));
99-
prepare_activation_data<has_weight_zeros>(
109+
activation_data_size(m, k, group_size, has_weight_zeros));
110+
prepare_activation_data(
100111
(void*)activation_data.data(),
101112
m,
102113
k,
103114
group_size,
104-
test_case.activations.data());
115+
test_case.activations.data(),
116+
has_weight_zeros);
105117

106-
std::vector<char> weight_data(
107-
weight_data_size<weight_nbit, has_weight_zeros, has_bias>(
108-
n, k, group_size));
109-
prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>(
118+
std::vector<char> weight_data(weight_data_size<weight_nbit>(
119+
n, k, group_size, has_weight_zeros, has_bias));
120+
int8_t* weight_zeros_ptr = nullptr;
121+
if (has_weight_zeros) {
122+
weight_zeros_ptr = test_case.weight_zeros.data();
123+
}
124+
float* bias_ptr = nullptr;
125+
if (has_bias) {
126+
bias_ptr = test_case.bias.data();
127+
}
128+
prepare_weight_data<weight_nbit>(
110129
(void*)weight_data.data(),
111130
n,
112131
k,
113132
group_size,
114133
test_case.weight_qvals.data(),
115134
test_case.weight_scales.data(),
116-
test_case.weight_zeros.data(),
117-
test_case.bias.data());
135+
weight_zeros_ptr,
136+
bias_ptr);
118137

119138
std::vector<float> output(m * k);
120139
for (auto _ : state) {
121-
kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>(
140+
kernel<weight_nbit>(
122141
output.data(),
123142
/*output_m_stride=*/n,
124143
m,
@@ -128,7 +147,10 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x4x16_f32_neondot(
128147
weight_data.data(),
129148
activation_data.data(),
130149
test_case.clamp_min,
131-
test_case.clamp_max);
150+
test_case.clamp_max,
151+
has_weight_zeros,
152+
has_bias,
153+
has_clamp);
132154
}
133155
}
134156

@@ -156,30 +178,38 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot(
156178
has_clamp);
157179

158180
std::vector<char> activation_data(
159-
activation_data_size<has_weight_zeros>(m, k, group_size));
160-
prepare_activation_data<has_weight_zeros>(
181+
activation_data_size(m, k, group_size, has_weight_zeros));
182+
prepare_activation_data(
161183
(void*)activation_data.data(),
162184
m,
163185
k,
164186
group_size,
165-
test_case.activations.data());
187+
test_case.activations.data(),
188+
has_weight_zeros);
166189

167-
std::vector<char> weight_data(
168-
weight_data_size<weight_nbit, has_weight_zeros, has_bias>(
169-
n, k, group_size));
170-
prepare_weight_data<weight_nbit, has_weight_zeros, has_bias>(
190+
std::vector<char> weight_data(weight_data_size<weight_nbit>(
191+
n, k, group_size, has_weight_zeros, has_bias));
192+
int8_t* weight_zeros_ptr = nullptr;
193+
if (has_weight_zeros) {
194+
weight_zeros_ptr = test_case.weight_zeros.data();
195+
}
196+
float* bias_ptr = nullptr;
197+
if (has_bias) {
198+
bias_ptr = test_case.bias.data();
199+
}
200+
prepare_weight_data<weight_nbit>(
171201
(void*)weight_data.data(),
172202
n,
173203
k,
174204
group_size,
175205
test_case.weight_qvals.data(),
176206
test_case.weight_scales.data(),
177-
test_case.weight_zeros.data(),
178-
test_case.bias.data());
207+
weight_zeros_ptr,
208+
bias_ptr);
179209

180210
std::vector<float> output(m * k);
181211
for (auto _ : state) {
182-
kernel<weight_nbit, has_weight_zeros, has_bias, has_clamp>(
212+
kernel<weight_nbit>(
183213
output.data(),
184214
/*output_m_stride=*/n,
185215
m,
@@ -189,7 +219,10 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot(
189219
weight_data.data(),
190220
activation_data.data(),
191221
test_case.clamp_min,
192-
test_case.clamp_max);
222+
test_case.clamp_max,
223+
has_weight_zeros,
224+
has_bias,
225+
has_clamp);
193226
}
194227
}
195228

0 commit comments

Comments
 (0)