Skip to content

Commit dd7db9f

Browse files
authored
Use quantized gemm only on aarch64
Differential Revision: D72413684 Pull Request resolved: #2023
1 parent 3bbf42a commit dd7db9f

File tree

5 files changed

+14
-14
lines changed

5 files changed

+14
-14
lines changed

torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#pragma once
88

9-
#if defined(__aarch64__) || defined(__ARM_NEON)
9+
#if defined(__aarch64__) && defined(__ARM_NEON)
1010

1111
#include <algorithm>
1212
#include <cassert>
@@ -381,4 +381,4 @@ void kernel(
381381
} // namespace channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal
382382
} // namespace torchao::kernels::cpu::aarch64::quantized_matmul
383383

384-
#endif // defined(__aarch64__) || defined(__ARM_NEON)
384+
#endif // defined(__aarch64__) && defined(__ARM_NEON)

torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#pragma once
88

9-
#if defined(__aarch64__) || defined(__ARM_NEON)
9+
#if defined(__aarch64__) && defined(__ARM_NEON)
1010

1111
#include <algorithm>
1212
#include <cassert>
@@ -333,4 +333,4 @@ void kernel(
333333
} // namespace channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot
334334
} // namespace torchao::kernels::cpu::aarch64::quantized_matmul
335335

336-
#endif // defined(__aarch64__) || defined(__ARM_NEON)
336+
#endif // defined(__aarch64__) && defined(__ARM_NEON)

torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#pragma once
88

9-
#if defined(__aarch64__) || defined(__ARM_NEON)
9+
#if defined(__aarch64__) && defined(__ARM_NEON)
1010

1111
#include <algorithm>
1212
#include <cassert>
@@ -278,4 +278,4 @@ void kernel(
278278
} // namespace fp32_a_input_channelwise_8bit_b_1x16x4_f32
279279
} // namespace torchao::kernels::cpu::aarch64::quantized_matmul
280280

281-
#endif // defined(__aarch64__) || defined(__ARM_NEON)
281+
#endif // defined(__aarch64__) && defined(__ARM_NEON)

torchao/experimental/kernels/cpu/aarch64/matmul/matmul.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#pragma once
1212

13-
#if defined(__aarch64__) || defined(__ARM_NEON)
13+
#if defined(__aarch64__) && defined(__ARM_NEON)
1414

1515
#include <arm_neon.h>
1616

@@ -92,4 +92,4 @@ void kernel(
9292
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h>
9393
#include <torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h>
9494

95-
#endif // defined(__aarch64__) || defined(__ARM_NEON)
95+
#endif // defined(__aarch64__) && defined(__ARM_NEON)

torchao/experimental/kernels/cpu/interface/quantized_matmul.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
#include <torchao/experimental/kernels/cpu/fallback/matmul/channelwise_8bit_a_channelwise_8bit_b.h>
1212
#include <torchao/experimental/kernels/cpu/fallback/matmul/fp32_a_channelwise_8bit_b_fp32_c.h>
1313

14-
#if defined(__aarch64__) || defined(__ARM_NEON)
14+
#if defined(__aarch64__) && defined(__ARM_NEON)
1515
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x16x16_f32_smlal-impl.h>
1616
#include <torchao/experimental/kernels/cpu/aarch64/matmul/channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot-impl.h>
1717
#include <torchao/experimental/kernels/cpu/aarch64/matmul/fp32_a_input_channelwise_8bit_b_1x16x4_f32_impl.h>
18-
#endif // defined(__aarch64__) || defined(__ARM_NEON)
18+
#endif // defined(__aarch64__) && defined(__ARM_NEON)
1919

2020
namespace torchao::kernels::cpu::quantized_matmul {
2121

@@ -67,15 +67,15 @@ get_int8_a_int8_b_channelwise_qmatmul(
6767
bool b_transposed,
6868
int& a_stride_m,
6969
int& b_stride_n) {
70-
#if defined(__aarch64__) || defined(__ARM_NEON)
70+
#if defined(__aarch64__) && defined(__ARM_NEON)
7171
if (!a_transposed && b_transposed && n >= 8) {
7272
a_stride_m = k;
7373
b_stride_n = k;
7474
return aarch64::quantized_matmul::
7575
channelwise_8bit_a_channelwise_8bit_b_1x8x16_f32_neondot::
7676
kernel<true, true, false, true>;
7777
}
78-
#endif // defined(__aarch64__) || defined(__ARM_NEON)
78+
#endif // defined(__aarch64__) && defined(__ARM_NEON)
7979
assert(!a_transposed);
8080
if (b_transposed) {
8181
a_stride_m = k;
@@ -134,14 +134,14 @@ get_fp32_a_input_channelwise_8bit_b_f32_c_matmul(
134134
bool b_transposed,
135135
int& a_stride_m,
136136
int& b_stride_n) {
137-
#if defined(__aarch64__) || defined(__ARM_NEON)
137+
#if defined(__aarch64__) && defined(__ARM_NEON)
138138
if (!a_transposed && !b_transposed && n >= 16) {
139139
a_stride_m = k;
140140
b_stride_n = n;
141141
return aarch64::quantized_matmul::
142142
fp32_a_input_channelwise_8bit_b_1x16x4_f32::kernel<true, false, false>;
143143
}
144-
#endif // defined(__aarch64__) || defined(__ARM_NEON)
144+
#endif // defined(__aarch64__) && defined(__ARM_NEON)
145145
assert(!a_transposed);
146146
if (b_transposed) {
147147
a_stride_m = k;

0 commit comments

Comments
 (0)