Skip to content

Commit 4300079

Browse files
authored
Add support for KleidiAI int4 kernels on aarch64 Linux (#2169)
* Debugging ARM Neoverse-V1 * add comment to cmake * Debug NEOVERSE ARM * remove useless comments * clean * clean * debug * clean * load multiple potential paths * remove assertion * re-introduce assertion and define load_libtorchao_ops fn * add unit test to ensure * Ready for merge * last test * fix * PR feedbacks * debug * add comments * add ENABLE_ARM_NEON in build_torchao_ops
1 parent 720a177 commit 4300079

File tree

9 files changed

+175
-22
lines changed

9 files changed

+175
-22
lines changed

setup.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def read_version(file_path="version.txt"):
4949

5050
import platform
5151

52-
build_torchao_experimental = (
52+
build_macos_arm_auto = (
5353
use_cpp == "1"
5454
and platform.machine().startswith("arm64")
5555
and platform.system() == "Darwin"
@@ -119,8 +119,33 @@ def __init__(self):
119119
"TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
120120
)
121121

122+
# TORCHAO_PARALLEL_BACKEND specifies which parallel backend to use
123+
# Possible values: aten_openmp, executorch, openmp, pthreadpool, single_threaded
124+
self.parallel_backend = os.getenv("TORCHAO_PARALLEL_BACKEND", "aten_openmp")
125+
126+
# TORCHAO_ENABLE_ARM_NEON_DOT enable ARM NEON Dot Product extension
127+
# Enabled by default on macOS silicon
128+
self.enable_arm_neon_dot = self._os_bool_var(
129+
"TORCHAO_ENABLE_ARM_NEON_DOT",
130+
default=(self._is_arm64() and self._is_macos()),
131+
)
132+
if self.enable_arm_neon_dot:
133+
assert self.build_cpu_aarch64, (
134+
"TORCHAO_ENABLE_ARM_NEON_DOT requires TORCHAO_BUILD_CPU_AARCH64 be set"
135+
)
136+
137+
# TORCHAO_ENABLE_ARM_I8MM enable ARM 8-bit Integer Matrix Multiply instructions
138+
# Not enabled by default on macOS as not all silicon mac supports it
139+
self.enable_arm_i8mm = self._os_bool_var(
140+
"TORCHAO_ENABLE_ARM_I8MM", default=False
141+
)
142+
if self.enable_arm_i8mm:
143+
assert self.build_cpu_aarch64, (
144+
"TORCHAO_ENABLE_ARM_I8MM requires TORCHAO_BUILD_CPU_AARCH64 be set"
145+
)
146+
122147
def _is_arm64(self) -> bool:
123-
return platform.machine().startswith("arm64")
148+
return platform.machine().startswith("arm64") or platform.machine() == "aarch64"
124149

125150
def _is_macos(self) -> bool:
126151
return platform.system() == "Darwin"
@@ -431,7 +456,8 @@ def get_extensions():
431456
)
432457
)
433458

434-
if build_torchao_experimental:
459+
# Build CMakeLists from /torchao/experimental - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
460+
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":
435461
build_options = BuildOptions()
436462

437463
def bool_to_on_off(value):
@@ -451,6 +477,9 @@ def bool_to_on_off(value):
451477
f"-DTORCHAO_BUILD_CPU_AARCH64={bool_to_on_off(build_options.build_cpu_aarch64)}",
452478
f"-DTORCHAO_BUILD_KLEIDIAI={bool_to_on_off(build_options.build_kleidi_ai)}",
453479
f"-DTORCHAO_BUILD_MPS_OPS={bool_to_on_off(build_options.build_experimental_mps)}",
480+
f"-DTORCHAO_ENABLE_ARM_NEON_DOT={bool_to_on_off(build_options.enable_arm_neon_dot)}",
481+
f"-DTORCHAO_ENABLE_ARM_I8MM={bool_to_on_off(build_options.enable_arm_i8mm)}",
482+
f"-DTORCHAO_PARALLEL_BACKEND={build_options.parallel_backend}",
454483
"-DTorch_DIR=" + torch_dir,
455484
]
456485
+ (

torchao/experimental/CMakeLists.txt

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@ if (NOT CMAKE_BUILD_TYPE)
1515
set(CMAKE_BUILD_TYPE Release)
1616
endif()
1717

18+
# Platform options
1819
option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF)
1920
option(TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF)
2021
option(TORCHAO_BUILD_CPU_AARCH64 "Build torchao's CPU aarch64 kernels" OFF)
2122
option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF)
23+
option(TORCHAO_ENABLE_ARM_NEON_DOT "Enable ARM Neon Dot Product extension" OFF)
24+
option(TORCHAO_ENABLE_ARM_I8MM "Enable ARM 8-bit Integer Matrix Multiply instructions" OFF)
2225

