Skip to content

Commit 16d22c1

Browse files
committed
Improve CUTLASS kernel support detection for non-Windows platforms
Modify CUTLASS kernel configuration to explicitly check for non-ROCm platforms when enabling support, ensuring more precise build configuration for different GPU environments.
1 parent 76d68bf commit 16d22c1

File tree

1 file changed

+1
-13
lines changed

1 file changed

+1
-13
lines changed

setup.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -292,20 +292,8 @@ def get_extensions():
292292
extra_compile_args["nvcc"].append("-g")
293293
extra_link_args.append("/DEBUG")
294294

295-
curdir = os.path.dirname(os.path.curdir)
296-
extensions_dir = os.path.join(curdir, "torchao", "csrc")
297-
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
298-
299-
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
300-
cuda_sources = list(
301-
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
302-
)
303-
304-
if use_cuda:
305-
sources += cuda_sources
306-
307295
use_cutlass = False
308-
if use_cuda and not IS_WINDOWS:
296+
if use_cuda and not IS_ROCM and not IS_WINDOWS:
309297
use_cutlass = True
310298
cutlass_dir = os.path.join(third_party_path, "cutlass")
311299
cutlass_include_dir = os.path.join(cutlass_dir, "include")

0 commit comments

Comments
 (0)