6
6
7
7
#pragma once
8
8
#include < cpuinfo.h>
9
- // #include <glog/logging.h>
10
9
#include < torchao/experimental/ops/linear_8bit_act_xbit_weight/linear_8bit_act_xbit_weight.h>
11
10
#include < torchao/experimental/ops/packed_weights_header.h>
12
11
@@ -121,6 +120,21 @@ void check_format(PackedWeightsFormat format,
121
120
}
122
121
}
123
122
123
+ void log_registration (PackedWeightsFormat format, std::string description) {
124
+ // Logging is only supported in ATen mode
125
+ #ifdef USE_ATEN
126
+ LOG (INFO) << " Registering ukernel config for linear_8bit_act_xbit_weight" << std::endl
127
+ << " \t Description: " << description << std::endl
128
+ << " \t format.type=" << static_cast <int >(format.type ) << std::endl
129
+ << " \t format.weight_nbit=" << format.weight_nbit << std::endl
130
+ << " \t format.has_weight_zeros=" << format.has_weight_zeros << std::endl
131
+ << " \t format.has_bias=" << format.has_bias << std::endl
132
+ << " \t format.nr=" << format.nr << std::endl
133
+ << " \t format.kr=" << format.kr << std::endl
134
+ << " \t format.sr=" << format.sr << std::endl;
135
+ #endif // USE_ATEN
136
+ }
137
+
124
138
template <int weight_nbit, bool has_weight_zeros, bool has_bias, bool has_clamp>
125
139
void register_ukernel_config_universal (UKernelConfigRegistrationTable &table,
126
140
PackedWeightsFormat format,
@@ -135,6 +149,7 @@ void register_ukernel_config_universal(UKernelConfigRegistrationTable &table,
135
149
if (format.nr == 8 && format.kr == 16 && format.sr == 2 ) {
136
150
#if defined(TORCHAO_BUILD_CPU_AARCH64)
137
151
if (cpuinfo_has_arm_neon_dot ()) {
152
+ log_registration (format, " universal" );
138
153
namespace kernel = torchao::kernels::cpu::aarch64::linear::
139
154
channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot;
140
155
table.register_ukernel_config (
@@ -211,6 +226,7 @@ void register_ukernel_config_kleidi(UKernelConfigRegistrationTable &table,
211
226
#if defined(TORCHAO_ENABLE_ARM_I8MM)
212
227
if (cpuinfo_has_arm_i8mm ()) {
213
228
constexpr int n_step = 8 ;
229
+ log_registration (format, " kleidiai: matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm" );
214
230
table.register_ukernel_config (
215
231
format, uarch,
216
232
UKernelConfig{
@@ -228,6 +244,7 @@ void register_ukernel_config_kleidi(UKernelConfigRegistrationTable &table,
228
244
229
245
if (cpuinfo_has_arm_neon_dot ()) {
230
246
constexpr int n_step = 8 ;
247
+ log_registration (format, " kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod" );
231
248
table.register_ukernel_config (
232
249
format, uarch,
233
250
UKernelConfig{
@@ -249,6 +266,7 @@ void register_ukernel_config_kleidi(UKernelConfigRegistrationTable &table,
249
266
constexpr int sr = 2 ;
250
267
if (cpuinfo_has_arm_neon_dot ()) {
251
268
constexpr int n_step = 4 ;
269
+ log_registration (format, " kleidiai: matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod" );
252
270
table.register_ukernel_config (
253
271
format, uarch,
254
272
UKernelConfig{
0 commit comments