2326
if(NOT TORCHAO_INCLUDE_DIRS)
2427
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..)
@@ -28,19 +31,49 @@ if(NOT DEFINED TORCHAO_PARALLEL_BACKEND)
2831
set(TORCHAO_PARALLEL_BACKEND aten_openmp)
2932
endif()
3033

31-
include(CMakePrintHelpers)
32-
34+
# Set default compiler options
3335
add_compile_options("-Wall" "-Werror" "-Wno-deprecated")
3436

3537
include(CMakePrintHelpers)
3638
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
3739
include_directories(${TORCHAO_INCLUDE_DIRS})
3840

39-
4041
if(TORCHAO_BUILD_CPU_AARCH64)
4142
message(STATUS "Building with cpu/aarch64")
4243
add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64)
43-
add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT)
44+
45+
# Set aarch64 compiler options
46+
if (CMAKE_SYSTEM_NAME STREQUAL "Linux")
47+
message(STATUS "Add aarch64 linux compiler options")
48+
add_compile_options(
49+
"-fPIC"
50+
"-Wno-error=unknown-pragmas"
51+
"-Wno-array-parameter"
52+
"-Wno-maybe-uninitialized"
53+
"-Wno-sign-compare"
54+
)
55+
56+
# Since versions are hierarchical (each includes features from prior versions):
57+
# - dotprod is included by default in armv8.4-a and later
58+
# - i8mm is included by default in armv8.6-a and later
59+
if(TORCHAO_ENABLE_ARM_I8MM)
60+
message(STATUS "Using armv8.6-a (includes 'i8mm' and 'dotprod' flags)")
61+
add_compile_options("-march=armv8.6-a")
62+
elseif(TORCHAO_ENABLE_ARM_NEON_DOT)
63+
message(STATUS "Using armv8.4-a (includes '+dotprod' flag)")
64+
add_compile_options("-march=armv8.4-a")
65+
endif()
66+
endif()
67+
68+
if(TORCHAO_ENABLE_ARM_NEON_DOT)
69+
message(STATUS "Building with ARM NEON dot product support")
70+
add_compile_definitions(TORCHAO_ENABLE_ARM_NEON_DOT)
71+
endif()
72+
73+
if(TORCHAO_ENABLE_ARM_I8MM)
74+
message(STATUS "Building with ARM I8MM support")
75+
add_compile_definitions(TORCHAO_ENABLE_ARM_I8MM)
76+
endif()
4477

4578
# Defines torchao_kernels_aarch64
4679
add_subdirectory(kernels/cpu/aarch64)
@@ -51,26 +84,33 @@ if(TORCHAO_BUILD_CPU_AARCH64)
5184
endif()
5285
endif()
5386

87+
# Add quantized operation dir
5488
add_subdirectory(ops/linear_8bit_act_xbit_weight)
5589
add_subdirectory(ops/embedding_xbit)
5690

91+
# ATen ops lib
5792
add_library(torchao_ops_aten SHARED)
5893
target_link_libraries(
5994
torchao_ops_aten PRIVATE
6095
torchao_ops_linear_8bit_act_xbit_weight_aten
6196
torchao_ops_embedding_xbit_aten
6297
)
98+
99+
# Add MPS support if enabled
63100
if (TORCHAO_BUILD_MPS_OPS)
64101
message(STATUS "Building with MPS support")
65102
add_subdirectory(ops/mps)
66103
target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten)
67104
endif()
68105

