Skip to content

Commit 76d68bf

Browse files
committed
Refactor CUDA and ROCm source file handling in setup.py
Reorganize source file selection logic for CUDA and ROCm builds, improving conditional handling of GPU sources and CUTLASS kernels. Simplify the source file selection process and improve readability of the build configuration.
1 parent 3a77641 commit 76d68bf

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

setup.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def __init__(self):
9191
default=(self._is_arm64() and self._is_macos()),
9292
)
9393
if self.build_cpu_aarch64:
94-
assert (
95-
self._is_arm64()
96-
), "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine"
94+
assert self._is_arm64(), (
95+
"TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine"
96+
)
9797

9898
# TORCHAO_BUILD_KLEIDIAI is disabled by default for now because
9999
# 1) It increases the build time
@@ -102,9 +102,9 @@ def __init__(self):
102102
"TORCHAO_BUILD_KLEIDIAI", default=False
103103
)
104104
if self.build_kleidi_ai:
105-
assert (
106-
self.build_cpu_aarch64
107-
), "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set"
105+
assert self.build_cpu_aarch64, (
106+
"TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set"
107+
)
108108

109109
# TORCHAO_BUILD_EXPERIMENTAL_MPS is disabled by default.
110110
self.build_experimental_mps = self._os_bool_var(
@@ -113,9 +113,9 @@ def __init__(self):
113113
if self.build_experimental_mps:
114114
assert self._is_macos(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MacOS"
115115
assert self._is_arm64(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires arm64"
116-
assert (
117-
torch.mps.is_available()
118-
), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
116+
assert torch.mps.is_available(), (
117+
"TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
118+
)
119119

120120
def _is_arm64(self) -> bool:
121121
return platform.machine().startswith("arm64")
@@ -341,6 +341,7 @@ def get_extensions():
341341
hip_sources = list(
342342
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
343343
)
344+
344345
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
345346
hip_sources += list(
346347
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
@@ -349,6 +350,16 @@ def get_extensions():
349350
# Collect CUDA source files if needed
350351
if not IS_ROCM and use_cuda:
351352
sources += cuda_sources
353+
elif IS_ROCM and use_cuda:
354+
# Add ROCm GPU architecture check
355+
gpu_arch = torch.cuda.get_device_properties(0).gcnArchName
356+
if "gfx942" not in gpu_arch:
357+
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
358+
print(
359+
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
360+
)
361+
else:
362+
sources += hip_sources
352363
else:
353364
# Remove CUTLASS-based kernels from the cuda_sources list. An
354365
# assumption is that these files will have "cutlass" in its
@@ -360,18 +371,6 @@ def get_extensions():
360371
)
361372
sources = [s for s in sources if s not in cutlass_sources]
362373

363-
# TOOD: Remove this and use what CUDA has once we fix all the builds.
364-
if IS_ROCM and use_cuda:
365-
# Add ROCm GPU architecture check
366-
gpu_arch = torch.cuda.get_device_properties(0).gcnArchName
367-
if "gfx942" not in gpu_arch:
368-
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
369-
print(
370-
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
371-
)
372-
else:
373-
sources += hip_sources
374-
375374
ext_modules = []
376375
if len(sources) > 0:
377376
ext_modules.append(

0 commit comments

Comments
 (0)