@@ -292,25 +292,6 @@ def get_extensions():
292
292
extra_compile_args ["nvcc" ].append ("-g" )
293
293
extra_link_args .append ("/DEBUG" )
294
294
295
- use_cutlass = False
296
- if use_cuda and not IS_ROCM and not IS_WINDOWS :
297
- use_cutlass = True
298
- cutlass_dir = os .path .join (third_party_path , "cutlass" )
299
- cutlass_include_dir = os .path .join (cutlass_dir , "include" )
300
- cutlass_tools_include_dir = os .path .join (
301
- cutlass_dir , "tools" , "util" , "include"
302
- )
303
- cutlass_extensions_include_dir = os .path .join (cwd , extensions_cuda_dir )
304
- if use_cutlass :
305
- extra_compile_args ["nvcc" ].extend (
306
- [
307
- "-DTORCHAO_USE_CUTLASS" ,
308
- "-I" + cutlass_include_dir ,
309
- "-I" + cutlass_tools_include_dir ,
310
- "-I" + cutlass_extensions_include_dir ,
311
- ]
312
- )
313
-
314
295
# Get base directory and source paths
315
296
curdir = os .path .dirname (os .path .curdir )
316
297
extensions_dir = os .path .join (curdir , "torchao" , "csrc" )
@@ -335,6 +316,25 @@ def get_extensions():
335
316
glob .glob (os .path .join (extensions_hip_dir , "*.cu" ), recursive = True )
336
317
)
337
318
319
+ use_cutlass = False
320
+ if use_cuda and not IS_ROCM and not IS_WINDOWS :
321
+ use_cutlass = True
322
+ cutlass_dir = os .path .join (third_party_path , "cutlass" )
323
+ cutlass_include_dir = os .path .join (cutlass_dir , "include" )
324
+ cutlass_tools_include_dir = os .path .join (
325
+ cutlass_dir , "tools" , "util" , "include"
326
+ )
327
+ cutlass_extensions_include_dir = os .path .join (cwd , extensions_cuda_dir )
328
+ if use_cutlass :
329
+ extra_compile_args ["nvcc" ].extend (
330
+ [
331
+ "-DTORCHAO_USE_CUTLASS" ,
332
+ "-I" + cutlass_include_dir ,
333
+ "-I" + cutlass_tools_include_dir ,
334
+ "-I" + cutlass_extensions_include_dir ,
335
+ ]
336
+ )
337
+
338
338
# Collect CUDA source files if needed
339
339
if not IS_ROCM and use_cuda :
340
340
sources += cuda_sources
0 commit comments