@@ -91,9 +91,9 @@ def __init__(self):
91
91
default = (self ._is_arm64 () and self ._is_macos ()),
92
92
)
93
93
if self .build_cpu_aarch64 :
94
- assert (
95
- self . _is_arm64 ()
96
- ), "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine"
94
+ assert self . _is_arm64 (), (
95
+ "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine"
96
+ )
97
97
98
98
# TORCHAO_BUILD_KLEIDIAI is disabled by default for now because
99
99
# 1) It increases the build time
@@ -102,9 +102,9 @@ def __init__(self):
102
102
"TORCHAO_BUILD_KLEIDIAI" , default = False
103
103
)
104
104
if self .build_kleidi_ai :
105
- assert (
106
- self . build_cpu_aarch64
107
- ), "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set"
105
+ assert self . build_cpu_aarch64 , (
106
+ "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set"
107
+ )
108
108
109
109
# TORCHAO_BUILD_EXPERIMENTAL_MPS is disabled by default.
110
110
self .build_experimental_mps = self ._os_bool_var (
@@ -113,9 +113,9 @@ def __init__(self):
113
113
if self .build_experimental_mps :
114
114
assert self ._is_macos (), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MacOS"
115
115
assert self ._is_arm64 (), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires arm64"
116
- assert (
117
- torch . mps . is_available ()
118
- ), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
116
+ assert torch . mps . is_available (), (
117
+ "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
118
+ )
119
119
120
120
def _is_arm64 (self ) -> bool :
121
121
return platform .machine ().startswith ("arm64" )
@@ -341,6 +341,7 @@ def get_extensions():
341
341
hip_sources = list (
342
342
glob .glob (os .path .join (extensions_hip_dir , "*.cu" ), recursive = True )
343
343
)
344
+
344
345
extensions_hip_dir = os .path .join (extensions_dir , "cuda" , "sparse_marlin" )
345
346
hip_sources += list (
346
347
glob .glob (os .path .join (extensions_hip_dir , "*.cu" ), recursive = True )
@@ -349,6 +350,16 @@ def get_extensions():
349
350
# Collect CUDA source files if needed
350
351
if not IS_ROCM and use_cuda :
351
352
sources += cuda_sources
353
+ elif IS_ROCM and use_cuda :
354
+ # Add ROCm GPU architecture check
355
+ gpu_arch = torch .cuda .get_device_properties (0 ).gcnArchName
356
+ if "gfx942" not in gpu_arch :
357
+ print (f"Warning: Unsupported ROCm GPU architecture: { gpu_arch } " )
358
+ print (
359
+ "Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
360
+ )
361
+ else :
362
+ sources += hip_sources
352
363
else :
353
364
# Remove CUTLASS-based kernels from the cuda_sources list. An
354
365
# assumption is that these files will have "cutlass" in its
@@ -360,18 +371,6 @@ def get_extensions():
360
371
)
361
372
sources = [s for s in sources if s not in cutlass_sources ]
362
373
363
- # TOOD: Remove this and use what CUDA has once we fix all the builds.
364
- if IS_ROCM and use_cuda :
365
- # Add ROCm GPU architecture check
366
- gpu_arch = torch .cuda .get_device_properties (0 ).gcnArchName
367
- if "gfx942" not in gpu_arch :
368
- print (f"Warning: Unsupported ROCm GPU architecture: { gpu_arch } " )
369
- print (
370
- "Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
371
- )
372
- else :
373
- sources += hip_sources
374
-
375
374
ext_modules = []
376
375
if len (sources ) > 0 :
377
376
ext_modules .append (
0 commit comments