diff --git a/setup.py b/setup.py index e1bad04cd2..cbab781709 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def read_version(file_path="version.txt"): import platform -build_torchao_experimental = ( +build_cpp = ( use_cpp == "1" and platform.machine().startswith("arm64") and platform.system() == "Darwin" @@ -75,9 +75,9 @@ def use_debug_mode(): CUDAExtension, ) -build_torchao_experimental_mps = ( +build_cpp_mps = ( os.getenv("TORCHAO_BUILD_EXPERIMENTAL_MPS") == "1" - and build_torchao_experimental + and build_cpp and torch.mps.is_available() ) @@ -188,7 +188,7 @@ def build_cmake(self, ext): if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) - build_mps_ops = "ON" if build_torchao_experimental_mps else "OFF" + build_mps_ops = "ON" if build_cpp_mps else "OFF" subprocess.check_call( [ @@ -214,6 +214,9 @@ def __init__(self, name, sourcedir=""): def get_extensions(): + if not build_cpp: + return [] + debug_mode = use_debug_mode() if debug_mode: print("Compiling in debug mode") @@ -309,13 +312,12 @@ def get_extensions(): ) ) - if build_torchao_experimental: - ext_modules.append( - CMakeExtension( - "torchao.experimental", - sourcedir="torchao/experimental", - ) + ext_modules.append( + CMakeExtension( + "torchao.experimental", + sourcedir="torchao/experimental", ) + ) return ext_modules