Skip to content

Commit 7481959

Browse files
committed
Reorder source file collection in setup.py
Move source file collection logic to maintain consistent code organization and improve readability of the build configuration. No functional changes were made to the source file selection process.
1 parent 16d22c1 commit 7481959

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

setup.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -292,25 +292,6 @@ def get_extensions():
292292
extra_compile_args["nvcc"].append("-g")
293293
extra_link_args.append("/DEBUG")
294294

295-
use_cutlass = False
296-
if use_cuda and not IS_ROCM and not IS_WINDOWS:
297-
use_cutlass = True
298-
cutlass_dir = os.path.join(third_party_path, "cutlass")
299-
cutlass_include_dir = os.path.join(cutlass_dir, "include")
300-
cutlass_tools_include_dir = os.path.join(
301-
cutlass_dir, "tools", "util", "include"
302-
)
303-
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
304-
if use_cutlass:
305-
extra_compile_args["nvcc"].extend(
306-
[
307-
"-DTORCHAO_USE_CUTLASS",
308-
"-I" + cutlass_include_dir,
309-
"-I" + cutlass_tools_include_dir,
310-
"-I" + cutlass_extensions_include_dir,
311-
]
312-
)
313-
314295
# Get base directory and source paths
315296
curdir = os.path.dirname(os.path.curdir)
316297
extensions_dir = os.path.join(curdir, "torchao", "csrc")
@@ -335,6 +316,25 @@ def get_extensions():
335316
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
336317
)
337318

319+
use_cutlass = False
320+
if use_cuda and not IS_ROCM and not IS_WINDOWS:
321+
use_cutlass = True
322+
cutlass_dir = os.path.join(third_party_path, "cutlass")
323+
cutlass_include_dir = os.path.join(cutlass_dir, "include")
324+
cutlass_tools_include_dir = os.path.join(
325+
cutlass_dir, "tools", "util", "include"
326+
)
327+
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
328+
if use_cutlass:
329+
extra_compile_args["nvcc"].extend(
330+
[
331+
"-DTORCHAO_USE_CUTLASS",
332+
"-I" + cutlass_include_dir,
333+
"-I" + cutlass_tools_include_dir,
334+
"-I" + cutlass_extensions_include_dir,
335+
]
336+
)
337+
338338
# Collect CUDA source files if needed
339339
if not IS_ROCM and use_cuda:
340340
sources += cuda_sources

0 commit comments

Comments
 (0)