Skip to content

Commit 25377e0

Browse files
authored
Migrate to config for Int8DynamicActivationIntxWeightConfig (#1836)
* init * up * up * up * up * up * lint * up * lint
1 parent 24c966c commit 25377e0

File tree

6 files changed

+445
-373
lines changed

6 files changed

+445
-373
lines changed

.github/workflows/torchao_experimental_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ jobs:
3636
pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu"
3737
pip install numpy
3838
pip install pytest
39+
pip install parameterized
3940
USE_CPP=1 pip install .
4041
- name: Run python tests
4142
run: |

torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
#pragma once
88
#include <cpuinfo.h>
9-
// #include <glog/logging.h>
109
#include <torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
1110
#include <torchao/experimental/ops/packed_weights_header.h>
1211

@@ -121,6 +120,21 @@ void check_format(PackedWeightsFormat format,
121120
}
122121
}
123122

123+
void log_registration(PackedWeightsFormat format, std::string description) {
124+
// Logging is only supported in ATen mode
125+
#ifdef USE_ATEN
126+
LOG(INFO) << "Registering ukernel config for linear_8bit_act_xbit_weight" << std::endl
127+
<< "\tDescription: " << description << std::endl
128+
<< "\tformat.type=" << static_cast<int>(format.type) << std::endl
129+
<< "\tformat.weight_nbit=" << format.weight_nbit << std::endl
130+
<< "\tformat.has_weight_zeros=" << format.has_weight_zeros << std::endl
131+
<< "\tformat.has_bias=" << format.has_bias << std::endl
132+
<< "\tformat.nr=" << format.nr << std::endl
133+
<< "\tformat.kr=" << format.kr << std::endl
134+
<< "\tformat.sr=" << format.sr << std::endl;
135+
#endif // USE_ATEN
136+
}
137+
124138
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
125139
void register_ukernel_config_universal(UKernelConfigRegistrationTable &table,
126140
PackedWeightsFormat format,
@@ -135,6 +149,7 @@ void register_ukernel_config_universal(UKernelConfigRegistrationTable &table,
135149
if (format.nr == 8 && format.kr == 16 && format.sr == 2) {
136150
#if defined(TORCHAO_BUILD_CPU_AARCH64)
137151
if (cpuinfo_has_arm_neon_dot()) {
152+
log_registration(format, "universal");
138153
namespace kernel = torchao::kernels::cpu::aarch64::linear::
139154
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;
140155
table.register_ukernel_config(
@@ -211,6 +226,7 @@ void register_ukernel_config_kleidi(UKernelConfigRegistrationTable &table,
211226
#if defined(TORCHAO_ENABLE_ARM_I8MM)
212227
if (cpuinfo_has_arm_i8mm()) {
213228
constexpr int n_step = 8;
229+
log_registration(format, "kleidiai: matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm");
214230
table.register_ukernel_config(
215231
format, uarch,
216232
UKernelConfig{
@@ -228,6 +244,7 @@ void register_ukernel_config_kleidi(UKernelConfigRegistrationTable &table,
228244

229245
if (cpuinfo_has_arm_neon_dot()) {
230246
constexpr int n_step = 8;
247+
log_registration(format, "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod");
231248
table.register_ukernel_config(
232249
format, uarch,
233250
UKernelConfig{
@@ -249,6 +266,7 @@ void register_ukernel_config_kleidi(UKernelConfigRegistrationTable &table,
249266
constexpr int sr = 2;
250267
if (cpuinfo_has_arm_neon_dot()) {
251268
constexpr int n_step = 4;
269+
log_registration(format, "kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod");
252270
table.register_ukernel_config(
253271
format, uarch,
254272
UKernelConfig{

0 commit comments

Comments
 (0)