106+
# Install ATen targets
69107
install(
70108
TARGETS torchao_ops_aten
71109
EXPORT _targets
72110
DESTINATION lib
73111
)
112+
113+
# Build executorch lib if enabled
74114
if(TORCHAO_BUILD_EXECUTORCH_OPS)
75115
add_library(torchao_ops_executorch STATIC)
76116
target_link_libraries(torchao_ops_executorch PRIVATE

torchao/experimental/build_torchao_ops.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
2222
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \
2323
-DTORCHAO_BUILD_EXECUTORCH_OPS="${TORCHAO_BUILD_EXECUTORCH_OPS}" \
2424
-DTORCHAO_BUILD_CPU_AARCH64=ON \
25+
-DTORCHAO_ENABLE_ARM_NEON_DOT=ON \
2526
-S . \
2627
-B ${CMAKE_OUT}
2728
cmake --build ${CMAKE_OUT} -j 16 --target install --config Release

torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <torchao/experimental/kernels/cpu/aarch64/valpacking/valpack.h>
88
#include <cassert>
99
#include <cstring>
10+
#include <cstdint>
1011

1112
// Interleaves data across channels (row/column) and groups.
1213
// Each channel is the same size (vals_per_channel) and is

torchao/experimental/op_lib.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,44 @@
1010
from torch import Tensor
1111
from torch.library import impl
1212

13-
# Load C++ ops
14-
lib_path = Path(__file__).parent.parent
15-
libs = list(lib_path.glob("libtorchao_ops_aten.*"))
16-
assert len(libs) == 1, (
17-
f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}"
18-
)
19-
torch.ops.load_library(str(libs[0]))
13+
# Load C++ ops - use multiple potential paths
14+
potential_paths = [
15+
# Standard path from the module location
16+
Path(__file__).parent.parent,
17+
# Site-packages installation path
18+
Path(torch.__file__).parent.parent / "torchao",
19+
# For editable installs
20+
Path(__file__).parent.parent.parent / "torchao",
21+
]
2022

2123

24+
def find_and_load_libtorchao_ops(potential_paths):
25+
for lib_path in potential_paths:
26+
libs = list(lib_path.glob("libtorchao_ops_aten.*"))
27+
28+
if not libs:
29+
continue
30+
31+
assert len(libs) == 1, (
32+
f"Expected to find one libtorchao_ops_aten.* library at {lib_path}, but found {len(libs)}"
33+
)
34+
35+
target_lib = libs[0]
36+
print(f"Found library at: {target_lib}")
37+
38+
try:
39+
torch.ops.load_library(str(target_lib))
40+
return
41+
except Exception as e:
42+
print(f"Error loading library from {target_lib}: {e}")
43+
44+
raise FileNotFoundError(
45+
"Could not find libtorchao_ops_aten library in any of the provided paths"
46+
)
47+
48+
49+
find_and_load_libtorchao_ops(potential_paths)
50+
2251
# Define meta ops. To support dynamic shapes, some meta ops need to
2352
# be defined in python instead of C++.
2453
torchao_lib = torch.library.Library("torchao", "IMPL")

torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_config.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ struct UKernelConfig {
190190
TORCHAO_CHECK(pack_weights != nullptr || pack_weights_with_lut != nullptr, "pack_weights or pack_weights_with_lut must be set");
191191

192192
bool linear_configs_set = true; // first linear config must be set
193-
for (int i = 0; i < linear_configs.size(); i++) {
193+
for (size_t i = 0; i < linear_configs.size(); i++) {
194194
if (linear_configs_set) {
195195
TORCHAO_CHECK(
196196
linear_configs[i].m_step >= 1,
@@ -225,7 +225,7 @@ struct UKernelConfig {
225225
assert(m >= 1);
226226
assert(linear_configs[0].m_step >= 1);
227227

228-
int i = 0;
228+
size_t i = 0;
229229
while (i + 1 < linear_configs.size() && linear_configs[i + 1].m_step >= 1 &&
230230
linear_configs[i + 1].m_step <= m) {
231231
assert(linear_configs[i].m_step < linear_configs[i + 1].m_step);
@@ -235,7 +235,7 @@ struct UKernelConfig {
235235
assert(i < linear_configs.size());
236236
assert(linear_configs[i].m_step >= 1);
237237
assert(i == 0 || linear_configs[i].m_step <= m);
238-
return i;
238+
return static_cast<int>(i);
239239
}
240240
};
241241

torchao/experimental/ops/packed_weights_header.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class PackedWeightsHeader {
4343
auto header = reinterpret_cast<int*>(packed_weights);
4444
header[0] = magic;
4545
header[1] = static_cast<int>(type);
46-
for (int i = 0; i < params.size(); i++) {
46+
for (size_t i = 0; i < params.size(); i++) {
4747
header[i + 2] = params[i];
4848
}
4949
}
@@ -52,7 +52,7 @@ class PackedWeightsHeader {
5252
auto header = reinterpret_cast<const int*>(packed_weights);
5353
assert(header[0] == PackedWeightsHeader::magic);
5454
params_type params;
55-
for (int i = 0; i < params.size(); i++) {
55+
for (size_t i = 0; i < params.size(); i++) {
5656
params[i] = header[i + 2];
5757
}
5858
return PackedWeightsHeader(
@@ -63,7 +63,7 @@ class PackedWeightsHeader {
6363
if (type != other.type) {
6464
return false;
6565
}
66-
for (int i = 0; i < params.size(); i++) {
66+
for (size_t i = 0; i < params.size(); i++) {
6767
if (params[i] != other.params[i]) {
6868
return false;
6969
}
@@ -79,7 +79,7 @@ namespace std {
7979
struct hash<torchao::ops::PackedWeightsHeader> {
8080
std::size_t operator()(const torchao::ops::PackedWeightsHeader& f) const {
8181
std::size_t hash = std::hash<int>()(static_cast<int>(f.type));
82-
for (int i = 0; i < f.params.size(); i++) {
82+
for (size_t i = 0; i < f.params.size(); i++) {
8383
hash ^= std::hash<int>()(f.params[i]);
8484
}
8585
return hash;

torchao/experimental/ops/parallel-aten-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// LICENSE file in the root directory of this source tree.
66

77
#pragma once
8-
#include <Aten/Parallel.h>
8+
#include <ATen/Parallel.h>
99
#include <torch/library.h>
1010
#include <torch/torch.h>
1111

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from pathlib import Path
9+
from unittest.mock import MagicMock, patch
10+
11+
12+
class TestLibTorchAoOpsLoader(unittest.TestCase):
13+
def test_find_and_load_success(self):
14+
mock_paths = [Path("/test/path1")]
15+
mock_lib = MagicMock()
16+
mock_lib.__str__.return_value = "/test/path1/libtorchao_ops_aten.so"
17+
18+
with patch("pathlib.Path.glob", return_value=[mock_lib]):
19+
with patch("torch.ops.load_library") as mock_load:
20+
from ..op_lib import find_and_load_libtorchao_ops
21+
22+
find_and_load_libtorchao_ops(mock_paths)
23+
24+
mock_load.assert_called_once_with("/test/path1/libtorchao_ops_aten.so")
25+
26+
def test_no_library_found(self):
27+
mock_paths = [Path("/test/path1"), Path("/test/path2")]
28+
29+
with patch("pathlib.Path.glob", return_value=[]):
30+
from ..op_lib import find_and_load_libtorchao_ops
31+
32+
with self.assertRaises(FileNotFoundError):
33+
find_and_load_libtorchao_ops(mock_paths)
34+
35+
def test_multiple_libraries_error(self):
36+
mock_paths = [Path("/test/path1")]
37+
mock_lib1 = MagicMock()
38+
mock_lib2 = MagicMock()
39+
mock_libs = [mock_lib1, mock_lib2]
40+
41+
with patch("pathlib.Path.glob", return_value=mock_libs):
42+
from ..op_lib import find_and_load_libtorchao_ops
43+
44+
try:
45+
find_and_load_libtorchao_ops(mock_paths)
46+
self.fail("Expected AssertionError was not raised")
47+
except AssertionError as e:
48+
expected_error_msg = f"Expected to find one libtorchao_ops_aten.* library at {mock_paths[0]}, but found 2"
49+
self.assertIn(expected_error_msg, str(e))
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()

0 commit comments

Comments
 (0)