3
3
# This source code is licensed under the license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
6
+ import copy
6
7
import glob
7
8
import os
8
9
import subprocess
@@ -75,6 +76,7 @@ def use_debug_mode():
75
76
BuildExtension ,
76
77
CppExtension ,
77
78
CUDAExtension ,
79
+ _get_cuda_arch_flags ,
78
80
)
79
81
80
82
IS_ROCM = (torch .version .hip is not None ) and (ROCM_HOME is not None )
@@ -269,7 +271,12 @@ def get_extensions():
269
271
extra_link_args = []
270
272
extra_compile_args = {
271
273
"cxx" : [f"-DPy_LIMITED_API={ PY3_9_HEXCODE } " ],
272
- "nvcc" : ["-O3" if not debug_mode else "-O0" , "-t=0" , "-std=c++17" ],
274
+ "nvcc" : [
275
+ "-DNDEBUG" if not debug_mode else "-DDEBUG" ,
276
+ "-O3" if not debug_mode else "-O0" ,
277
+ "-t=0" ,
278
+ "-std=c++17" ,
279
+ ],
273
280
}
274
281
275
282
if not IS_WINDOWS :
@@ -304,25 +311,6 @@ def get_extensions():
304
311
if use_cuda :
305
312
sources += cuda_sources
306
313
307
- use_cutlass = False
308
- if use_cuda and not IS_WINDOWS :
309
- use_cutlass = True
310
- cutlass_dir = os .path .join (third_party_path , "cutlass" )
311
- cutlass_include_dir = os .path .join (cutlass_dir , "include" )
312
- cutlass_tools_include_dir = os .path .join (
313
- cutlass_dir , "tools" , "util" , "include"
314
- )
315
- cutlass_extensions_include_dir = os .path .join (cwd , extensions_cuda_dir )
316
- if use_cutlass :
317
- extra_compile_args ["nvcc" ].extend (
318
- [
319
- "-DTORCHAO_USE_CUTLASS" ,
320
- "-I" + cutlass_include_dir ,
321
- "-I" + cutlass_tools_include_dir ,
322
- "-I" + cutlass_extensions_include_dir ,
323
- ]
324
- )
325
-
326
314
# Get base directory and source paths
327
315
curdir = os .path .dirname (os .path .curdir )
328
316
extensions_dir = os .path .join (curdir , "torchao" , "csrc" )
@@ -349,16 +337,6 @@ def get_extensions():
349
337
# Collect CUDA source files if needed
350
338
if not IS_ROCM and use_cuda :
351
339
sources += cuda_sources
352
- else :
353
- # Remove CUTLASS-based kernels from the cuda_sources list. An
354
- # assumption is that these files will have "cutlass" in its
355
- # name.
356
- cutlass_sources = list (
357
- glob .glob (
358
- os .path .join (extensions_cuda_dir , "**/*cutlass*.cu" ), recursive = True
359
- )
360
- )
361
- sources = [s for s in sources if s not in cutlass_sources ]
362
340
363
341
# TOOD: Remove this and use what CUDA has once we fix all the builds.
364
342
if IS_ROCM and use_cuda :
@@ -372,6 +350,72 @@ def get_extensions():
372
350
else :
373
351
sources += hip_sources
374
352
353
+ use_cutlass = False
354
+ cutlass_90a_sources = None
355
+ if use_cuda and not IS_ROCM and not IS_WINDOWS :
356
+ use_cutlass = True
357
+ cutlass_dir = os .path .join (third_party_path , "cutlass" )
358
+ cutlass_include_dir = os .path .join (cutlass_dir , "include" )
359
+ cutlass_tools_include_dir = os .path .join (
360
+ cutlass_dir , "tools" , "util" , "include"
361
+ )
362
+ cutlass_extensions_include_dir = os .path .join (cwd , extensions_cuda_dir )
363
+ if use_cutlass :
364
+ extra_compile_args ["nvcc" ].extend (
365
+ [
366
+ "-DTORCHAO_USE_CUTLASS" ,
367
+ "-I" + cutlass_include_dir ,
368
+ "-I" + cutlass_tools_include_dir ,
369
+ "-I" + cutlass_extensions_include_dir ,
370
+ "-DCUTE_USE_PACKED_TUPLE=1" ,
371
+ "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED" ,
372
+ "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1" ,
373
+ "-DCUTLASS_DEBUG_TRACE_LEVEL=0" ,
374
+ "--ftemplate-backtrace-limit=0" ,
375
+ # "--keep",
376
+ # "--ptxas-options=--verbose,--register-usage-level=5,--warn-on-local-memory-usage",
377
+ # "--resource-usage",
378
+ # "-lineinfo",
379
+ # "-DCUTLASS_ENABLE_GDC_FOR_SM90", # https://github.com/NVIDIA/cutlass/blob/main/media/docs/dependent_kernel_launch.md
380
+ ]
381
+ )
382
+
383
+ cuda_arch_flags = _get_cuda_arch_flags ()
384
+ build_for_sm90 = "-gencode=arch=compute_90,code=sm_90" in cuda_arch_flags
385
+ build_for_sm90a = "-gencode=arch=compute_90a,code=sm_90a" in cuda_arch_flags
386
+ if build_for_sm90 and not build_for_sm90a :
387
+ cutlass_90a_sources = [
388
+ os .path .join (
389
+ extensions_cuda_dir ,
390
+ "rowwise_scaled_linear_sparse_cutlass" ,
391
+ "rowwise_scaled_linear_sparse_cutlass_f8f8.cu" ,
392
+ ),
393
+ os .path .join (
394
+ extensions_cuda_dir ,
395
+ "to_sparse_semi_structured_cutlass_sm9x" ,
396
+ "to_sparse_semi_structured_cutlass_sm9x_f8.cu" ,
397
+ ),
398
+ ]
399
+ for dtypes in ["e4m3e4m3" , "e4m3e5m2" , "e5m2e4m3" , "e5m2e5m2" ]:
400
+ cutlass_90a_sources .append (
401
+ os .path .join (
402
+ extensions_cuda_dir ,
403
+ "rowwise_scaled_linear_sparse_cutlass" ,
404
+ "rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu" ,
405
+ )
406
+ )
407
+ sources = [s for s in sources if s not in cutlass_90a_sources ]
408
+ else :
409
+ # Remove CUTLASS-based kernels from the sources list. An
410
+ # assumption is that these files will have "cutlass" in its
411
+ # name.
412
+ cutlass_sources = list (
413
+ glob .glob (
414
+ os .path .join (extensions_cuda_dir , "**/*cutlass*.cu" ), recursive = True
415
+ )
416
+ )
417
+ sources = [s for s in sources if s not in cutlass_sources ]
418
+
375
419
ext_modules = []
376
420
if len (sources ) > 0 :
377
421
ext_modules .append (
@@ -384,6 +428,21 @@ def get_extensions():
384
428
)
385
429
)
386
430
431
+ if cutlass_90a_sources is not None and len (cutlass_90a_sources ) > 0 :
432
+ cutlass_90a_extra_compile_args = copy .deepcopy (extra_compile_args )
433
+ cutlass_90a_extra_compile_args ["nvcc" ].extend (
434
+ cuda_arch_flags + ["-gencode=arch=compute_90a,code=sm_90a" ]
435
+ )
436
+ ext_modules .append (
437
+ extension (
438
+ "torchao._C" ,
439
+ cutlass_90a_sources ,
440
+ py_limited_api = True ,
441
+ extra_compile_args = cutlass_90a_extra_compile_args ,
442
+ extra_link_args = extra_link_args ,
443
+ )
444
+ )
445
+
387
446
if build_torchao_experimental :
388
447
build_options = BuildOptions ()
389
448
0 commit comments