Skip to content

Commit 8f93751

Browse files
metal lowbit kernels: pip install (#1785)
1 parent 4a4925f commit 8f93751

File tree

6 files changed

+38
-10
lines changed

6 files changed

+38
-10
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,3 +375,4 @@ checkpoints/
375375

376376
# Experimental
377377
torchao/experimental/cmake-out
378+
torchao/experimental/deps

setup.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,20 @@ def use_debug_mode():
7575
CUDAExtension,
7676
)
7777

78+
build_torchao_experimental_mps = (
79+
os.getenv("TORCHAO_BUILD_EXPERIMENTAL_MPS") == "1"
80+
and build_torchao_experimental
81+
and torch.mps.is_available()
82+
)
83+
84+
if os.getenv("TORCHAO_BUILD_EXPERIMENTAL_MPS") == "1":
85+
if use_cpp != "1":
86+
print("Building experimental MPS ops requires USE_CPP=1")
87+
if not platform.machine().startswith("arm64") or platform.system() != "Darwin":
88+
print("Experimental MPS ops require Apple Silicon.")
89+
if not torch.mps.is_available():
90+
print("MPS not available. Skipping compilation of experimental MPS ops.")
91+
7892
# Constant known variables used throughout this file
7993
cwd = os.path.abspath(os.path.curdir)
8094
third_party_path = os.path.join(cwd, "third_party")
@@ -174,15 +188,19 @@ def build_cmake(self, ext):
174188
if not os.path.exists(self.build_temp):
175189
os.makedirs(self.build_temp)
176190

191+
build_mps_ops = "ON" if build_torchao_experimental_mps else "OFF"
192+
177193
subprocess.check_call(
178194
[
179195
"cmake",
180196
ext.sourcedir,
181197
"-DCMAKE_BUILD_TYPE=" + build_type,
182198
# Disable now because 1) KleidiAI increases build time, and 2) KleidiAI has accuracy issues due to BF16
183199
"-DTORCHAO_BUILD_KLEIDIAI=OFF",
200+
"-DTORCHAO_BUILD_MPS_OPS=" + build_mps_ops,
184201
"-DTorch_DIR=" + torch_dir,
185202
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir,
203+
"-DCMAKE_INSTALL_PREFIX=cmake-out",
186204
],
187205
cwd=self.build_temp,
188206
)

torchao/experimental/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ if (NOT CMAKE_BUILD_TYPE)
1616
endif()
1717

1818
option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF)
19+
option(TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF)
1920

2021

2122
if(NOT TORCHAO_INCLUDE_DIRS)
@@ -51,6 +52,12 @@ if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
5152
torchao_ops_linear_8bit_act_xbit_weight_aten
5253
torchao_ops_embedding_xbit_aten
5354
)
55+
if (TORCHAO_BUILD_MPS_OPS)
56+
message(STATUS "Building with MPS support")
57+
add_subdirectory(ops/mps)
58+
target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten)
59+
endif()
60+
5461
install(
5562
TARGETS torchao_ops_aten
5663
EXPORT _targets

torchao/experimental/ops/mps/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ find_package(Torch REQUIRED)
2828
# Generate metal_shader_lib.h by running gen_metal_shader_lib.py
2929
set(METAL_SHADERS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal)
3030
file(GLOB METAL_FILES ${METAL_SHADERS_DIR}/*.metal)
31+
set(METAL_SHADERS_YAML ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/metal.yaml)
3132
set(GEN_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py)
3233
set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h)
3334
add_custom_command(
3435
OUTPUT ${GENERATED_METAL_SHADER_LIB}
3536
COMMAND python ${GEN_SCRIPT} ${GENERATED_METAL_SHADER_LIB}
36-
DEPENDS ${METAL_FILES} ${GEN_SCRIPT}
37+
DEPENDS ${METAL_FILES} ${METAL_SHADERS_YAML} ${GEN_SCRIPT}
3738
COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py"
3839
)
3940
add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB})

torchao/experimental/ops/mps/test/test_lowbit.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,18 @@
1010
import torch
1111
from parameterized import parameterized
1212

13-
libname = "libtorchao_ops_mps_aten.dylib"
14-
libpath = os.path.abspath(
15-
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
16-
)
13+
import torchao # noqa: F401
1714

1815
try:
1916
for nbit in range(1, 8):
2017
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
2118
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
2219
except AttributeError:
2320
try:
21+
libname = "libtorchao_ops_mps_aten.dylib"
22+
libpath = os.path.abspath(
23+
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
24+
)
2425
torch.ops.load_library(libpath)
2526
except:
2627
raise RuntimeError(f"Failed to load library {libpath}")

torchao/experimental/ops/mps/test/test_quantizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@
1212
import torch
1313
from parameterized import parameterized
1414

15+
import torchao # noqa: F401
1516
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer, _quantize
1617

17-
libname = "libtorchao_ops_mps_aten.dylib"
18-
libpath = os.path.abspath(
19-
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
20-
)
21-
2218
try:
2319
for nbit in range(1, 8):
2420
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
2521
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
2622
except AttributeError:
2723
try:
24+
libname = "libtorchao_ops_mps_aten.dylib"
25+
libpath = os.path.abspath(
26+
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
27+
)
2828
torch.ops.load_library(libpath)
2929
except:
3030
raise RuntimeError(f"Failed to load library {libpath}")

0 commit comments

Comments
 (0)